Implement IMiddleware for all middlewares

This commit is contained in:
Laura Hausmann 2024-01-24 19:10:43 +01:00
parent bb365ddc66
commit 2e4a1137ed
No known key found for this signature in database
GPG key ID: D044E84C5BE01605
4 changed files with 21 additions and 15 deletions

View file

@ -2,6 +2,7 @@ using Iceshrimp.Backend.Controllers.Renderers.ActivityPub;
using Iceshrimp.Backend.Core.Configuration;
using Iceshrimp.Backend.Core.Federation.ActivityPub;
using Iceshrimp.Backend.Core.Federation.WebFinger;
using Iceshrimp.Backend.Core.Middleware;
using Iceshrimp.Backend.Core.Services;
namespace Iceshrimp.Backend.Core.Extensions;
@ -18,11 +19,14 @@ public static class ServiceExtensions {
services.AddScoped<UserRenderer>();
services.AddScoped<NoteRenderer>();
services.AddScoped<WebFingerService>();
services.AddScoped<AuthorizedFetchMiddleware>();
// Singleton = instantiated once across application lifetime
services.AddSingleton<HttpClient>();
services.AddSingleton<HttpRequestService>();
services.AddSingleton<ActivityPubService>();
services.AddSingleton<ErrorHandlerMiddleware>();
services.AddSingleton<RequestBufferingMiddleware>();
}
public static void ConfigureServices(this IServiceCollection services, IConfiguration configuration) {

View file

@ -9,15 +9,17 @@ using Microsoft.Extensions.Options;
namespace Iceshrimp.Backend.Core.Middleware;
public class AuthorizedFetchMiddleware(RequestDelegate next) {
public async Task InvokeAsync(HttpContext context, IOptionsSnapshot<Config.SecuritySection> config,
DatabaseContext db, UserResolver userResolver,
ILogger<AuthorizedFetchMiddleware> logger) {
var endpoint = context.Features.Get<IEndpointFeature>()?.Endpoint;
public class AuthorizedFetchMiddleware(
IOptionsSnapshot<Config.SecuritySection> config,
DatabaseContext db,
UserResolver userResolver,
ILogger<AuthorizedFetchMiddleware> logger) : IMiddleware {
public async Task InvokeAsync(HttpContext ctx, RequestDelegate next) {
var endpoint = ctx.Features.Get<IEndpointFeature>()?.Endpoint;
var attribute = endpoint?.Metadata.GetMetadata<AuthorizedFetchAttribute>();
if (attribute != null && config.Value.AuthorizedFetch) {
var request = context.Request;
var request = ctx.Request;
if (!request.Headers.TryGetValue("signature", out var sigHeader))
throw new CustomException(HttpStatusCode.Unauthorized, "Request is missing the signature header");
@ -41,13 +43,13 @@ public class AuthorizedFetchMiddleware(RequestDelegate next) {
? ["(request-target)", "digest", "host", "date"]
: ["(request-target)", "host", "date"];
var verified = await HttpSignature.Verify(context.Request, sig, headers, key.KeyPem);
var verified = await HttpSignature.Verify(ctx.Request, sig, headers, key.KeyPem);
logger.LogDebug("HttpSignature.Verify returned {result} for key {keyId}", verified, sig.KeyId);
if (!verified)
throw new CustomException(HttpStatusCode.Forbidden, "Request signature validation failed");
}
await next(context);
await next(ctx);
}
}

View file

@ -3,8 +3,8 @@ using Iceshrimp.Backend.Controllers.Schemas;
namespace Iceshrimp.Backend.Core.Middleware;
public class ErrorHandlerMiddleware(RequestDelegate next) {
public async Task InvokeAsync(HttpContext ctx, ILoggerFactory loggerFactory) {
public class ErrorHandlerMiddleware(ILoggerFactory loggerFactory) : IMiddleware {
public async Task InvokeAsync(HttpContext ctx, RequestDelegate next) {
try {
await next(ctx);
}

View file

@ -2,14 +2,14 @@ using Microsoft.AspNetCore.Http.Features;
namespace Iceshrimp.Backend.Core.Middleware;
public class RequestBufferingMiddleware(RequestDelegate next) {
public async Task InvokeAsync(HttpContext context) {
var endpoint = context.Features.Get<IEndpointFeature>()?.Endpoint;
public class RequestBufferingMiddleware : IMiddleware {
public async Task InvokeAsync(HttpContext ctx, RequestDelegate next) {
var endpoint = ctx.Features.Get<IEndpointFeature>()?.Endpoint;
var attribute = endpoint?.Metadata.GetMetadata<EnableRequestBufferingAttribute>();
if (attribute != null) context.Request.EnableBuffering(attribute.MaxLength);
if (attribute != null) ctx.Request.EnableBuffering(attribute.MaxLength);
await next(context);
await next(ctx);
}
}