diff --git a/Iceshrimp.Backend/Controllers/Mastodon/Streaming/Channels/PublicChannel.cs b/Iceshrimp.Backend/Controllers/Mastodon/Streaming/Channels/PublicChannel.cs new file mode 100644 index 00000000..12705078 --- /dev/null +++ b/Iceshrimp.Backend/Controllers/Mastodon/Streaming/Channels/PublicChannel.cs @@ -0,0 +1,87 @@ +using System.Text.Json; +using Iceshrimp.Backend.Controllers.Mastodon.Renderers; +using Iceshrimp.Backend.Core.Database.Tables; + +namespace Iceshrimp.Backend.Controllers.Mastodon.Streaming.Channels; + +public class PublicChannel( + WebSocketConnection connection, + string name, + bool local, + bool remote, + bool onlyMedia +) : IChannel +{ + public string Name => name; + public List Scopes => ["read:statuses"]; + public bool IsSubscribed { get; private set; } + + public Task Subscribe(StreamingRequestMessage _) + { + if (IsSubscribed) return Task.CompletedTask; + IsSubscribed = true; + + connection.EventService.NotePublished += OnNotePublished; + connection.EventService.NoteUpdated += OnNoteUpdated; + connection.EventService.NoteDeleted += OnNoteDeleted; + return Task.CompletedTask; + } + + private bool IsApplicable(Note note) + { + if (note.Visibility != Note.NoteVisibility.Public) return false; + if (!local && note.UserHost == null) return false; + if (!remote && note.UserHost != null) return false; + if (onlyMedia && note.FileIds.Count == 0) return false; + + return true; + } + + private async void OnNotePublished(object? _, Note note) + { + if (!IsApplicable(note)) return; + var provider = connection.ScopeFactory.CreateScope().ServiceProvider; + var renderer = provider.GetRequiredService(); + var rendered = await renderer.RenderAsync(note, connection.Token.User); + var message = new StreamingUpdateMessage + { + Stream = [Name], Event = "update", Payload = JsonSerializer.Serialize(rendered) + }; + await connection.SendMessageAsync(JsonSerializer.Serialize(message)); + } + + private async void OnNoteUpdated(object? _, Note note) + { + if (!IsApplicable(note)) return; + var provider = connection.ScopeFactory.CreateScope().ServiceProvider; + var renderer = provider.GetRequiredService(); + var rendered = await renderer.RenderAsync(note, connection.Token.User); + var message = new StreamingUpdateMessage + { + Stream = [Name], Event = "status.update", Payload = JsonSerializer.Serialize(rendered) + }; + await connection.SendMessageAsync(JsonSerializer.Serialize(message)); + } + + private async void OnNoteDeleted(object? _, Note note) + { + if (!IsApplicable(note)) return; + var message = new StreamingUpdateMessage { Stream = [Name], Event = "delete", Payload = note.Id }; + await connection.SendMessageAsync(JsonSerializer.Serialize(message)); + } + + public Task Unsubscribe(StreamingRequestMessage _) + { + if (!IsSubscribed) return Task.CompletedTask; + IsSubscribed = false; + Dispose(); + return Task.CompletedTask; + } + + public void Dispose() + { + connection.EventService.NotePublished -= OnNotePublished; + connection.EventService.NoteUpdated -= OnNoteUpdated; + connection.EventService.NoteDeleted -= OnNoteDeleted; + } +} \ No newline at end of file diff --git a/Iceshrimp.Backend/Controllers/Mastodon/Streaming/Channels/UserChannel.cs b/Iceshrimp.Backend/Controllers/Mastodon/Streaming/Channels/UserChannel.cs new file mode 100644 index 00000000..77a371ab --- /dev/null +++ b/Iceshrimp.Backend/Controllers/Mastodon/Streaming/Channels/UserChannel.cs @@ -0,0 +1,108 @@ +using System.Text.Json; +using Iceshrimp.Backend.Controllers.Mastodon.Renderers; +using Iceshrimp.Backend.Core.Database; +using Iceshrimp.Backend.Core.Database.Tables; +using Microsoft.EntityFrameworkCore; + +namespace Iceshrimp.Backend.Controllers.Mastodon.Streaming.Channels; + +public class UserChannel(WebSocketConnection connection, bool notificationsOnly) : IChannel +{ + public string Name => notificationsOnly ? "user:notification" : "user"; + public List Scopes => ["read:statuses", "read:notifications"]; + public bool IsSubscribed { get; private set; } + + private List _followedUsers = []; + + public async Task Subscribe(StreamingRequestMessage _) + { + if (IsSubscribed) return; + IsSubscribed = true; + + var provider = connection.ScopeFactory.CreateScope().ServiceProvider; + var db = provider.GetRequiredService(); + + _followedUsers = await db.Users.Where(p => p == connection.Token.User) + .SelectMany(p => p.Following) + .Select(p => p.Id) + .ToListAsync(); + + if (!notificationsOnly) + { + connection.EventService.NotePublished += OnNotePublished; + connection.EventService.NoteUpdated += OnNoteUpdated; + connection.EventService.NoteDeleted += OnNoteDeleted; + } + + connection.EventService.Notification += OnNotification; + } + + private bool IsApplicable(Note note) => _followedUsers.Prepend(connection.Token.User.Id).Contains(note.UserId); + private bool IsApplicable(Notification notification) => notification.NotifieeId == connection.Token.User.Id; + + private async void OnNotePublished(object? _, Note note) + { + if (!IsApplicable(note)) return; + var provider = connection.ScopeFactory.CreateScope().ServiceProvider; + var renderer = provider.GetRequiredService(); + var rendered = await renderer.RenderAsync(note, connection.Token.User); + var message = new StreamingUpdateMessage + { + Stream = [Name], Event = "update", Payload = JsonSerializer.Serialize(rendered) + }; + await connection.SendMessageAsync(JsonSerializer.Serialize(message)); + } + + private async void OnNoteUpdated(object? _, Note note) + { + if (!IsApplicable(note)) return; + var provider = connection.ScopeFactory.CreateScope().ServiceProvider; + var renderer = provider.GetRequiredService(); + var rendered = await renderer.RenderAsync(note, connection.Token.User); + var message = new StreamingUpdateMessage + { + Stream = [Name], Event = "status.update", Payload = JsonSerializer.Serialize(rendered) + }; + await connection.SendMessageAsync(JsonSerializer.Serialize(message)); + } + + private async void OnNoteDeleted(object? _, Note note) + { + if (!IsApplicable(note)) return; + var message = new StreamingUpdateMessage { Stream = [Name], Event = "delete", Payload = note.Id }; + await connection.SendMessageAsync(JsonSerializer.Serialize(message)); + } + + private async void OnNotification(object? _, Notification notification) + { + if (!IsApplicable(notification)) return; + var provider = connection.ScopeFactory.CreateScope().ServiceProvider; + var renderer = provider.GetRequiredService(); + var rendered = await renderer.RenderAsync(notification, connection.Token.User); + var message = new StreamingUpdateMessage + { + Stream = [Name], Event = "notification", Payload = JsonSerializer.Serialize(rendered) + }; + await connection.SendMessageAsync(JsonSerializer.Serialize(message)); + } + + public Task Unsubscribe(StreamingRequestMessage _) + { + if (!IsSubscribed) return Task.CompletedTask; + IsSubscribed = false; + Dispose(); + return Task.CompletedTask; + } + + public void Dispose() + { + if (!notificationsOnly) + { + connection.EventService.NotePublished -= OnNotePublished; + connection.EventService.NoteUpdated -= OnNoteUpdated; + connection.EventService.NoteDeleted -= OnNoteDeleted; + } + + connection.EventService.Notification -= OnNotification; + } +} \ No newline at end of file diff --git a/Iceshrimp.Backend/Controllers/Mastodon/Streaming/StreamingRequestMessage.cs b/Iceshrimp.Backend/Controllers/Mastodon/Streaming/StreamingRequestMessage.cs new file mode 100644 index 00000000..9fc2310e --- /dev/null +++ b/Iceshrimp.Backend/Controllers/Mastodon/Streaming/StreamingRequestMessage.cs @@ -0,0 +1,18 @@ +using J = System.Text.Json.Serialization.JsonPropertyNameAttribute; + +namespace Iceshrimp.Backend.Controllers.Mastodon.Streaming; + +public class StreamingRequestMessage +{ + [J("type")] public required string Type { get; set; } + [J("stream")] public required string Stream { get; set; } + [J("list")] public string? List { get; set; } + [J("tag")] public string? Tag { get; set; } +} + +public class StreamingUpdateMessage +{ + [J("stream")] public required List Stream { get; set; } + [J("event")] public required string Event { get; set; } + [J("payload")] public required string Payload { get; set; } +} \ No newline at end of file diff --git a/Iceshrimp.Backend/Controllers/Mastodon/Streaming/WebSocketConnection.cs b/Iceshrimp.Backend/Controllers/Mastodon/Streaming/WebSocketConnection.cs new file mode 100644 index 00000000..dc2113c1 --- /dev/null +++ b/Iceshrimp.Backend/Controllers/Mastodon/Streaming/WebSocketConnection.cs @@ -0,0 +1,118 @@ +using System.Net.WebSockets; +using System.Text; +using System.Text.Json; +using Iceshrimp.Backend.Controllers.Mastodon.Streaming.Channels; +using Iceshrimp.Backend.Core.Database.Tables; +using Iceshrimp.Backend.Core.Helpers; +using Iceshrimp.Backend.Core.Services; + +namespace Iceshrimp.Backend.Controllers.Mastodon.Streaming; + +public sealed class WebSocketConnection( + WebSocket socket, + OauthToken token, + EventService eventSvc, + IServiceScopeFactory scopeFactory, + CancellationToken ct +) : IDisposable +{ + public readonly OauthToken Token = token; + public readonly List Channels = []; + public readonly EventService EventService = eventSvc; + public readonly IServiceScopeFactory ScopeFactory = scopeFactory; + private readonly SemaphoreSlim _lock = new(1); + + public void InitializeStreamingWorker() + { + Channels.Add(new UserChannel(this, true)); + Channels.Add(new UserChannel(this, false)); + Channels.Add(new PublicChannel(this, "public", true, true, false)); + Channels.Add(new PublicChannel(this, "public:media", true, true, true)); + Channels.Add(new PublicChannel(this, "public:allow_local_only", true, true, false)); + Channels.Add(new PublicChannel(this, "public:allow_local_only:media", true, true, true)); + Channels.Add(new PublicChannel(this, "public:local", true, false, false)); + Channels.Add(new PublicChannel(this, "public:local:media", true, false, true)); + Channels.Add(new PublicChannel(this, "public:remote", false, true, false)); + Channels.Add(new PublicChannel(this, "public:remote:media", false, true, true)); + } + + public void Dispose() + { + foreach (var channel in Channels) + channel.Dispose(); + } + + public async Task HandleSocketMessageAsync(string payload) + { + StreamingRequestMessage? message = null; + try + { + message = JsonSerializer.Deserialize(payload); + } + catch + { + // ignored + } + + if (message == null) + { + await CloseAsync(WebSocketCloseStatus.InvalidPayloadData); + return; + } + + switch (message.Type) + { + case "subscribe": + { + var channel = Channels.FirstOrDefault(p => p.Name == message.Stream && !p.IsSubscribed); + if (channel == null) return; + if (channel.Scopes.Except(MastodonOauthHelpers.ExpandScopes(Token.Scopes)).Any()) + await CloseAsync(WebSocketCloseStatus.PolicyViolation); + else + await channel.Subscribe(message); + break; + } + case "unsubscribe": + { + var channel = Channels.FirstOrDefault(p => p.Name == message.Stream && p.IsSubscribed); + if (channel != null) await channel.Unsubscribe(message); + break; + } + default: + { + await CloseAsync(WebSocketCloseStatus.InvalidPayloadData); + return; + } + } + } + + public async Task SendMessageAsync(string message) + { + await _lock.WaitAsync(ct); + try + { + var buffer = new ArraySegment(Encoding.UTF8.GetBytes(message)); + await socket.SendAsync(buffer, WebSocketMessageType.Text, true, ct); + } + finally + { + _lock.Release(); + } + } + + public async Task CloseAsync(WebSocketCloseStatus status) + { + Dispose(); + await socket.CloseAsync(status, null, ct); + } +} + +public interface IChannel +{ + public string Name { get; } + public List Scopes { get; } + public bool IsSubscribed { get; } + public Task Subscribe(StreamingRequestMessage message); + public Task Unsubscribe(StreamingRequestMessage message); + public void Dispose(); +} \ No newline at end of file diff --git a/Iceshrimp.Backend/Controllers/Mastodon/Streaming/WebSocketHandler.cs b/Iceshrimp.Backend/Controllers/Mastodon/Streaming/WebSocketHandler.cs new file mode 100644 index 00000000..bd3f1cc1 --- /dev/null +++ b/Iceshrimp.Backend/Controllers/Mastodon/Streaming/WebSocketHandler.cs @@ -0,0 +1,50 @@ +using System.Net.WebSockets; +using System.Text; +using Iceshrimp.Backend.Core.Database.Tables; +using Iceshrimp.Backend.Core.Services; + +namespace Iceshrimp.Backend.Controllers.Mastodon.Streaming; + +public static class WebSocketHandler +{ + public static async Task HandleConnectionAsync( + WebSocket socket, OauthToken token, EventService eventSvc, IServiceScopeFactory scopeFactory, + CancellationToken ct + ) + { + using var connection = new WebSocketConnection(socket, token, eventSvc, scopeFactory, ct); + var buffer = new byte[256]; + + WebSocketReceiveResult? res = null; + + connection.InitializeStreamingWorker(); + + while ((!res?.CloseStatus.HasValue ?? true) && + !ct.IsCancellationRequested && + socket.State is WebSocketState.Open) + { + res = await socket.ReceiveAsync(new ArraySegment(buffer), ct); + + if (res.Count > buffer.Length) + { + await socket.CloseAsync(WebSocketCloseStatus.MessageTooBig, null, ct); + break; + } + + if (res.MessageType == WebSocketMessageType.Text) + await connection.HandleSocketMessageAsync(Encoding.UTF8.GetString(buffer, 0, res.Count)); + else if (res.MessageType == WebSocketMessageType.Binary) + break; + } + + if (socket.State is not WebSocketState.Open and not WebSocketState.CloseReceived) + return; + + if (res?.CloseStatus != null) + await socket.CloseAsync(res.CloseStatus.Value, res.CloseStatusDescription, ct); + else if (!ct.IsCancellationRequested) + await socket.CloseAsync(WebSocketCloseStatus.InvalidMessageType, null, ct); + else + await socket.CloseAsync(WebSocketCloseStatus.EndpointUnavailable, null, ct); + } +} \ No newline at end of file diff --git a/Iceshrimp.Backend/Controllers/Mastodon/WebSocketController.cs b/Iceshrimp.Backend/Controllers/Mastodon/WebSocketController.cs new file mode 100644 index 00000000..91b61515 --- /dev/null +++ b/Iceshrimp.Backend/Controllers/Mastodon/WebSocketController.cs @@ -0,0 +1,58 @@ +using System.Net.WebSockets; +using Iceshrimp.Backend.Controllers.Mastodon.Attributes; +using Iceshrimp.Backend.Controllers.Mastodon.Streaming; +using Iceshrimp.Backend.Core.Database; +using Iceshrimp.Backend.Core.Database.Tables; +using Iceshrimp.Backend.Core.Middleware; +using Iceshrimp.Backend.Core.Services; +using Microsoft.AspNetCore.Mvc; +using Microsoft.EntityFrameworkCore; + +namespace Iceshrimp.Backend.Controllers.Mastodon; + +[MastodonApiController] +public class WebSocketController( + IHostApplicationLifetime appLifetime, + DatabaseContext db, + EventService eventSvc, + IServiceScopeFactory scopeFactory, + ILogger logger +) : ControllerBase +{ + [Route("/api/v1/streaming")] + [ApiExplorerSettings(IgnoreApi = true)] + public async Task GetStreamingSocket() + { + if (!HttpContext.WebSockets.IsWebSocketRequest) + throw GracefulException.BadRequest("Not a WebSocket request"); + + var ct = appLifetime.ApplicationStopping; + var accessToken = HttpContext.WebSockets.WebSocketRequestedProtocols.FirstOrDefault() ?? + throw GracefulException.BadRequest("Missing WebSocket protocol header"); + + using var webSocket = await HttpContext.WebSockets.AcceptWebSocketAsync(); + try + { + var token = await Authenticate(accessToken); + await WebSocketHandler.HandleConnectionAsync(webSocket, token, eventSvc, scopeFactory, ct); + } + catch (Exception e) + { + if (e is WebSocketException) + logger.LogDebug("WebSocket connection {id} encountered an error: {error}", + HttpContext.TraceIdentifier, e.Message); + else if (!ct.IsCancellationRequested) + throw; + } + } + + private async Task Authenticate(string token) + { + return await db.OauthTokens + .Include(p => p.User) + .ThenInclude(p => p.UserProfile) + .Include(p => p.App) + .FirstOrDefaultAsync(p => p.Token == token && p.Active) ?? + throw GracefulException.Unauthorized("This method requires an authenticated user"); + } +} \ No newline at end of file diff --git a/Iceshrimp.Backend/Core/Extensions/ServiceExtensions.cs b/Iceshrimp.Backend/Core/Extensions/ServiceExtensions.cs index 10189f7a..abfe585d 100644 --- a/Iceshrimp.Backend/Core/Extensions/ServiceExtensions.cs +++ b/Iceshrimp.Backend/Core/Extensions/ServiceExtensions.cs @@ -218,6 +218,7 @@ public static class ServiceExtensions options.AddPolicy("mastodon", policy => { policy.WithOrigins("*") + .WithMethods("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "CONNECT") .WithHeaders("Authorization", "Content-Type", "Idempotency-Key") .WithExposedHeaders("Link", "Connection", "Sec-Websocket-Accept", "Upgrade"); }); diff --git a/Iceshrimp.Backend/Core/Services/EventService.cs b/Iceshrimp.Backend/Core/Services/EventService.cs index b415e7b9..e4d30fd0 100644 --- a/Iceshrimp.Backend/Core/Services/EventService.cs +++ b/Iceshrimp.Backend/Core/Services/EventService.cs @@ -7,14 +7,14 @@ public class EventService { public event EventHandler? NotePublished; public event EventHandler? NoteUpdated; - public event EventHandler? NoteDeleted; + public event EventHandler? NoteDeleted; public event EventHandler? NoteLiked; public event EventHandler? NoteUnliked; public event EventHandler? Notification; public void RaiseNotePublished(object? sender, Note note) => NotePublished?.Invoke(sender, note); public void RaiseNoteUpdated(object? sender, Note note) => NoteUpdated?.Invoke(sender, note); - public void RaiseNoteDeleted(object? sender, Note note) => NoteDeleted?.Invoke(sender, note.Id); + public void RaiseNoteDeleted(object? sender, Note note) => NoteDeleted?.Invoke(sender, note); public void RaiseNotification(object? sender, Notification notification) => Notification?.Invoke(sender, notification); diff --git a/Iceshrimp.Backend/Startup.cs b/Iceshrimp.Backend/Startup.cs index 2c8b37e0..b5089e8f 100644 --- a/Iceshrimp.Backend/Startup.cs +++ b/Iceshrimp.Backend/Startup.cs @@ -45,6 +45,7 @@ app.UseStaticFiles(); app.UseRateLimiter(); app.UseCors(); app.UseAuthorization(); +app.UseWebSockets(); app.UseCustomMiddleware(); app.MapControllers();