diff --git a/Iceshrimp.Backend/Core/Extensions/ServiceExtensions.cs b/Iceshrimp.Backend/Core/Extensions/ServiceExtensions.cs index b1784426..a94595f1 100644 --- a/Iceshrimp.Backend/Core/Extensions/ServiceExtensions.cs +++ b/Iceshrimp.Backend/Core/Extensions/ServiceExtensions.cs @@ -1,6 +1,8 @@ using System.Diagnostics.CodeAnalysis; +using System.Reflection; using System.Threading.RateLimiting; using System.Xml.Linq; +using Iceshrimp.AssemblyUtils; using Iceshrimp.Backend.Components.PublicPreview.Attributes; using Iceshrimp.Backend.Components.PublicPreview.Renderers; using Iceshrimp.Backend.Controllers.Federation; @@ -9,6 +11,7 @@ using Iceshrimp.Backend.Controllers.Web.Renderers; using Iceshrimp.Backend.Core.Configuration; using Iceshrimp.Backend.Core.Database; using Iceshrimp.Backend.Core.Federation.WebFinger; +using Iceshrimp.Backend.Core.Helpers; using Iceshrimp.Backend.Core.Helpers.LibMfm.Conversion; using Iceshrimp.Backend.Core.Middleware; using Iceshrimp.Backend.Core.Services; @@ -30,8 +33,6 @@ using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Options; using Microsoft.OpenApi.Models; -using AuthenticationMiddleware = Iceshrimp.Backend.Core.Middleware.AuthenticationMiddleware; -using AuthorizationMiddleware = Iceshrimp.Backend.Core.Middleware.AuthorizationMiddleware; using NoteRenderer = Iceshrimp.Backend.Controllers.Web.Renderers.NoteRenderer; using NotificationRenderer = Iceshrimp.Backend.Controllers.Web.Renderers.NotificationRenderer; using UserRenderer = Iceshrimp.Backend.Controllers.Web.Renderers.UserRenderer; @@ -68,10 +69,6 @@ public static class ServiceExtensions .AddScoped() .AddScoped() .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() .AddScoped() .AddScoped() .AddScoped() @@ -100,16 +97,10 @@ public static class ServiceExtensions .AddSingleton() .AddSingleton() .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() .AddSingleton() .AddSingleton() .AddSingleton() .AddSingleton() - .AddSingleton() .AddSingleton() .AddSingleton() .AddSingleton(); @@ -139,6 +130,21 @@ public static class ServiceExtensions services.AddHostedService(provider => provider.GetRequiredService()); } + public static void AddMiddleware(this IServiceCollection services) + { + var types = PluginLoader + .Assemblies.Prepend(Assembly.GetExecutingAssembly()) + .SelectMany(p => AssemblyLoader.GetImplementationsOfInterface(p, typeof(IMiddlewareService))); + + foreach (var type in types) + { + if (type.GetProperty(nameof(IMiddlewareService.Lifetime))?.GetValue(null) is not ServiceLifetime lifetime) + continue; + + services.Add(new ServiceDescriptor(type, type, lifetime)); + } + } + public static void ConfigureServices(this IServiceCollection services, IConfiguration configuration) { // @formatter:off diff --git a/Iceshrimp.Backend/Core/Extensions/WebApplicationExtensions.cs b/Iceshrimp.Backend/Core/Extensions/WebApplicationExtensions.cs index ba04adca..2d10da4c 100644 --- a/Iceshrimp.Backend/Core/Extensions/WebApplicationExtensions.cs +++ b/Iceshrimp.Backend/Core/Extensions/WebApplicationExtensions.cs @@ -1,3 +1,5 @@ +using System.Collections.Concurrent; +using System.Diagnostics.CodeAnalysis; using System.Runtime.InteropServices; using Iceshrimp.Backend.Core.Configuration; using Iceshrimp.Backend.Core.Database; @@ -32,6 +34,10 @@ public static class WebApplicationExtensions .UseMiddleware(); } + // Prevents conditional middleware from being invoked on non-matching requests + private static IApplicationBuilder UseMiddleware(this IApplicationBuilder app) where T : IConditionalMiddleware + => app.UseWhen(T.Predicate, builder => UseMiddlewareExtensions.UseMiddleware(builder)); + public static IApplicationBuilder UseOpenApiWithOptions(this WebApplication app) { app.MapSwagger("/openapi/{documentName}.{extension:regex(^(json|ya?ml)$)}") @@ -305,4 +311,32 @@ public static class WebApplicationExtensions [DllImport("libc")] static extern int chmod(string pathname, int mode); } +} + +public interface IConditionalMiddleware +{ + public static abstract bool Predicate(HttpContext ctx); +} + +public interface IMiddlewareService : IMiddleware +{ + public static abstract ServiceLifetime Lifetime { get; } +} + +public class ConditionalMiddleware : IConditionalMiddleware where T : Attribute +{ + [SuppressMessage("ReSharper", "StaticMemberInGenericType", Justification = "Intended behavior")] + private static readonly ConcurrentDictionary Cache = []; + + public static bool Predicate(HttpContext ctx) + => ctx.GetEndpoint() is { } endpoint && Cache.GetOrAdd(endpoint, e => GetAttribute(e) != null); + + private static T? GetAttribute(Endpoint? endpoint) + => endpoint?.Metadata.GetMetadata(); + + private static T? GetAttribute(HttpContext ctx) + => GetAttribute(ctx.GetEndpoint()); + + protected static T GetAttributeOrFail(HttpContext ctx) + => GetAttribute(ctx) ?? throw new Exception("Failed to get middleware filter attribute"); } \ No newline at end of file diff --git a/Iceshrimp.Backend/Core/Middleware/AuthenticationMiddleware.cs b/Iceshrimp.Backend/Core/Middleware/AuthenticationMiddleware.cs index c0a85066..f614b573 100644 --- a/Iceshrimp.Backend/Core/Middleware/AuthenticationMiddleware.cs +++ b/Iceshrimp.Backend/Core/Middleware/AuthenticationMiddleware.cs @@ -1,6 +1,7 @@ using Iceshrimp.Backend.Controllers.Mastodon.Attributes; using Iceshrimp.Backend.Core.Database; using Iceshrimp.Backend.Core.Database.Tables; +using Iceshrimp.Backend.Core.Extensions; using Iceshrimp.Backend.Core.Helpers; using Iceshrimp.Backend.Core.Helpers.LibMfm.Conversion; using Iceshrimp.Backend.Core.Services; @@ -9,101 +10,104 @@ using Microsoft.EntityFrameworkCore; namespace Iceshrimp.Backend.Core.Middleware; -public class AuthenticationMiddleware(DatabaseContext db, UserService userSvc, MfmConverter mfmConverter) : IMiddleware +public class AuthenticationMiddleware( + DatabaseContext db, + UserService userSvc, + MfmConverter mfmConverter +) : ConditionalMiddleware, IMiddlewareService { + public static ServiceLifetime Lifetime => ServiceLifetime.Scoped; + public async Task InvokeAsync(HttpContext ctx, RequestDelegate next) { var endpoint = ctx.GetEndpoint(); - var attribute = endpoint?.Metadata.GetMetadata(); + var attribute = GetAttributeOrFail(ctx); - if (attribute != null) + var isBlazorSsr = endpoint?.Metadata.GetMetadata() != null; + if (isBlazorSsr) { - var isBlazorSsr = endpoint?.Metadata.GetMetadata() != null; - if (isBlazorSsr) - { - await AuthenticateBlazorSsr(ctx, attribute); - await next(ctx); - return; - } + await AuthenticateBlazorSsr(ctx, attribute); + await next(ctx); + return; + } - ctx.Response.Headers.CacheControl = "private, no-store"; - var request = ctx.Request; - var header = request.Headers.Authorization.ToString(); - if (!header.ToLowerInvariant().StartsWith("bearer ")) + ctx.Response.Headers.CacheControl = "private, no-store"; + var request = ctx.Request; + var header = request.Headers.Authorization.ToString(); + if (!header.ToLowerInvariant().StartsWith("bearer ")) + { + await next(ctx); + return; + } + + var token = header[7..]; + + var isMastodon = endpoint?.Metadata.GetMetadata() != null; + if (isMastodon) + { + var oauthToken = await db.OauthTokens + .Include(p => p.User.UserProfile) + .Include(p => p.User.UserSettings) + .Include(p => p.App) + .FirstOrDefaultAsync(p => p.Token == token && p.Active); + + if (oauthToken?.User.IsSuspended == true) + throw GracefulException + .Unauthorized("Your access has been suspended by the instance administrator."); + + if (oauthToken == null) { await next(ctx); return; } - var token = header[7..]; - - var isMastodon = endpoint?.Metadata.GetMetadata() != null; - if (isMastodon) + if ((attribute.AdminRole && !oauthToken.User.IsAdmin) || + (attribute.ModeratorRole && + oauthToken.User is { IsAdmin: false, IsModerator: false })) { - var oauthToken = await db.OauthTokens - .Include(p => p.User.UserProfile) - .Include(p => p.User.UserSettings) - .Include(p => p.App) - .FirstOrDefaultAsync(p => p.Token == token && p.Active); - - if (oauthToken?.User.IsSuspended == true) - throw GracefulException - .Unauthorized("Your access has been suspended by the instance administrator."); - - if (oauthToken == null) - { - await next(ctx); - return; - } - - if ((attribute.AdminRole && !oauthToken.User.IsAdmin) || - (attribute.ModeratorRole && - oauthToken.User is { IsAdmin: false, IsModerator: false })) - { - await next(ctx); - return; - } - - if (attribute.Scopes.Length > 0 && - attribute.Scopes.Except(MastodonOauthHelpers.ExpandScopes(oauthToken.Scopes)).Any()) - { - await next(ctx); - return; - } - - userSvc.UpdateOauthTokenMetadata(oauthToken); - ctx.SetOauthToken(oauthToken); - - mfmConverter.SupportsHtmlFormatting = oauthToken.SupportsHtmlFormatting; + await next(ctx); + return; } - else + + if (attribute.Scopes.Length > 0 && + attribute.Scopes.Except(MastodonOauthHelpers.ExpandScopes(oauthToken.Scopes)).Any()) { - var session = await db.Sessions - .Include(p => p.User.UserProfile) - .Include(p => p.User.UserSettings) - .FirstOrDefaultAsync(p => p.Token == token && p.Active); - - if (session?.User.IsSuspended == true) - throw GracefulException - .Unauthorized("Your access has been suspended by the instance administrator."); - - if (session == null) - { - await next(ctx); - return; - } - - if ((attribute.AdminRole && !session.User.IsAdmin) || - (attribute.ModeratorRole && - session.User is { IsAdmin: false, IsModerator: false })) - { - await next(ctx); - return; - } - - userSvc.UpdateSessionMetadata(session); - ctx.SetSession(session); + await next(ctx); + return; } + + userSvc.UpdateOauthTokenMetadata(oauthToken); + ctx.SetOauthToken(oauthToken); + + mfmConverter.SupportsHtmlFormatting = oauthToken.SupportsHtmlFormatting; + } + else + { + var session = await db.Sessions + .Include(p => p.User.UserProfile) + .Include(p => p.User.UserSettings) + .FirstOrDefaultAsync(p => p.Token == token && p.Active); + + if (session?.User.IsSuspended == true) + throw GracefulException + .Unauthorized("Your access has been suspended by the instance administrator."); + + if (session == null) + { + await next(ctx); + return; + } + + if ((attribute.AdminRole && !session.User.IsAdmin) || + (attribute.ModeratorRole && + session.User is { IsAdmin: false, IsModerator: false })) + { + await next(ctx); + return; + } + + userSvc.UpdateSessionMetadata(session); + ctx.SetSession(session); } await next(ctx); diff --git a/Iceshrimp.Backend/Core/Middleware/AuthorizationMiddleware.cs b/Iceshrimp.Backend/Core/Middleware/AuthorizationMiddleware.cs index c91cd3f7..370f6c54 100644 --- a/Iceshrimp.Backend/Core/Middleware/AuthorizationMiddleware.cs +++ b/Iceshrimp.Backend/Core/Middleware/AuthorizationMiddleware.cs @@ -1,43 +1,41 @@ using Iceshrimp.Backend.Controllers.Mastodon.Attributes; +using Iceshrimp.Backend.Core.Extensions; using Iceshrimp.Backend.Core.Helpers; namespace Iceshrimp.Backend.Core.Middleware; -public class AuthorizationMiddleware : IMiddleware +public class AuthorizationMiddleware(RequestDelegate next) : ConditionalMiddleware { - public async Task InvokeAsync(HttpContext ctx, RequestDelegate next) + public async Task InvokeAsync(HttpContext ctx) { var endpoint = ctx.GetEndpoint(); - var attribute = endpoint?.Metadata.GetMetadata(); + var attribute = GetAttributeOrFail(ctx); - if (attribute != null) + ctx.Response.Headers.CacheControl = "private, no-store"; + var isMastodon = endpoint?.Metadata.GetMetadata() != null; + + if (isMastodon) { - ctx.Response.Headers.CacheControl = "private, no-store"; - var isMastodon = endpoint?.Metadata.GetMetadata() != null; - - if (isMastodon) - { - var token = ctx.GetOauthToken(); - if (token is not { Active: true }) - throw GracefulException.Unauthorized("This method requires an authenticated user"); - if (attribute.Scopes.Length > 0 && - attribute.Scopes.Except(MastodonOauthHelpers.ExpandScopes(token.Scopes)).Any()) - throw GracefulException.Forbidden("This action is outside the authorized scopes"); - if (attribute.AdminRole && !token.User.IsAdmin) - throw GracefulException.Forbidden("This action is outside the authorized scopes"); - if (attribute.ModeratorRole && token.User is { IsAdmin: false, IsModerator: false }) - throw GracefulException.Forbidden("This action is outside the authorized scopes"); - } - else - { - var session = ctx.GetSession(); - if (session is not { Active: true }) - throw GracefulException.Unauthorized("This method requires an authenticated user"); - if (attribute.AdminRole && !session.User.IsAdmin) - throw GracefulException.Forbidden("This action is outside the authorized scopes"); - if (attribute.ModeratorRole && session.User is { IsAdmin: false, IsModerator: false }) - throw GracefulException.Forbidden("This action is outside the authorized scopes"); - } + var token = ctx.GetOauthToken(); + if (token is not { Active: true }) + throw GracefulException.Unauthorized("This method requires an authenticated user"); + if (attribute.Scopes.Length > 0 && + attribute.Scopes.Except(MastodonOauthHelpers.ExpandScopes(token.Scopes)).Any()) + throw GracefulException.Forbidden("This action is outside the authorized scopes"); + if (attribute.AdminRole && !token.User.IsAdmin) + throw GracefulException.Forbidden("This action is outside the authorized scopes"); + if (attribute.ModeratorRole && token.User is { IsAdmin: false, IsModerator: false }) + throw GracefulException.Forbidden("This action is outside the authorized scopes"); + } + else + { + var session = ctx.GetSession(); + if (session is not { Active: true }) + throw GracefulException.Unauthorized("This method requires an authenticated user"); + if (attribute.AdminRole && !session.User.IsAdmin) + throw GracefulException.Forbidden("This action is outside the authorized scopes"); + if (attribute.ModeratorRole && session.User is { IsAdmin: false, IsModerator: false }) + throw GracefulException.Forbidden("This action is outside the authorized scopes"); } await next(ctx); diff --git a/Iceshrimp.Backend/Core/Middleware/AuthorizedFetchMiddleware.cs b/Iceshrimp.Backend/Core/Middleware/AuthorizedFetchMiddleware.cs index 26526834..f4b47819 100644 --- a/Iceshrimp.Backend/Core/Middleware/AuthorizedFetchMiddleware.cs +++ b/Iceshrimp.Backend/Core/Middleware/AuthorizedFetchMiddleware.cs @@ -3,6 +3,7 @@ using System.Net; using Iceshrimp.Backend.Core.Configuration; using Iceshrimp.Backend.Core.Database; using Iceshrimp.Backend.Core.Database.Tables; +using Iceshrimp.Backend.Core.Extensions; using Iceshrimp.Backend.Core.Federation.Cryptography; using Iceshrimp.Backend.Core.Services; using Microsoft.EntityFrameworkCore; @@ -21,107 +22,110 @@ public class AuthorizedFetchMiddleware( ActivityPub.FederationControlService fedCtrlSvc, ILogger logger, IHostApplicationLifetime appLifetime -) : IMiddleware +) : ConditionalMiddleware, IMiddlewareService { + public static ServiceLifetime Lifetime => ServiceLifetime.Scoped; + public async Task InvokeAsync(HttpContext ctx, RequestDelegate next) { - var attribute = ctx.GetEndpoint()?.Metadata.GetMetadata(); - - if (attribute != null && config.Value.AuthorizedFetch) + if (!config.Value.AuthorizedFetch) { - ctx.Response.Headers.CacheControl = "private, no-store"; + await next(ctx); + return; + } - var request = ctx.Request; - var ct = appLifetime.ApplicationStopping; + ctx.Response.Headers.CacheControl = "private, no-store"; - //TODO: cache this somewhere - var instanceActorUri = $"/users/{(await systemUserSvc.GetInstanceActorAsync()).Id}"; - if (request.Path.Value == instanceActorUri) + var request = ctx.Request; + var ct = appLifetime.ApplicationStopping; + + //TODO: cache this somewhere + var instanceActorUri = $"/users/{(await systemUserSvc.GetInstanceActorAsync()).Id}"; + if (request.Path.Value == instanceActorUri) + { + await next(ctx); + return; + } + + UserPublickey? key = null; + var verified = false; + + logger.LogTrace("Processing authorized fetch request for {path}", request.Path); + + try + { + if (!request.Headers.TryGetValue("signature", out var sigHeader)) + throw new GracefulException(HttpStatusCode.Unauthorized, "Request is missing the signature header"); + + var sig = HttpSignature.Parse(sigHeader.ToString()); + + if (await fedCtrlSvc.ShouldBlockAsync(sig.KeyId)) + throw new GracefulException(HttpStatusCode.Forbidden, "Forbidden", "Instance is blocked", + suppressLog: true); + + // First, we check if we already have the key + key = await db.UserPublickeys.Include(p => p.User) + .FirstOrDefaultAsync(p => p.KeyId == sig.KeyId, ct); + + // If we don't, we need to try to fetch it + if (key == null) { - await next(ctx); - return; + try + { + var user = await userResolver.ResolveAsync(sig.KeyId, ResolveFlags.Uri).WaitAsync(ct); + key = await db.UserPublickeys.Include(p => p.User) + .FirstOrDefaultAsync(p => p.User == user, ct); + + // If the key is still null here, we have a data consistency issue and need to update the key manually + key ??= await userSvc.UpdateUserPublicKeyAsync(user).WaitAsync(ct); + } + catch (Exception e) + { + if (e is GracefulException) throw; + throw new Exception($"Failed to fetch key of signature user ({sig.KeyId}) - {e.Message}"); + } } - UserPublickey? key = null; - var verified = false; + // If we still don't have the key, something went wrong and we need to throw an exception + if (key == null) throw new Exception($"Failed to fetch key of signature user ({sig.KeyId})"); - logger.LogTrace("Processing authorized fetch request for {path}", request.Path); + if (key.User.IsLocalUser) + throw new Exception("Remote user must have a host"); - try + // We want to check both the user host & the keyId host (as account & web domain might be different) + if (await fedCtrlSvc.ShouldBlockAsync(key.User.Host, key.KeyId)) + throw new GracefulException(HttpStatusCode.Forbidden, "Forbidden", "Instance is blocked", + suppressLog: true); + + List headers = ["(request-target)", "host"]; + + if (sig.Created != null && !sig.Headers.Contains("date")) + headers.Add("(created)"); + else + headers.Add("date"); + + verified = await HttpSignature.VerifyAsync(request, sig, headers, key.KeyPem); + logger.LogDebug("HttpSignature.Verify returned {result} for key {keyId}", verified, sig.KeyId); + + if (!verified) { - if (!request.Headers.TryGetValue("signature", out var sigHeader)) - throw new GracefulException(HttpStatusCode.Unauthorized, "Request is missing the signature header"); - - var sig = HttpSignature.Parse(sigHeader.ToString()); - - if (await fedCtrlSvc.ShouldBlockAsync(sig.KeyId)) - throw new GracefulException(HttpStatusCode.Forbidden, "Forbidden", "Instance is blocked", - suppressLog: true); - - // First, we check if we already have the key - key = await db.UserPublickeys.Include(p => p.User) - .FirstOrDefaultAsync(p => p.KeyId == sig.KeyId, ct); - - // If we don't, we need to try to fetch it - if (key == null) - { - try - { - var user = await userResolver.ResolveAsync(sig.KeyId, ResolveFlags.Uri).WaitAsync(ct); - key = await db.UserPublickeys.Include(p => p.User) - .FirstOrDefaultAsync(p => p.User == user, ct); - - // If the key is still null here, we have a data consistency issue and need to update the key manually - key ??= await userSvc.UpdateUserPublicKeyAsync(user).WaitAsync(ct); - } - catch (Exception e) - { - if (e is GracefulException) throw; - throw new Exception($"Failed to fetch key of signature user ({sig.KeyId}) - {e.Message}"); - } - } - - // If we still don't have the key, something went wrong and we need to throw an exception - if (key == null) throw new Exception($"Failed to fetch key of signature user ({sig.KeyId})"); - - if (key.User.IsLocalUser) - throw new Exception("Remote user must have a host"); - - // We want to check both the user host & the keyId host (as account & web domain might be different) - if (await fedCtrlSvc.ShouldBlockAsync(key.User.Host, key.KeyId)) - throw new GracefulException(HttpStatusCode.Forbidden, "Forbidden", "Instance is blocked", - suppressLog: true); - - List headers = ["(request-target)", "host"]; - - if (sig.Created != null && !sig.Headers.Contains("date")) - headers.Add("(created)"); - else - headers.Add("date"); - + logger.LogDebug("Refetching user key..."); + key = await userSvc.UpdateUserPublicKeyAsync(key); verified = await HttpSignature.VerifyAsync(request, sig, headers, key.KeyPem); logger.LogDebug("HttpSignature.Verify returned {result} for key {keyId}", verified, sig.KeyId); - - if (!verified) - { - logger.LogDebug("Refetching user key..."); - key = await userSvc.UpdateUserPublicKeyAsync(key); - verified = await HttpSignature.VerifyAsync(request, sig, headers, key.KeyPem); - logger.LogDebug("HttpSignature.Verify returned {result} for key {keyId}", verified, sig.KeyId); - } } - catch (Exception e) - { - if (e is AuthFetchException afe) throw GracefulException.Accepted(afe.Message); - if (e is GracefulException { SuppressLog: true }) throw; - logger.LogDebug("Error validating HTTP signature: {error}", e.Message); - } - - if (!verified || key == null) - throw new GracefulException(HttpStatusCode.Forbidden, "Request signature validation failed"); - - ctx.SetActor(key.User); } + catch (Exception e) + { + if (e is AuthFetchException afe) throw GracefulException.Accepted(afe.Message); + if (e is GracefulException { SuppressLog: true }) throw; + logger.LogDebug("Error validating HTTP signature: {error}", e.Message); + } + + if (!verified || key == null) + throw new GracefulException(HttpStatusCode.Forbidden, "Request signature validation failed"); + + ctx.SetActor(key.User); await next(ctx); } diff --git a/Iceshrimp.Backend/Core/Middleware/BlazorSsrHandoffMiddleware.cs b/Iceshrimp.Backend/Core/Middleware/BlazorSsrHandoffMiddleware.cs index 6c6ec47e..ade772ca 100644 --- a/Iceshrimp.Backend/Core/Middleware/BlazorSsrHandoffMiddleware.cs +++ b/Iceshrimp.Backend/Core/Middleware/BlazorSsrHandoffMiddleware.cs @@ -1,15 +1,16 @@ using System.Reflection; +using Iceshrimp.Backend.Core.Extensions; using Microsoft.AspNetCore.Components.Endpoints; namespace Iceshrimp.Backend.Core.Middleware; -public class BlazorSsrHandoffMiddleware : IMiddleware +public class BlazorSsrHandoffMiddleware(RequestDelegate next) : ConditionalMiddleware { - public async Task InvokeAsync(HttpContext context, RequestDelegate next) + public async Task InvokeAsync(HttpContext context) { var attribute = context.GetEndpoint() ?.Metadata.GetMetadata() - ?.Type.GetCustomAttributes() + ?.Type.GetCustomAttributes() .FirstOrDefault(); if (attribute != null) @@ -34,4 +35,4 @@ public class BlazorSsrHandoffMiddleware : IMiddleware } } -public class RazorSsrAttribute : Attribute; \ No newline at end of file +public class BlazorSsrAttribute : Attribute; \ No newline at end of file diff --git a/Iceshrimp.Backend/Core/Middleware/ErrorHandlerMiddleware.cs b/Iceshrimp.Backend/Core/Middleware/ErrorHandlerMiddleware.cs index 22b3c6a0..e2d51abb 100644 --- a/Iceshrimp.Backend/Core/Middleware/ErrorHandlerMiddleware.cs +++ b/Iceshrimp.Backend/Core/Middleware/ErrorHandlerMiddleware.cs @@ -15,11 +15,13 @@ namespace Iceshrimp.Backend.Core.Middleware; public class ErrorHandlerMiddleware( [SuppressMessage("ReSharper", "SuggestBaseTypeForParameterInConstructor")] - IOptionsSnapshot options, + IOptionsMonitor options, ILoggerFactory loggerFactory, RazorViewRenderService razor -) : IMiddleware +) : IMiddlewareService { + public static ServiceLifetime Lifetime => ServiceLifetime.Singleton; + private static readonly XmlSerializer XmlSerializer = new(typeof(ErrorResponse)); public async Task InvokeAsync(HttpContext ctx, RequestDelegate next) @@ -37,7 +39,7 @@ public class ErrorHandlerMiddleware( type = type[..(type.IndexOf('>') + 1)]; var logger = loggerFactory.CreateLogger(type); - var verbosity = options.Value.ExceptionVerbosity; + var verbosity = options.CurrentValue.ExceptionVerbosity; if (ctx.Response.HasStarted) { diff --git a/Iceshrimp.Backend/Core/Middleware/FederationSemaphoreMiddleware.cs b/Iceshrimp.Backend/Core/Middleware/FederationSemaphoreMiddleware.cs index df23d0c4..79442124 100644 --- a/Iceshrimp.Backend/Core/Middleware/FederationSemaphoreMiddleware.cs +++ b/Iceshrimp.Backend/Core/Middleware/FederationSemaphoreMiddleware.cs @@ -1,5 +1,6 @@ using System.Net; using Iceshrimp.Backend.Core.Configuration; +using Iceshrimp.Backend.Core.Extensions; using Iceshrimp.Backend.Core.Helpers; using Microsoft.Extensions.Options; @@ -8,8 +9,10 @@ namespace Iceshrimp.Backend.Core.Middleware; public class FederationSemaphoreMiddleware( IOptions config, IHostApplicationLifetime appLifetime -) : IMiddleware +) : ConditionalMiddleware, IMiddlewareService { + public static ServiceLifetime Lifetime => ServiceLifetime.Singleton; + private readonly SemaphorePlus _semaphore = new(Math.Max(config.Value.FederationRequestHandlerConcurrency, 1)); public async Task InvokeAsync(HttpContext ctx, RequestDelegate next) @@ -20,13 +23,6 @@ public class FederationSemaphoreMiddleware( return; } - var attribute = ctx.GetEndpoint()?.Metadata.GetMetadata(); - if (attribute == null) - { - await next(ctx); - return; - } - try { var cts = CancellationTokenSource diff --git a/Iceshrimp.Backend/Core/Middleware/InboxValidationMiddleware.cs b/Iceshrimp.Backend/Core/Middleware/InboxValidationMiddleware.cs index e8f79b16..929171ec 100644 --- a/Iceshrimp.Backend/Core/Middleware/InboxValidationMiddleware.cs +++ b/Iceshrimp.Backend/Core/Middleware/InboxValidationMiddleware.cs @@ -4,6 +4,7 @@ using System.Net.Http.Headers; using Iceshrimp.Backend.Core.Configuration; using Iceshrimp.Backend.Core.Database; using Iceshrimp.Backend.Core.Database.Tables; +using Iceshrimp.Backend.Core.Extensions; using Iceshrimp.Backend.Core.Federation.ActivityStreams; using Iceshrimp.Backend.Core.Federation.ActivityStreams.Types; using Iceshrimp.Backend.Core.Federation.Cryptography; @@ -25,233 +26,230 @@ public class InboxValidationMiddleware( ActivityPub.FederationControlService fedCtrlSvc, ILogger logger, IHostApplicationLifetime appLifetime -) : IMiddleware +) : ConditionalMiddleware, IMiddlewareService { + public static ServiceLifetime Lifetime => ServiceLifetime.Scoped; + private static readonly JsonSerializerSettings JsonSerializerSettings = new() { DateParseHandling = DateParseHandling.None }; public async Task InvokeAsync(HttpContext ctx, RequestDelegate next) { - var attribute = ctx.GetEndpoint()?.Metadata.GetMetadata(); + var request = ctx.Request; + var ct = appLifetime.ApplicationStopping; - if (attribute != null) + if (request is not { ContentType: not null, ContentLength: > 0 }) + throw GracefulException.UnprocessableEntity("Inbox request must have a body"); + + HttpSignature.HttpSignatureHeader? sig = null; + + if (request.Headers.TryGetValue("signature", out var sigHeader)) { - var request = ctx.Request; - var ct = appLifetime.ApplicationStopping; + try + { + sig = HttpSignature.Parse(sigHeader.ToString()); + if (await fedCtrlSvc.ShouldBlockAsync(sig.KeyId)) + throw new GracefulException(HttpStatusCode.Forbidden, "Forbidden", "Instance is blocked", + suppressLog: true); + } + catch (Exception e) + { + if (e is GracefulException { SuppressLog: true }) throw; + } + } - if (request is not { ContentType: not null, ContentLength: > 0 }) - throw GracefulException.UnprocessableEntity("Inbox request must have a body"); + var body = await new StreamReader(request.Body).ReadToEndAsync(ct); + request.Body.Seek(0, SeekOrigin.Begin); - HttpSignature.HttpSignatureHeader? sig = null; + JToken parsed; + try + { + parsed = JToken.Parse(body); + } + catch (Exception e) + { + logger.LogDebug("Failed to parse ASObject ({error}), skipping", e.Message); + return; + } - if (request.Headers.TryGetValue("signature", out var sigHeader)) + JArray? expanded; + try + { + expanded = LdHelpers.Expand(parsed); + if (expanded == null) throw new Exception("Failed to expand ASObject"); + } + catch (Exception e) + { + logger.LogDebug("Failed to expand ASObject ({error}), skipping", e.Message); + return; + } + + ASObject? obj; + try + { + obj = ASObject.Deserialize(expanded); + if (obj == null) throw new Exception("Failed to deserialize ASObject"); + } + catch (Exception e) + { + throw GracefulException + .UnprocessableEntity($"Failed to deserialize request body as ASObject: {e.Message}"); + } + + if (obj is not ASActivity activity) + throw new GracefulException(HttpStatusCode.UnprocessableEntity, + "Request body is not an ASActivity", $"Type: {obj.Type}"); + + UserPublickey? key = null; + var verified = false; + + try + { + if (sig == null) + throw new GracefulException(HttpStatusCode.Unauthorized, "Request is missing the signature header"); + + // First, we check if we already have the key + key = await db.UserPublickeys.Include(p => p.User) + .FirstOrDefaultAsync(p => p.KeyId == sig.KeyId, ct); + + // If we don't, we need to try to fetch it + if (key == null) { try { - sig = HttpSignature.Parse(sigHeader.ToString()); - if (await fedCtrlSvc.ShouldBlockAsync(sig.KeyId)) - throw new GracefulException(HttpStatusCode.Forbidden, "Forbidden", "Instance is blocked", - suppressLog: true); + var flags = activity is ASDelete + ? ResolveFlags.Uri | ResolveFlags.OnlyExisting + : ResolveFlags.Uri; + + var user = await userResolver.ResolveOrNullAsync(sig.KeyId, flags).WaitAsync(ct); + if (user == null) throw AuthFetchException.NotFound("Delete activity actor is unknown"); + key = await db.UserPublickeys.Include(p => p.User) + .FirstOrDefaultAsync(p => p.User == user, ct); + + // If the key is still null here, we have a data consistency issue and need to update the key manually + key ??= await userSvc.UpdateUserPublicKeyAsync(user).WaitAsync(ct); } catch (Exception e) { - if (e is GracefulException { SuppressLog: true }) throw; + if (e is GracefulException) throw; + throw new + GracefulException($"Failed to fetch key of signature user ({sig.KeyId}) - {e.Message}"); } } - var body = await new StreamReader(request.Body).ReadToEndAsync(ct); - request.Body.Seek(0, SeekOrigin.Begin); + // If we still don't have the key, something went wrong and we need to throw an exception + if (key == null) throw new GracefulException($"Failed to fetch key of signature user ({sig.KeyId})"); - JToken parsed; + if (key.User.IsLocalUser) + throw new Exception("Remote user must have a host"); + + // We want to check both the user host & the keyId host (as account & web domain might be different) + if (await fedCtrlSvc.ShouldBlockAsync(key.User.Host, key.KeyId)) + throw new GracefulException(HttpStatusCode.Forbidden, "Forbidden", "Instance is blocked", + suppressLog: true); + + List headers = ["(request-target)", "digest", "host"]; + + if (sig.Created != null && !sig.Headers.Contains("date")) + headers.Add("(created)"); + else + headers.Add("date"); + + verified = await HttpSignature.VerifyAsync(request, sig, headers, key.KeyPem); + logger.LogDebug("HttpSignature.Verify returned {result} for key {keyId}", verified, sig.KeyId); + + if (!verified) + { + logger.LogDebug("Refetching user key..."); + key = await userSvc.UpdateUserPublicKeyAsync(key); + verified = await HttpSignature.VerifyAsync(request, sig, headers, key.KeyPem); + logger.LogDebug("HttpSignature.Verify returned {result} for key {keyId}", verified, sig.KeyId); + } + } + catch (Exception e) + { + if (e is AuthFetchException afe) throw GracefulException.Accepted(afe.Message); + if (e is GracefulException { SuppressLog: true }) throw; + logger.LogDebug("Error validating HTTP signature: {error}", e.Message); + } + + if ( + (!verified || (key?.User.Uri != null && activity.Actor?.Id != key.User.Uri)) && + (activity is ASDelete || config.Value.AcceptLdSignatures) + ) + { + if (activity is ASDelete) + logger.LogDebug("Activity is ASDelete & actor uri is not matching, trying LD signature next..."); + else + logger.LogDebug("Trying LD signature next..."); try { - parsed = JToken.Parse(body); - } - catch (Exception e) - { - logger.LogDebug("Failed to parse ASObject ({error}), skipping", e.Message); - return; - } + var contentType = new MediaTypeHeaderValue(request.ContentType); + if (!ActivityPub.ActivityFetcherService.IsValidActivityContentType(contentType)) + throw new Exception("Request body is not an activity"); - JArray? expanded; - try - { - expanded = LdHelpers.Expand(parsed); - if (expanded == null) throw new Exception("Failed to expand ASObject"); - } - catch (Exception e) - { - logger.LogDebug("Failed to expand ASObject ({error}), skipping", e.Message); - return; - } + if (activity.Actor == null) + throw new Exception("Activity has no actor"); + if (await fedCtrlSvc.ShouldBlockAsync(new Uri(activity.Actor.Id).Host)) + throw new GracefulException(HttpStatusCode.Forbidden, "Forbidden", "Instance is blocked", + suppressLog: true); + key = null; + key = await db.UserPublickeys + .Include(p => p.User) + .FirstOrDefaultAsync(p => p.User.Uri == activity.Actor.Id, ct); - ASObject? obj; - try - { - obj = ASObject.Deserialize(expanded); - if (obj == null) throw new Exception("Failed to deserialize ASObject"); - } - catch (Exception e) - { - throw GracefulException - .UnprocessableEntity($"Failed to deserialize request body as ASObject: {e.Message}"); - } - - if (obj is not ASActivity activity) - throw new GracefulException(HttpStatusCode.UnprocessableEntity, - "Request body is not an ASActivity", $"Type: {obj.Type}"); - - UserPublickey? key = null; - var verified = false; - - try - { - if (sig == null) - throw new GracefulException(HttpStatusCode.Unauthorized, "Request is missing the signature header"); - - // First, we check if we already have the key - key = await db.UserPublickeys.Include(p => p.User) - .FirstOrDefaultAsync(p => p.KeyId == sig.KeyId, ct); - - // If we don't, we need to try to fetch it if (key == null) { - try - { - var flags = activity is ASDelete - ? ResolveFlags.Uri | ResolveFlags.OnlyExisting - : ResolveFlags.Uri; + var flags = activity is ASDelete + ? ResolveFlags.Uri | ResolveFlags.OnlyExisting + : ResolveFlags.Uri; - var user = await userResolver.ResolveOrNullAsync(sig.KeyId, flags).WaitAsync(ct); - if (user == null) throw AuthFetchException.NotFound("Delete activity actor is unknown"); - key = await db.UserPublickeys.Include(p => p.User) - .FirstOrDefaultAsync(p => p.User == user, ct); + var user = await userResolver + .ResolveOrNullAsync(activity.Actor.Id, flags) + .WaitAsync(ct); + if (user == null) throw AuthFetchException.NotFound("Delete activity actor is unknown"); + key = await db.UserPublickeys + .Include(p => p.User) + .FirstOrDefaultAsync(p => p.User == user, ct); - // If the key is still null here, we have a data consistency issue and need to update the key manually - key ??= await userSvc.UpdateUserPublicKeyAsync(user).WaitAsync(ct); - } - catch (Exception e) - { - if (e is GracefulException) throw; - throw new - GracefulException($"Failed to fetch key of signature user ({sig.KeyId}) - {e.Message}"); - } + if (key == null) + throw new Exception($"Failed to fetch public key for user {activity.Actor.Id}"); } - // If we still don't have the key, something went wrong and we need to throw an exception - if (key == null) throw new GracefulException($"Failed to fetch key of signature user ({sig.KeyId})"); - - if (key.User.IsLocalUser) - throw new Exception("Remote user must have a host"); - - // We want to check both the user host & the keyId host (as account & web domain might be different) - if (await fedCtrlSvc.ShouldBlockAsync(key.User.Host, key.KeyId)) + if (await fedCtrlSvc.ShouldBlockAsync(key.User.Host, new Uri(key.KeyId).Host)) throw new GracefulException(HttpStatusCode.Forbidden, "Forbidden", "Instance is blocked", suppressLog: true); - List headers = ["(request-target)", "digest", "host"]; - - if (sig.Created != null && !sig.Headers.Contains("date")) - headers.Add("(created)"); - else - headers.Add("date"); - - verified = await HttpSignature.VerifyAsync(request, sig, headers, key.KeyPem); - logger.LogDebug("HttpSignature.Verify returned {result} for key {keyId}", verified, sig.KeyId); - + // We need to re-run deserialize & expand with date time handling disabled for JSON-LD canonicalization to work correctly + var rawDeserialized = JsonConvert.DeserializeObject(body, JsonSerializerSettings); + var rawExpanded = LdHelpers.Expand(rawDeserialized); + if (rawExpanded == null) + throw new Exception("Failed to expand activity for LD signature processing"); + verified = await LdSignature.VerifyAsync(expanded, rawExpanded, key.KeyPem, key.KeyId); + logger.LogDebug("LdSignature.VerifyAsync returned {result} for actor {id}", + verified, activity.Actor.Id); if (!verified) { logger.LogDebug("Refetching user key..."); key = await userSvc.UpdateUserPublicKeyAsync(key); - verified = await HttpSignature.VerifyAsync(request, sig, headers, key.KeyPem); - logger.LogDebug("HttpSignature.Verify returned {result} for key {keyId}", verified, sig.KeyId); + verified = await LdSignature.VerifyAsync(expanded, rawExpanded, key.KeyPem, key.KeyId); + logger.LogDebug("LdSignature.VerifyAsync returned {result} for actor {id}", + verified, activity.Actor.Id); } } catch (Exception e) { if (e is AuthFetchException afe) throw GracefulException.Accepted(afe.Message); if (e is GracefulException { SuppressLog: true }) throw; - logger.LogDebug("Error validating HTTP signature: {error}", e.Message); + logger.LogError("Error validating JSON-LD signature: {error}", e.Message); } - - if ( - (!verified || (key?.User.Uri != null && activity.Actor?.Id != key.User.Uri)) && - (activity is ASDelete || config.Value.AcceptLdSignatures) - ) - { - if (activity is ASDelete) - logger.LogDebug("Activity is ASDelete & actor uri is not matching, trying LD signature next..."); - else - logger.LogDebug("Trying LD signature next..."); - try - { - var contentType = new MediaTypeHeaderValue(request.ContentType); - if (!ActivityPub.ActivityFetcherService.IsValidActivityContentType(contentType)) - throw new Exception("Request body is not an activity"); - - if (activity.Actor == null) - throw new Exception("Activity has no actor"); - if (await fedCtrlSvc.ShouldBlockAsync(new Uri(activity.Actor.Id).Host)) - throw new GracefulException(HttpStatusCode.Forbidden, "Forbidden", "Instance is blocked", - suppressLog: true); - key = null; - key = await db.UserPublickeys - .Include(p => p.User) - .FirstOrDefaultAsync(p => p.User.Uri == activity.Actor.Id, ct); - - if (key == null) - { - var flags = activity is ASDelete - ? ResolveFlags.Uri | ResolveFlags.OnlyExisting - : ResolveFlags.Uri; - - var user = await userResolver - .ResolveOrNullAsync(activity.Actor.Id, flags) - .WaitAsync(ct); - if (user == null) throw AuthFetchException.NotFound("Delete activity actor is unknown"); - key = await db.UserPublickeys - .Include(p => p.User) - .FirstOrDefaultAsync(p => p.User == user, ct); - - if (key == null) - throw new Exception($"Failed to fetch public key for user {activity.Actor.Id}"); - } - - if (await fedCtrlSvc.ShouldBlockAsync(key.User.Host, new Uri(key.KeyId).Host)) - throw new GracefulException(HttpStatusCode.Forbidden, "Forbidden", "Instance is blocked", - suppressLog: true); - - // We need to re-run deserialize & expand with date time handling disabled for JSON-LD canonicalization to work correctly - var rawDeserialized = JsonConvert.DeserializeObject(body, JsonSerializerSettings); - var rawExpanded = LdHelpers.Expand(rawDeserialized); - if (rawExpanded == null) - throw new Exception("Failed to expand activity for LD signature processing"); - verified = await LdSignature.VerifyAsync(expanded, rawExpanded, key.KeyPem, key.KeyId); - logger.LogDebug("LdSignature.VerifyAsync returned {result} for actor {id}", - verified, activity.Actor.Id); - if (!verified) - { - logger.LogDebug("Refetching user key..."); - key = await userSvc.UpdateUserPublicKeyAsync(key); - verified = await LdSignature.VerifyAsync(expanded, rawExpanded, key.KeyPem, key.KeyId); - logger.LogDebug("LdSignature.VerifyAsync returned {result} for actor {id}", - verified, activity.Actor.Id); - } - } - catch (Exception e) - { - if (e is AuthFetchException afe) throw GracefulException.Accepted(afe.Message); - if (e is GracefulException { SuppressLog: true }) throw; - logger.LogError("Error validating JSON-LD signature: {error}", e.Message); - } - } - - if (!verified || key == null) - throw new GracefulException(HttpStatusCode.Forbidden, "Request signature validation failed"); - - ctx.SetActor(key.User); } + if (!verified || key == null) + throw new GracefulException(HttpStatusCode.Forbidden, "Request signature validation failed"); + + ctx.SetActor(key.User); + await next(ctx); } } diff --git a/Iceshrimp.Backend/Core/Middleware/RequestBufferingMiddleware.cs b/Iceshrimp.Backend/Core/Middleware/RequestBufferingMiddleware.cs index 2370f18d..ecb069bd 100644 --- a/Iceshrimp.Backend/Core/Middleware/RequestBufferingMiddleware.cs +++ b/Iceshrimp.Backend/Core/Middleware/RequestBufferingMiddleware.cs @@ -1,16 +1,22 @@ +using Iceshrimp.Backend.Core.Extensions; +using JetBrains.Annotations; + namespace Iceshrimp.Backend.Core.Middleware; -public class RequestBufferingMiddleware : IMiddleware +[UsedImplicitly] +public class RequestBufferingMiddleware(RequestDelegate next) : ConditionalMiddleware { - public async Task InvokeAsync(HttpContext ctx, RequestDelegate next) + [UsedImplicitly] + public async Task InvokeAsync(HttpContext ctx) { - var attribute = ctx.GetEndpoint()?.Metadata.GetMetadata(); - if (attribute != null) ctx.Request.EnableBuffering(attribute.MaxLength); + var attr = GetAttributeOrFail(ctx); + ctx.Request.EnableBuffering(attr.MaxLength); await next(ctx); } } +[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method)] public class EnableRequestBufferingAttribute(long maxLength) : Attribute { - internal long MaxLength = maxLength; + internal readonly long MaxLength = maxLength; } \ No newline at end of file diff --git a/Iceshrimp.Backend/Core/Middleware/RequestDurationMiddleware.cs b/Iceshrimp.Backend/Core/Middleware/RequestDurationMiddleware.cs index aa0aaf01..64d23b69 100644 --- a/Iceshrimp.Backend/Core/Middleware/RequestDurationMiddleware.cs +++ b/Iceshrimp.Backend/Core/Middleware/RequestDurationMiddleware.cs @@ -1,11 +1,14 @@ using System.Diagnostics; using Iceshrimp.Backend.Core.Extensions; +using JetBrains.Annotations; namespace Iceshrimp.Backend.Core.Middleware; -public class RequestDurationMiddleware : IMiddleware +[UsedImplicitly] +public class RequestDurationMiddleware(RequestDelegate next) { - public async Task InvokeAsync(HttpContext ctx, RequestDelegate next) + [UsedImplicitly] + public async Task InvokeAsync(HttpContext ctx) { if (ctx.GetEndpoint()?.Metadata.GetMetadata() == null) { diff --git a/Iceshrimp.Backend/Core/Middleware/RequestVerificationMiddleware.cs b/Iceshrimp.Backend/Core/Middleware/RequestVerificationMiddleware.cs index 85259363..1f827087 100644 --- a/Iceshrimp.Backend/Core/Middleware/RequestVerificationMiddleware.cs +++ b/Iceshrimp.Backend/Core/Middleware/RequestVerificationMiddleware.cs @@ -1,4 +1,5 @@ using Iceshrimp.Backend.Core.Configuration; +using Iceshrimp.Backend.Core.Extensions; using Microsoft.Extensions.Options; namespace Iceshrimp.Backend.Core.Middleware; @@ -7,8 +8,10 @@ public class RequestVerificationMiddleware( IOptions config, IHostEnvironment environment, ILogger logger -) : IMiddleware +) : IMiddlewareService { + public static ServiceLifetime Lifetime => ServiceLifetime.Singleton; + private readonly bool _isDevelopment = environment.IsDevelopment(); public async Task InvokeAsync(HttpContext ctx, RequestDelegate next) diff --git a/Iceshrimp.Backend/Iceshrimp.Backend.csproj b/Iceshrimp.Backend/Iceshrimp.Backend.csproj index 82ed861c..18ef7611 100644 --- a/Iceshrimp.Backend/Iceshrimp.Backend.csproj +++ b/Iceshrimp.Backend/Iceshrimp.Backend.csproj @@ -47,7 +47,7 @@ - + diff --git a/Iceshrimp.Backend/Pages/Shared/RootComponent.razor b/Iceshrimp.Backend/Pages/Shared/RootComponent.razor index 3cf543ba..42cc73b1 100644 --- a/Iceshrimp.Backend/Pages/Shared/RootComponent.razor +++ b/Iceshrimp.Backend/Pages/Shared/RootComponent.razor @@ -7,7 +7,7 @@ @using Microsoft.AspNetCore.Components.Routing @inject IOptions Instance @preservewhitespace true -@attribute [RazorSsr] +@attribute [BlazorSsr] @inherits AsyncComponentBase diff --git a/Iceshrimp.Backend/Startup.cs b/Iceshrimp.Backend/Startup.cs index 642f2299..ce9da7c3 100644 --- a/Iceshrimp.Backend/Startup.cs +++ b/Iceshrimp.Backend/Startup.cs @@ -37,6 +37,7 @@ builder.Services.AddResponseCompression(); builder.Services.AddRazorPages(); builder.Services.AddRazorComponents(); builder.Services.AddAntiforgery(o => o.Cookie.Name = "CSRF-Token"); +builder.Services.AddMiddleware(); builder.Services.AddServices(builder.Configuration); builder.Services.ConfigureServices(builder.Configuration);