using System.Diagnostics.CodeAnalysis; using System.Threading.RateLimiting; using System.Xml.Linq; using Iceshrimp.Backend.Controllers.Federation; using Iceshrimp.Backend.Controllers.Mastodon.Renderers; using Iceshrimp.Backend.Controllers.Web.Renderers; using Iceshrimp.Backend.Core.Configuration; using Iceshrimp.Backend.Core.Database; using Iceshrimp.Backend.Core.Federation.WebFinger; using Iceshrimp.Backend.Core.Helpers.LibMfm.Conversion; using Iceshrimp.Backend.Core.Middleware; using Iceshrimp.Backend.Core.Services; using Iceshrimp.Backend.Core.Services.ImageProcessing; using Iceshrimp.Backend.SignalR.Authentication; using Iceshrimp.Shared.Configuration; using Iceshrimp.Shared.Schemas.Web; using Microsoft.AspNetCore.Authentication; using Microsoft.AspNetCore.DataProtection; using Microsoft.AspNetCore.DataProtection.AuthenticatedEncryption; using Microsoft.AspNetCore.DataProtection.AuthenticatedEncryption.ConfigurationModel; using Microsoft.AspNetCore.DataProtection.EntityFrameworkCore; using Microsoft.AspNetCore.DataProtection.KeyManagement; using Microsoft.AspNetCore.DataProtection.Repositories; using Microsoft.AspNetCore.Http.Json; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Options; using Microsoft.OpenApi.Models; using AuthenticationMiddleware = Iceshrimp.Backend.Core.Middleware.AuthenticationMiddleware; using AuthorizationMiddleware = Iceshrimp.Backend.Core.Middleware.AuthorizationMiddleware; using NoteRenderer = Iceshrimp.Backend.Controllers.Web.Renderers.NoteRenderer; using NotificationRenderer = Iceshrimp.Backend.Controllers.Web.Renderers.NotificationRenderer; using UserRenderer = Iceshrimp.Backend.Controllers.Web.Renderers.UserRenderer; namespace Iceshrimp.Backend.Core.Extensions; public static class ServiceExtensions { public static void AddServices(this IServiceCollection services, IConfiguration configuration) { // Transient = instantiated per request and class // Scoped = instantiated per request services .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped() .AddScoped(); // Singleton = instantiated once across application lifetime services .AddSingleton() .AddSingleton() .AddSingleton() .AddSingleton() .AddSingleton() .AddSingleton() .AddSingleton() .AddSingleton() .AddSingleton() .AddSingleton() .AddSingleton() .AddSingleton() .AddSingleton() .AddSingleton() .AddSingleton(); var config = configuration.GetSection("Storage").Get() ?? throw new Exception("Failed to read storage config section"); switch (config.MediaProcessing.ImageProcessor) { case Enums.ImageProcessor.LibVips: services.AddSingleton(); services.AddSingleton(); break; case Enums.ImageProcessor.ImageSharp: services.AddSingleton(); break; case Enums.ImageProcessor.None: break; default: throw new ArgumentOutOfRangeException(); } // Hosted services = long running background tasks // Note: These need to be added as a singleton as well to ensure data consistency services.AddHostedService(provider => provider.GetRequiredService()); services.AddHostedService(provider => provider.GetRequiredService()); services.AddHostedService(provider => provider.GetRequiredService()); } public static void ConfigureServices(this IServiceCollection services, IConfiguration configuration) { // @formatter:off services.ConfigureWithValidation(configuration) .ConfigureWithValidation(configuration, "Instance") .ConfigureWithValidation(configuration, "Security") .ConfigureWithValidation(configuration, "Performance") .ConfigureWithValidation(configuration, "Performance:QueueConcurrency") .ConfigureWithValidation(configuration, "Backfill") .ConfigureWithValidation(configuration, "Backfill:Replies") .ConfigureWithValidation(configuration, "Queue") .ConfigureWithValidation(configuration, "Queue:JobRetention") .ConfigureWithValidation(configuration, "Database") .ConfigureWithValidation(configuration, "Storage") .ConfigureWithValidation(configuration, "Storage:Local") .ConfigureWithValidation(configuration, "Storage:ObjectStorage") .ConfigureWithValidation(configuration, "Storage:MediaProcessing") .ConfigureWithValidation(configuration, "Storage:MediaProcessing:ImagePipeline") .ConfigureWithValidation(configuration, "Storage:MediaProcessing:ImagePipeline:Original:Local") .ConfigureWithValidation(configuration, "Storage:MediaProcessing:ImagePipeline:Original:Remote") .ConfigureWithValidation(configuration, "Storage:MediaProcessing:ImagePipeline:Thumbnail:Local") .ConfigureWithValidation(configuration, "Storage:MediaProcessing:ImagePipeline:Thumbnail:Remote") .ConfigureWithValidation(configuration, "Storage:MediaProcessing:ImagePipeline:Public:Local") .ConfigureWithValidation(configuration, "Storage:MediaProcessing:ImagePipeline:Public:Remote"); // @formatter:on services.Configure(options => { options.SerializerOptions.PropertyNamingPolicy = JsonSerialization.Options.PropertyNamingPolicy; foreach (var converter in JsonSerialization.Options.Converters) options.SerializerOptions.Converters.Add(converter); }); services.Configure(options => { options.JsonSerializerOptions.PropertyNamingPolicy = JsonSerialization.Options.PropertyNamingPolicy; foreach (var converter in JsonSerialization.Options.Converters) options.JsonSerializerOptions.Converters.Add(converter); }); } private static IServiceCollection ConfigureWithValidation( this IServiceCollection services, IConfiguration config ) where T : class { services.AddOptionsWithValidateOnStart() .Bind(config) .ValidateDataAnnotations(); return services; } private static IServiceCollection ConfigureWithValidation( this IServiceCollection services, IConfiguration config, string name ) where T : class { services.AddOptionsWithValidateOnStart() .Bind(config.GetSection(name)) .ValidateDataAnnotations(); return services; } public static void AddDatabaseContext(this IServiceCollection services, IConfiguration configuration) { var config = configuration.GetSection("Database").Get() ?? throw new Exception("Failed to initialize database: Failed to load configuration"); var dataSource = DatabaseContext.GetDataSource(config); services.AddDbContext(options => { DatabaseContext.Configure(options, dataSource, config); }); services.AddKeyedDatabaseContext("cache"); services.AddDataProtection() .PersistKeysToDbContextAsync() .UseCryptographicAlgorithms(new AuthenticatedEncryptorConfiguration { EncryptionAlgorithm = EncryptionAlgorithm.AES_256_CBC, ValidationAlgorithm = ValidationAlgorithm.HMACSHA256 }); } private static void AddKeyedDatabaseContext( this IServiceCollection services, string key, ServiceLifetime contextLifetime = ServiceLifetime.Scoped ) where T : DbContext { services.TryAdd(new ServiceDescriptor(typeof(T), key, typeof(T), contextLifetime)); } public static void AddSwaggerGenWithOptions(this IServiceCollection services) { services.AddEndpointsApiExplorer(); services.AddSwaggerGen(options => { options.SupportNonNullableReferenceTypes(); var version = new Config.InstanceSection().Version; options.SwaggerDoc("iceshrimp", new OpenApiInfo { Title = "Iceshrimp.NET", Version = version }); options.SwaggerDoc("federation", new OpenApiInfo { Title = "Federation", Version = version }); options.SwaggerDoc("mastodon", new OpenApiInfo { Title = "Mastodon", Version = version }); options.AddSecurityDefinition("iceshrimp", new OpenApiSecurityScheme { Name = "Authorization token", In = ParameterLocation.Header, Type = SecuritySchemeType.Http, Scheme = "bearer" }); options.AddSecurityDefinition("mastodon", new OpenApiSecurityScheme { Name = "Authorization token", In = ParameterLocation.Header, Type = SecuritySchemeType.Http, Scheme = "bearer" }); options.AddFilters(); }); } public static void AddSlidingWindowRateLimiter(this IServiceCollection services) { //TODO: rate limit status headers - maybe switch to https://github.com/stefanprodan/AspNetCoreRateLimit? //TODO: alternatively just write our own services.AddRateLimiter(options => { var sliding = new SlidingWindowRateLimiterOptions { PermitLimit = 500, SegmentsPerWindow = 60, Window = TimeSpan.FromSeconds(60), QueueProcessingOrder = QueueProcessingOrder.OldestFirst, QueueLimit = 0 }; var auth = new SlidingWindowRateLimiterOptions { PermitLimit = 10, SegmentsPerWindow = 60, Window = TimeSpan.FromSeconds(60), QueueProcessingOrder = QueueProcessingOrder.OldestFirst, QueueLimit = 0 }; var strict = new SlidingWindowRateLimiterOptions { PermitLimit = 3, SegmentsPerWindow = 60, Window = TimeSpan.FromSeconds(60), QueueProcessingOrder = QueueProcessingOrder.OldestFirst, QueueLimit = 0 }; // @formatter:off options.AddPolicy("sliding", ctx => RateLimitPartition.GetSlidingWindowLimiter(ctx.GetRateLimitPartition(false),_ => sliding)); options.AddPolicy("auth", ctx => RateLimitPartition.GetSlidingWindowLimiter(ctx.GetRateLimitPartition(false), _ => auth)); options.AddPolicy("strict", ctx => RateLimitPartition.GetSlidingWindowLimiter(ctx.GetRateLimitPartition(true), _ => strict)); // @formatter:on options.OnRejected = async (context, token) => { context.HttpContext.Response.StatusCode = 429; context.HttpContext.Response.ContentType = "application/json"; var res = new ErrorResponse(new Exception()) { Error = "Too Many Requests", StatusCode = 429, RequestId = context.HttpContext.TraceIdentifier }; await context.HttpContext.Response.WriteAsJsonAsync(res, token); }; }); } public static void AddCorsPolicies(this IServiceCollection services) { services.AddCors(options => { options.AddPolicy("well-known", policy => { policy.WithOrigins("*") .WithMethods("GET") .WithHeaders("Accept") .WithExposedHeaders("Vary"); }); options.AddPolicy("drive", policy => { policy.WithOrigins("*") .WithMethods("GET", "HEAD"); }); options.AddPolicy("mastodon", policy => { policy.WithOrigins("*") .WithMethods("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "CONNECT") .WithHeaders("Authorization", "Content-Type", "Idempotency-Key") .WithExposedHeaders("Link", "Connection", "Sec-Websocket-Accept", "Upgrade"); }); options.AddPolicy("fallback", policy => { policy.WithOrigins("*") .WithMethods("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "CONNECT") .WithHeaders("Authorization", "Content-Type", "Idempotency-Key") .WithExposedHeaders("Link", "Connection", "Sec-Websocket-Accept", "Upgrade"); }); }); } public static void AddAuthorizationPolicies(this IServiceCollection services) { services.AddAuthorizationBuilder() .AddPolicy("HubAuthorization", policy => { policy.Requirements.Add(new HubAuthorizationRequirement()); policy.AuthenticationSchemes = ["HubAuthenticationScheme"]; }); services.AddAuthentication(options => { options.AddScheme("HubAuthenticationScheme", null); // Add a stub authentication handler to bypass strange ASP.NET Core >=7.0 defaults // Ref: https://github.com/dotnet/aspnetcore/issues/44661 options.AddScheme("StubAuthenticationHandler", null); }); } } public static class HttpContextExtensions { public static string GetRateLimitPartition(this HttpContext ctx, bool includeRoute) => (includeRoute ? ctx.Request.Path.ToString() + "#" : "") + (GetRateLimitPartitionInternal(ctx) ?? ""); private static string? GetRateLimitPartitionInternal(this HttpContext ctx) => ctx.GetUser()?.Id ?? ctx.Request.Headers["X-Forwarded-For"].FirstOrDefault() ?? ctx.Connection.RemoteIpAddress?.ToString(); } #region AsyncDataProtection handlers /// /// Async equivalent of EntityFrameworkCoreDataProtectionExtensions.PersistKeysToDbContext. /// Required because Npgsql doesn't support the non-async APIs when using connection multiplexing, and the stock /// version EFCore API calls their blocking equivalents. /// file static class DataProtectionExtensions { public static IDataProtectionBuilder PersistKeysToDbContextAsync(this IDataProtectionBuilder builder) where TContext : DbContext, IDataProtectionKeyContext { builder.Services.AddSingleton>(services => { var loggerFactory = services.GetService() ?? NullLoggerFactory.Instance; return new ConfigureOptions(options => options.XmlRepository = new EntityFrameworkCoreXmlRepositoryAsync< TContext>(services, loggerFactory)); }); return builder; } } file sealed class EntityFrameworkCoreXmlRepositoryAsync : IXmlRepository where TContext : DbContext, IDataProtectionKeyContext { private readonly IServiceProvider _services; [DynamicDependency(DynamicallyAccessedMemberTypes.PublicProperties, typeof(DataProtectionKey))] public EntityFrameworkCoreXmlRepositoryAsync(IServiceProvider services, ILoggerFactory loggerFactory) { ArgumentNullException.ThrowIfNull(loggerFactory, nameof(loggerFactory)); _services = services ?? throw new ArgumentNullException(nameof(services)); } public IReadOnlyCollection GetAllElements() { return GetAllElementsCore().ToBlockingEnumerable().ToList().AsReadOnly(); async IAsyncEnumerable GetAllElementsCore() { using var scope = _services.CreateScope(); var @enum = scope.ServiceProvider.GetRequiredService() .DataProtectionKeys .AsNoTracking() .AsAsyncEnumerable(); await foreach (var dataProtectionKey in @enum) { if (!string.IsNullOrEmpty(dataProtectionKey.Xml)) yield return XElement.Parse(dataProtectionKey.Xml); } } } public void StoreElement(XElement element, string friendlyName) { using var scope = _services.CreateAsyncScope(); using var requiredService = scope.ServiceProvider.GetRequiredService(); requiredService.DataProtectionKeys.Add(new DataProtectionKey { FriendlyName = friendlyName, Xml = element.ToString(SaveOptions.DisableFormatting) }); requiredService.SaveChangesAsync().Wait(); } } #endregion