Iceshrimp.NET/Iceshrimp.Backend/Core/Extensions/WebApplicationExtensions.cs

465 lines
16 KiB
C#

using System.Collections.Concurrent;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.InteropServices;
using Iceshrimp.Backend.Core.Configuration;
using Iceshrimp.Backend.Core.Database;
using Iceshrimp.Backend.Core.Database.Migrations;
using Iceshrimp.Backend.Core.Helpers;
using Iceshrimp.Backend.Core.Middleware;
using Iceshrimp.Backend.Core.Services;
using Iceshrimp.Backend.Core.Services.ImageProcessing;
using Iceshrimp.WebPush;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Migrations;
using Microsoft.Extensions.Configuration.Ini;
using Microsoft.Extensions.Options;
using Scalar.AspNetCore;
namespace Iceshrimp.Backend.Core.Extensions;
public static class WebApplicationExtensions
{
public static IApplicationBuilder UseCustomMiddleware(this IApplicationBuilder app)
{
// Caution: make sure these are in the correct order
return app.UseMiddleware<RequestDurationMiddleware>()
.UseMiddleware<ErrorHandlerMiddleware>()
.UseMiddleware<RequestVerificationMiddleware>()
.UseMiddleware<RequestBufferingMiddleware>()
.UseMiddleware<AuthenticationMiddleware>()
.UseRateLimiter()
.UseMiddleware<AuthorizationMiddleware>()
.UseMiddleware<FederationSemaphoreMiddleware>()
.UseMiddleware<AuthorizedFetchMiddleware>()
.UseMiddleware<InboxValidationMiddleware>()
.UseOutputCache()
.UseMiddleware<BlazorSsrHandoffMiddleware>();
}
// Prevents conditional middleware from being invoked on non-matching requests
private static IApplicationBuilder UseMiddleware<T>(this IApplicationBuilder app) where T : IConditionalMiddleware
=> app.UseWhen(T.Predicate, builder => UseMiddlewareExtensions.UseMiddleware<T>(builder));
public static IApplicationBuilder UseOpenApiWithOptions(this WebApplication app)
{
app.MapSwagger("/openapi/{documentName}.{extension:regex(^(json|ya?ml)$)}")
.CacheOutput(p => p.Expire(TimeSpan.FromHours(12)));
app.UseSwaggerUI(options =>
{
options.DocumentTitle = "Iceshrimp API documentation";
options.SwaggerEndpoint("/openapi/iceshrimp.json", "Iceshrimp.NET");
options.SwaggerEndpoint("/openapi/federation.json", "Federation");
options.SwaggerEndpoint("/openapi/mastodon.json", "Mastodon");
options.InjectStylesheet("/css/swagger.css");
options.EnablePersistAuthorization();
options.EnableTryItOutByDefault();
options.DisplayRequestDuration();
options.DefaultModelsExpandDepth(-1); // Hide "Schemas" section
options.ConfigObject.AdditionalItems.Add("tagsSorter", "alpha"); // Sort tags alphabetically
});
app.MapScalarApiReference("/scalar", options =>
{
options.WithTitle("Iceshrimp API documentation")
.WithOpenApiRoutePattern("/openapi/{documentName}.json")
.WithModels(false)
.WithCustomCss("""
.open-api-client-button, .darklight-reference > .flex > .text-sm { display: none !important; }
.darklight-reference > .flex > button > div:nth-child(1) { height: 14px !important; }
.darklight-reference { padding: 15px 14px !important; }
""");
});
return app;
}
public static void MapFrontendRoutes(this WebApplication app, string page)
{
app.MapFallbackToPage(page).WithOrder(int.MaxValue - 2);
app.MapFallbackToPage("/@{user}", page).WithOrder(int.MaxValue - 1);
app.MapFallbackToPage("/@{user}@{host}", page);
}
public static async Task<Config.InstanceSection> InitializeAsync(this WebApplication app, string[] args)
{
var instanceConfig = app.Configuration.GetSection("Instance").Get<Config.InstanceSection>() ??
throw new Exception("Failed to read Instance config section");
app.Logger.LogInformation("Iceshrimp.NET v{version}, codename \"{codename}\" ({domain})",
instanceConfig.Version, instanceConfig.Codename, instanceConfig.AccountDomain);
await using var scope = app.Services.CreateAsyncScope();
var provider = scope.ServiceProvider;
var config = (ConfigurationManager)app.Configuration;
var files = config.Sources.OfType<IniConfigurationSource>().Select(p => p.Path);
app.Logger.LogDebug("Loaded configuration files: \n* {files}", string.Join("\n* ", files));
try
{
app.Logger.LogInformation("Validating configuration...");
provider.GetRequiredService<IStartupValidator>().Validate();
}
catch (OptionsValidationException e)
{
app.Logger.LogCritical("Failed to validate configuration: {error}", e.Message);
Environment.Exit(1);
}
if (app.Environment.IsDevelopment())
{
app.Logger.LogWarning("The hosting environment is set to Development.");
app.Logger.LogWarning("This application will not validate the Host header for incoming requests.");
app.Logger.LogWarning("If this is not a local development instance, please set the environment to Production.");
}
await using var db = provider.GetService<DatabaseContext>();
if (db == null)
{
app.Logger.LogCritical("Failed to initialize database context");
Environment.Exit(1);
}
app.Logger.LogInformation("Verifying database connection...");
if (!await db.Database.CanConnectAsync())
{
app.Logger.LogCritical("Failed to connect to database. Please make sure your configuration is correct.");
Environment.Exit(1);
}
// @formatter:off
var pendingMigration = (await db.Database.GetPendingMigrationsAsync()).FirstOrDefault();
if (args.Contains("--migrate-from-js"))
{
app.Logger.LogInformation("Initializing migration assistant...");
var initialMigration = typeof(Initial).GetCustomAttribute<MigrationAttribute>()?.Id;
if (pendingMigration != initialMigration || await db.IsDatabaseEmptyAsync())
{
app.Logger.LogCritical("Database does not appear to be an iceshrimp-js database.");
Environment.Exit(1);
}
else if (!args.Contains("--i-reverted-any-extra-migrations") ||
!args.Contains("--i-made-a-database-backup") ||
!args.Contains("--i-understand-that-this-is-a-one-way-operation"))
{
app.Logger.LogCritical("Missing confirmation argument(s), please follow the instructions on https://iceshrimp.net/help/migrate exactly.");
Environment.Exit(1);
}
else
{
app.Logger.LogInformation("Applying initial migration...");
try
{
await db.Database.ExecuteSqlAsync(new MigrationAssistant().InitialMigration);
}
catch (Exception e)
{
app.Logger.LogCritical("Failed to apply initial migration: {error}", e);
app.Logger.LogCritical("Manual intervention required, please follow the instructions on https://iceshrimp.net/help/migrate for more information.");
Environment.Exit(1);
}
app.Logger.LogInformation("Successfully applied the initial migration.");
app.Logger.LogInformation("Please follow the instructions on https://iceshrimp.net/help/migrate to validate the database schema.");
Environment.Exit(0);
}
}
// @formatter:on
if (pendingMigration != null)
{
var initialMigration = typeof(Initial).GetCustomAttribute<MigrationAttribute>()?.Id;
if (pendingMigration == initialMigration && !await db.IsDatabaseEmptyAsync())
{
app.Logger.LogCritical("Initial migration is pending but database is not empty.");
app.Logger.LogCritical("If you are trying to migrate from iceshrimp-js, please follow the instructions on https://iceshrimp.net/help/migrate.");
Environment.Exit(1);
}
if (args.Contains("--migrate") || args.Contains("--migrate-and-start"))
{
app.Logger.LogInformation("Running migrations...");
db.Database.SetCommandTimeout(0);
await db.Database.MigrateAsync();
db.Database.SetCommandTimeout(30);
if (args.Contains("--migrate")) Environment.Exit(0);
}
else
{
app.Logger.LogCritical("Database has pending migrations, please restart with --migrate or --migrate-and-start");
Environment.Exit(1);
}
}
else if (args.Contains("--migrate") || args.Contains("--migrate-and-start"))
{
app.Logger.LogInformation("No migrations are pending.");
if (args.Contains("--migrate")) Environment.Exit(0);
}
if (args.Contains("--recompute-counters"))
{
app.Logger.LogInformation("Recomputing note, user & instance counters, this will take a while...");
var maintenanceSvc = provider.GetRequiredService<DatabaseMaintenanceService>();
await maintenanceSvc.RecomputeNoteCountersAsync();
await maintenanceSvc.RecomputeUserCountersAsync();
await maintenanceSvc.RecomputeInstanceCountersAsync();
Environment.Exit(0);
}
if (args.Contains("--migrate-storage"))
{
app.Logger.LogInformation("Migrating files to object storage, this will take a while...");
db.Database.SetCommandTimeout(0);
await provider.GetRequiredService<StorageMaintenanceService>()
.MigrateLocalFilesAsync(args.Contains("--purge"));
Environment.Exit(0);
}
if (args.Contains("--fixup-media"))
{
await provider.GetRequiredService<StorageMaintenanceService>().FixupMediaAsync(args.Contains("--dry-run"));
Environment.Exit(0);
}
if (args.Contains("--cleanup-storage"))
{
await provider.GetRequiredService<StorageMaintenanceService>()
.CleanupStorageAsync(args.Contains("--dry-run"));
Environment.Exit(0);
}
string[] userMgmtCommands =
[
"--create-user", "--create-admin-user", "--reset-password", "--grant-admin", "--revoke-admin"
];
if (args.FirstOrDefault(userMgmtCommands.Contains) is { } cmd)
{
if (args is not [not null, var username])
{
app.Logger.LogError("Invalid syntax. Usage: {cmd} <username>", cmd);
Environment.Exit(1);
return null!;
}
if (cmd is "--create-user" or "--create-admin-user")
{
var password = CryptographyHelpers.GenerateRandomString(16);
app.Logger.LogInformation("Creating user {username}...", username);
var userSvc = provider.GetRequiredService<UserService>();
await userSvc.CreateLocalUserAsync(username, password, null, force: true);
if (args[0] is "--create-admin-user")
{
await db.Users.Where(p => p.Username == username)
.ExecuteUpdateAsync(p => p.SetProperty(i => i.IsAdmin, true));
app.Logger.LogInformation("Successfully created admin user.");
}
else
{
app.Logger.LogInformation("Successfully created user.");
}
app.Logger.LogInformation("Username: {username}", username);
app.Logger.LogInformation("Password: {password}", password);
Environment.Exit(0);
}
if (cmd is "--reset-password")
{
var settings = await db.UserSettings
.FirstOrDefaultAsync(p => p.User.UsernameLower == username.ToLowerInvariant());
if (settings == null)
{
app.Logger.LogError("User {username} not found.", username);
Environment.Exit(1);
}
app.Logger.LogInformation("Resetting password for user {username}...", username);
var password = CryptographyHelpers.GenerateRandomString(16);
settings.Password = AuthHelpers.HashPassword(password);
await db.SaveChangesAsync();
app.Logger.LogInformation("Password for user {username} was reset to: {password}", username, password);
Environment.Exit(0);
}
if (cmd is "--grant-admin")
{
var user = await db.Users.FirstOrDefaultAsync(p => p.UsernameLower == username.ToLowerInvariant());
if (user == null)
{
app.Logger.LogError("User {username} not found.", username);
Environment.Exit(1);
}
else
{
user.IsAdmin = true;
await db.SaveChangesAsync();
app.Logger.LogInformation("Granted admin privileges to user {username}.", username);
Environment.Exit(0);
}
}
if (cmd is "--revoke-admin")
{
var user = await db.Users.FirstOrDefaultAsync(p => p.UsernameLower == username.ToLowerInvariant());
if (user == null)
{
app.Logger.LogError("User {username} not found.", username);
Environment.Exit(1);
}
else
{
user.IsAdmin = false;
await db.SaveChangesAsync();
app.Logger.LogInformation("Revoked admin privileges of user {username}.", username);
Environment.Exit(0);
}
}
}
var storageConfig = app.Configuration.GetSection("Storage").Get<Config.StorageSection>() ??
throw new Exception("Failed to read Storage config section");
if (storageConfig.Provider == Enums.FileStorage.Local)
{
if (string.IsNullOrWhiteSpace(storageConfig.Local?.Path) || !Directory.Exists(storageConfig.Local.Path))
{
app.Logger.LogCritical("Local storage path does not exist");
Environment.Exit(1);
}
else
{
try
{
var path = Path.Combine(storageConfig.Local.Path, Path.GetRandomFileName());
await using var fs = File.Create(path, 1, FileOptions.DeleteOnClose);
}
catch
{
app.Logger.LogCritical("Local storage path is not accessible or not writable");
Environment.Exit(1);
}
}
}
else if (storageConfig.Provider == Enums.FileStorage.ObjectStorage)
{
app.Logger.LogInformation("Verifying object storage configuration...");
var svc = provider.GetRequiredService<ObjectStorageService>();
try
{
await svc.VerifyCredentialsAsync();
}
catch (Exception e)
{
app.Logger.LogCritical("Failed to initialize object storage: {message}", e.Message);
Environment.Exit(1);
}
}
var tempPath = Environment.GetEnvironmentVariable("ASPNETCORE_TEMP") ?? Path.GetTempPath();
try
{
await using var stream = File.Create(Path.Combine(tempPath, ".iceshrimp-test"), 1, FileOptions.DeleteOnClose);
}
catch
{
app.Logger.LogCritical("Temporary directory {dir} is not writable. Please adjust permissions or set the ASPNETCORE_TEMP environment variable to a writable directory.",
tempPath);
Environment.Exit(1);
}
app.Logger.LogInformation("Initializing VAPID keys...");
var meta = provider.GetRequiredService<MetaService>();
await meta.EnsureSetAsync([MetaEntity.VapidPublicKey, MetaEntity.VapidPrivateKey], () =>
{
var keypair = VapidHelper.GenerateVapidKeys();
return [keypair.PublicKey, keypair.PrivateKey];
});
app.Logger.LogInformation("Warming up meta cache...");
await meta.WarmupCacheAsync();
// Initialize image processing
provider.GetRequiredService<ImageProcessor>();
return instanceConfig;
}
public static void SetKestrelUnixSocketPermissions(this WebApplication app)
{
var config = app.Configuration.GetSection("Instance").Get<Config.InstanceSection>()
?? throw new Exception("Failed to read instance config");
if (config.ListenSocket == null) return;
using var scope = app.Services.CreateScope();
var logger = scope.ServiceProvider.GetRequiredService<ILoggerFactory>()
.CreateLogger("Microsoft.Hosting.Lifetime");
if (!OperatingSystem.IsLinux() && !OperatingSystem.IsMacOS() && !OperatingSystem.IsFreeBSD())
throw new Exception("Can't set unix socket permissions on a non-UNIX system");
int perms;
try
{
perms = Convert.ToInt32(config.ListenSocketPerms, 8);
}
catch
{
logger.LogError("Failed to set Kestrel unix socket permissions to {SocketPerms}: failed to parse octal digits",
config.ListenSocketPerms);
Environment.Exit(1);
return;
}
var exitCode = chmod(config.ListenSocket, perms);
if (exitCode < 0)
{
logger.LogError("Failed to set Kestrel unix socket permissions to {SocketPerms}, return code: {ExitCode}",
config.ListenSocketPerms, exitCode);
}
else
{
logger.LogInformation("Kestrel unix socket permissions were set to {SocketPerms}",
config.ListenSocketPerms);
}
return;
[DllImport("libc")]
static extern int chmod(string pathname, int mode);
}
}
public interface IConditionalMiddleware
{
public static abstract bool Predicate(HttpContext ctx);
}
public interface IMiddlewareService : IMiddleware
{
public static abstract ServiceLifetime Lifetime { get; }
}
public class ConditionalMiddleware<T> : IConditionalMiddleware where T : Attribute
{
[SuppressMessage("ReSharper", "StaticMemberInGenericType", Justification = "Intended behavior")]
private static readonly ConcurrentDictionary<Endpoint, bool> Cache = [];
public static bool Predicate(HttpContext ctx)
=> ctx.GetEndpoint() is { } endpoint && Cache.GetOrAdd(endpoint, e => GetAttribute(e) != null);
private static T? GetAttribute(Endpoint? endpoint)
=> endpoint?.Metadata.GetMetadata<T>();
private static T? GetAttribute(HttpContext ctx)
=> GetAttribute(ctx.GetEndpoint());
protected static T GetAttributeOrFail(HttpContext ctx)
=> GetAttribute(ctx) ?? throw new Exception("Failed to get middleware filter attribute");
}