using System.Security.Claims; using System.Text.Encodings.Web; using Iceshrimp.Backend.Core.Database; using Iceshrimp.Backend.Core.Middleware; using Iceshrimp.Backend.Core.Services; using Microsoft.AspNetCore.Authentication; using Microsoft.AspNetCore.Authentication.BearerToken; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.SignalR; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.Options; namespace Iceshrimp.Backend.SignalR.Authentication; public class HubAuthorizationRequirement : IAuthorizationRequirement; public class HubAuthenticationHandler( IOptionsMonitor options, ILoggerFactory logger, UrlEncoder encoder, DatabaseContext db, UserService userSvc ) : AuthenticationHandler(options, logger, encoder) { protected override async Task HandleAuthenticateAsync() { string token; if (Request.Query.ContainsKey("access_token")) { token = Request.Query["access_token"].ToString(); } else { var header = Request.Headers.Authorization.ToString(); if (!header.ToLowerInvariant().StartsWith("bearer ")) return AuthenticateResult.NoResult(); token = header[7..]; } var session = await db.Sessions .Include(p => p.User.UserProfile) .Include(p => p.User.UserSettings) .FirstOrDefaultAsync(p => p.Token == token && p.Active); if (session is not { Active: true }) return AuthenticateResult.NoResult(); var claims = new[] { new Claim("token", token), new Claim("userId", session.UserId) }; var identity = new ClaimsIdentity(claims, nameof(HubAuthenticationHandler)); var ticket = new AuthenticationTicket(new ClaimsPrincipal(identity), Scheme.Name); userSvc.UpdateSessionMetadata(session); Context.SetSession(session); return AuthenticateResult.Success(ticket); } } public class HubAuthorizationHandler( IHttpContextAccessor httpContextAccessor ) : AuthorizationHandler { protected override Task HandleRequirementAsync( AuthorizationHandlerContext context, HubAuthorizationRequirement requirement ) { var ctx = httpContextAccessor.HttpContext; if (ctx == null) throw new Exception("HttpContext must not be null at this stage"); if (ctx.GetUser() == null) context.Fail(new AuthorizationFailureReason(this, "Unauthorized")); else context.Succeed(requirement); return Task.CompletedTask; } } public class HubUserIdProvider(IHttpContextAccessor httpContextAccessor) : IUserIdProvider { public string? GetUserId(HubConnectionContext connection) { if (httpContextAccessor.HttpContext == null) throw new Exception("HttpContext must not be null at this stage"); return httpContextAccessor.HttpContext.GetUser()?.Id; } } public static class AuthenticationServiceExtensions { public static void AddAuthenticationServices(this IServiceCollection services) { services.AddScoped() .AddSingleton() .AddSingleton() .AddSingleton(); } }