[backend/masto-client] Add basic WebSocket support

This commit is contained in:
Laura Hausmann 2024-02-21 02:31:24 +01:00
parent c6a2a99c1b
commit 9b99f9245f
No known key found for this signature in database
GPG key ID: D044E84C5BE01605
9 changed files with 443 additions and 2 deletions

View file

@ -0,0 +1,87 @@
using System.Text.Json;
using Iceshrimp.Backend.Controllers.Mastodon.Renderers;
using Iceshrimp.Backend.Core.Database.Tables;
namespace Iceshrimp.Backend.Controllers.Mastodon.Streaming.Channels;
public class PublicChannel(
WebSocketConnection connection,
string name,
bool local,
bool remote,
bool onlyMedia
) : IChannel
{
public string Name => name;
public List<string> Scopes => ["read:statuses"];
public bool IsSubscribed { get; private set; }
public Task Subscribe(StreamingRequestMessage _)
{
if (IsSubscribed) return Task.CompletedTask;
IsSubscribed = true;
connection.EventService.NotePublished += OnNotePublished;
connection.EventService.NoteUpdated += OnNoteUpdated;
connection.EventService.NoteDeleted += OnNoteDeleted;
return Task.CompletedTask;
}
private bool IsApplicable(Note note)
{
if (note.Visibility != Note.NoteVisibility.Public) return false;
if (!local && note.UserHost == null) return false;
if (!remote && note.UserHost != null) return false;
if (onlyMedia && note.FileIds.Count == 0) return false;
return true;
}
private async void OnNotePublished(object? _, Note note)
{
if (!IsApplicable(note)) return;
var provider = connection.ScopeFactory.CreateScope().ServiceProvider;
var renderer = provider.GetRequiredService<NoteRenderer>();
var rendered = await renderer.RenderAsync(note, connection.Token.User);
var message = new StreamingUpdateMessage
{
Stream = [Name], Event = "update", Payload = JsonSerializer.Serialize(rendered)
};
await connection.SendMessageAsync(JsonSerializer.Serialize(message));
}
private async void OnNoteUpdated(object? _, Note note)
{
if (!IsApplicable(note)) return;
var provider = connection.ScopeFactory.CreateScope().ServiceProvider;
var renderer = provider.GetRequiredService<NoteRenderer>();
var rendered = await renderer.RenderAsync(note, connection.Token.User);
var message = new StreamingUpdateMessage
{
Stream = [Name], Event = "status.update", Payload = JsonSerializer.Serialize(rendered)
};
await connection.SendMessageAsync(JsonSerializer.Serialize(message));
}
private async void OnNoteDeleted(object? _, Note note)
{
if (!IsApplicable(note)) return;
var message = new StreamingUpdateMessage { Stream = [Name], Event = "delete", Payload = note.Id };
await connection.SendMessageAsync(JsonSerializer.Serialize(message));
}
public Task Unsubscribe(StreamingRequestMessage _)
{
if (!IsSubscribed) return Task.CompletedTask;
IsSubscribed = false;
Dispose();
return Task.CompletedTask;
}
public void Dispose()
{
connection.EventService.NotePublished -= OnNotePublished;
connection.EventService.NoteUpdated -= OnNoteUpdated;
connection.EventService.NoteDeleted -= OnNoteDeleted;
}
}

View file

@ -0,0 +1,108 @@
using System.Text.Json;
using Iceshrimp.Backend.Controllers.Mastodon.Renderers;
using Iceshrimp.Backend.Core.Database;
using Iceshrimp.Backend.Core.Database.Tables;
using Microsoft.EntityFrameworkCore;
namespace Iceshrimp.Backend.Controllers.Mastodon.Streaming.Channels;
public class UserChannel(WebSocketConnection connection, bool notificationsOnly) : IChannel
{
public string Name => notificationsOnly ? "user:notification" : "user";
public List<string> Scopes => ["read:statuses", "read:notifications"];
public bool IsSubscribed { get; private set; }
private List<string> _followedUsers = [];
public async Task Subscribe(StreamingRequestMessage _)
{
if (IsSubscribed) return;
IsSubscribed = true;
var provider = connection.ScopeFactory.CreateScope().ServiceProvider;
var db = provider.GetRequiredService<DatabaseContext>();
_followedUsers = await db.Users.Where(p => p == connection.Token.User)
.SelectMany(p => p.Following)
.Select(p => p.Id)
.ToListAsync();
if (!notificationsOnly)
{
connection.EventService.NotePublished += OnNotePublished;
connection.EventService.NoteUpdated += OnNoteUpdated;
connection.EventService.NoteDeleted += OnNoteDeleted;
}
connection.EventService.Notification += OnNotification;
}
private bool IsApplicable(Note note) => _followedUsers.Prepend(connection.Token.User.Id).Contains(note.UserId);
private bool IsApplicable(Notification notification) => notification.NotifieeId == connection.Token.User.Id;
private async void OnNotePublished(object? _, Note note)
{
if (!IsApplicable(note)) return;
var provider = connection.ScopeFactory.CreateScope().ServiceProvider;
var renderer = provider.GetRequiredService<NoteRenderer>();
var rendered = await renderer.RenderAsync(note, connection.Token.User);
var message = new StreamingUpdateMessage
{
Stream = [Name], Event = "update", Payload = JsonSerializer.Serialize(rendered)
};
await connection.SendMessageAsync(JsonSerializer.Serialize(message));
}
private async void OnNoteUpdated(object? _, Note note)
{
if (!IsApplicable(note)) return;
var provider = connection.ScopeFactory.CreateScope().ServiceProvider;
var renderer = provider.GetRequiredService<NoteRenderer>();
var rendered = await renderer.RenderAsync(note, connection.Token.User);
var message = new StreamingUpdateMessage
{
Stream = [Name], Event = "status.update", Payload = JsonSerializer.Serialize(rendered)
};
await connection.SendMessageAsync(JsonSerializer.Serialize(message));
}
private async void OnNoteDeleted(object? _, Note note)
{
if (!IsApplicable(note)) return;
var message = new StreamingUpdateMessage { Stream = [Name], Event = "delete", Payload = note.Id };
await connection.SendMessageAsync(JsonSerializer.Serialize(message));
}
private async void OnNotification(object? _, Notification notification)
{
if (!IsApplicable(notification)) return;
var provider = connection.ScopeFactory.CreateScope().ServiceProvider;
var renderer = provider.GetRequiredService<NotificationRenderer>();
var rendered = await renderer.RenderAsync(notification, connection.Token.User);
var message = new StreamingUpdateMessage
{
Stream = [Name], Event = "notification", Payload = JsonSerializer.Serialize(rendered)
};
await connection.SendMessageAsync(JsonSerializer.Serialize(message));
}
public Task Unsubscribe(StreamingRequestMessage _)
{
if (!IsSubscribed) return Task.CompletedTask;
IsSubscribed = false;
Dispose();
return Task.CompletedTask;
}
public void Dispose()
{
if (!notificationsOnly)
{
connection.EventService.NotePublished -= OnNotePublished;
connection.EventService.NoteUpdated -= OnNoteUpdated;
connection.EventService.NoteDeleted -= OnNoteDeleted;
}
connection.EventService.Notification -= OnNotification;
}
}

View file

@ -0,0 +1,18 @@
using J = System.Text.Json.Serialization.JsonPropertyNameAttribute;
namespace Iceshrimp.Backend.Controllers.Mastodon.Streaming;
public class StreamingRequestMessage
{
[J("type")] public required string Type { get; set; }
[J("stream")] public required string Stream { get; set; }
[J("list")] public string? List { get; set; }
[J("tag")] public string? Tag { get; set; }
}
public class StreamingUpdateMessage
{
[J("stream")] public required List<string> Stream { get; set; }
[J("event")] public required string Event { get; set; }
[J("payload")] public required string Payload { get; set; }
}

View file

@ -0,0 +1,118 @@
using System.Net.WebSockets;
using System.Text;
using System.Text.Json;
using Iceshrimp.Backend.Controllers.Mastodon.Streaming.Channels;
using Iceshrimp.Backend.Core.Database.Tables;
using Iceshrimp.Backend.Core.Helpers;
using Iceshrimp.Backend.Core.Services;
namespace Iceshrimp.Backend.Controllers.Mastodon.Streaming;
public sealed class WebSocketConnection(
WebSocket socket,
OauthToken token,
EventService eventSvc,
IServiceScopeFactory scopeFactory,
CancellationToken ct
) : IDisposable
{
public readonly OauthToken Token = token;
public readonly List<IChannel> Channels = [];
public readonly EventService EventService = eventSvc;
public readonly IServiceScopeFactory ScopeFactory = scopeFactory;
private readonly SemaphoreSlim _lock = new(1);
public void InitializeStreamingWorker()
{
Channels.Add(new UserChannel(this, true));
Channels.Add(new UserChannel(this, false));
Channels.Add(new PublicChannel(this, "public", true, true, false));
Channels.Add(new PublicChannel(this, "public:media", true, true, true));
Channels.Add(new PublicChannel(this, "public:allow_local_only", true, true, false));
Channels.Add(new PublicChannel(this, "public:allow_local_only:media", true, true, true));
Channels.Add(new PublicChannel(this, "public:local", true, false, false));
Channels.Add(new PublicChannel(this, "public:local:media", true, false, true));
Channels.Add(new PublicChannel(this, "public:remote", false, true, false));
Channels.Add(new PublicChannel(this, "public:remote:media", false, true, true));
}
public void Dispose()
{
foreach (var channel in Channels)
channel.Dispose();
}
public async Task HandleSocketMessageAsync(string payload)
{
StreamingRequestMessage? message = null;
try
{
message = JsonSerializer.Deserialize<StreamingRequestMessage>(payload);
}
catch
{
// ignored
}
if (message == null)
{
await CloseAsync(WebSocketCloseStatus.InvalidPayloadData);
return;
}
switch (message.Type)
{
case "subscribe":
{
var channel = Channels.FirstOrDefault(p => p.Name == message.Stream && !p.IsSubscribed);
if (channel == null) return;
if (channel.Scopes.Except(MastodonOauthHelpers.ExpandScopes(Token.Scopes)).Any())
await CloseAsync(WebSocketCloseStatus.PolicyViolation);
else
await channel.Subscribe(message);
break;
}
case "unsubscribe":
{
var channel = Channels.FirstOrDefault(p => p.Name == message.Stream && p.IsSubscribed);
if (channel != null) await channel.Unsubscribe(message);
break;
}
default:
{
await CloseAsync(WebSocketCloseStatus.InvalidPayloadData);
return;
}
}
}
public async Task SendMessageAsync(string message)
{
await _lock.WaitAsync(ct);
try
{
var buffer = new ArraySegment<byte>(Encoding.UTF8.GetBytes(message));
await socket.SendAsync(buffer, WebSocketMessageType.Text, true, ct);
}
finally
{
_lock.Release();
}
}
public async Task CloseAsync(WebSocketCloseStatus status)
{
Dispose();
await socket.CloseAsync(status, null, ct);
}
}
public interface IChannel
{
public string Name { get; }
public List<string> Scopes { get; }
public bool IsSubscribed { get; }
public Task Subscribe(StreamingRequestMessage message);
public Task Unsubscribe(StreamingRequestMessage message);
public void Dispose();
}

View file

@ -0,0 +1,50 @@
using System.Net.WebSockets;
using System.Text;
using Iceshrimp.Backend.Core.Database.Tables;
using Iceshrimp.Backend.Core.Services;
namespace Iceshrimp.Backend.Controllers.Mastodon.Streaming;
public static class WebSocketHandler
{
public static async Task HandleConnectionAsync(
WebSocket socket, OauthToken token, EventService eventSvc, IServiceScopeFactory scopeFactory,
CancellationToken ct
)
{
using var connection = new WebSocketConnection(socket, token, eventSvc, scopeFactory, ct);
var buffer = new byte[256];
WebSocketReceiveResult? res = null;
connection.InitializeStreamingWorker();
while ((!res?.CloseStatus.HasValue ?? true) &&
!ct.IsCancellationRequested &&
socket.State is WebSocketState.Open)
{
res = await socket.ReceiveAsync(new ArraySegment<byte>(buffer), ct);
if (res.Count > buffer.Length)
{
await socket.CloseAsync(WebSocketCloseStatus.MessageTooBig, null, ct);
break;
}
if (res.MessageType == WebSocketMessageType.Text)
await connection.HandleSocketMessageAsync(Encoding.UTF8.GetString(buffer, 0, res.Count));
else if (res.MessageType == WebSocketMessageType.Binary)
break;
}
if (socket.State is not WebSocketState.Open and not WebSocketState.CloseReceived)
return;
if (res?.CloseStatus != null)
await socket.CloseAsync(res.CloseStatus.Value, res.CloseStatusDescription, ct);
else if (!ct.IsCancellationRequested)
await socket.CloseAsync(WebSocketCloseStatus.InvalidMessageType, null, ct);
else
await socket.CloseAsync(WebSocketCloseStatus.EndpointUnavailable, null, ct);
}
}

View file

@ -0,0 +1,58 @@
using System.Net.WebSockets;
using Iceshrimp.Backend.Controllers.Mastodon.Attributes;
using Iceshrimp.Backend.Controllers.Mastodon.Streaming;
using Iceshrimp.Backend.Core.Database;
using Iceshrimp.Backend.Core.Database.Tables;
using Iceshrimp.Backend.Core.Middleware;
using Iceshrimp.Backend.Core.Services;
using Microsoft.AspNetCore.Mvc;
using Microsoft.EntityFrameworkCore;
namespace Iceshrimp.Backend.Controllers.Mastodon;
[MastodonApiController]
public class WebSocketController(
IHostApplicationLifetime appLifetime,
DatabaseContext db,
EventService eventSvc,
IServiceScopeFactory scopeFactory,
ILogger<WebSocketController> logger
) : ControllerBase
{
[Route("/api/v1/streaming")]
[ApiExplorerSettings(IgnoreApi = true)]
public async Task GetStreamingSocket()
{
if (!HttpContext.WebSockets.IsWebSocketRequest)
throw GracefulException.BadRequest("Not a WebSocket request");
var ct = appLifetime.ApplicationStopping;
var accessToken = HttpContext.WebSockets.WebSocketRequestedProtocols.FirstOrDefault() ??
throw GracefulException.BadRequest("Missing WebSocket protocol header");
using var webSocket = await HttpContext.WebSockets.AcceptWebSocketAsync();
try
{
var token = await Authenticate(accessToken);
await WebSocketHandler.HandleConnectionAsync(webSocket, token, eventSvc, scopeFactory, ct);
}
catch (Exception e)
{
if (e is WebSocketException)
logger.LogDebug("WebSocket connection {id} encountered an error: {error}",
HttpContext.TraceIdentifier, e.Message);
else if (!ct.IsCancellationRequested)
throw;
}
}
private async Task<OauthToken> Authenticate(string token)
{
return await db.OauthTokens
.Include(p => p.User)
.ThenInclude(p => p.UserProfile)
.Include(p => p.App)
.FirstOrDefaultAsync(p => p.Token == token && p.Active) ??
throw GracefulException.Unauthorized("This method requires an authenticated user");
}
}

View file

@ -218,6 +218,7 @@ public static class ServiceExtensions
options.AddPolicy("mastodon", policy => options.AddPolicy("mastodon", policy =>
{ {
policy.WithOrigins("*") policy.WithOrigins("*")
.WithMethods("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "CONNECT")
.WithHeaders("Authorization", "Content-Type", "Idempotency-Key") .WithHeaders("Authorization", "Content-Type", "Idempotency-Key")
.WithExposedHeaders("Link", "Connection", "Sec-Websocket-Accept", "Upgrade"); .WithExposedHeaders("Link", "Connection", "Sec-Websocket-Accept", "Upgrade");
}); });

View file

@ -7,14 +7,14 @@ public class EventService
{ {
public event EventHandler<Note>? NotePublished; public event EventHandler<Note>? NotePublished;
public event EventHandler<Note>? NoteUpdated; public event EventHandler<Note>? NoteUpdated;
public event EventHandler<string>? NoteDeleted; public event EventHandler<Note>? NoteDeleted;
public event EventHandler<NoteInteraction>? NoteLiked; public event EventHandler<NoteInteraction>? NoteLiked;
public event EventHandler<NoteInteraction>? NoteUnliked; public event EventHandler<NoteInteraction>? NoteUnliked;
public event EventHandler<Notification>? Notification; public event EventHandler<Notification>? Notification;
public void RaiseNotePublished(object? sender, Note note) => NotePublished?.Invoke(sender, note); public void RaiseNotePublished(object? sender, Note note) => NotePublished?.Invoke(sender, note);
public void RaiseNoteUpdated(object? sender, Note note) => NoteUpdated?.Invoke(sender, note); public void RaiseNoteUpdated(object? sender, Note note) => NoteUpdated?.Invoke(sender, note);
public void RaiseNoteDeleted(object? sender, Note note) => NoteDeleted?.Invoke(sender, note.Id); public void RaiseNoteDeleted(object? sender, Note note) => NoteDeleted?.Invoke(sender, note);
public void RaiseNotification(object? sender, Notification notification) => public void RaiseNotification(object? sender, Notification notification) =>
Notification?.Invoke(sender, notification); Notification?.Invoke(sender, notification);

View file

@ -45,6 +45,7 @@ app.UseStaticFiles();
app.UseRateLimiter(); app.UseRateLimiter();
app.UseCors(); app.UseCors();
app.UseAuthorization(); app.UseAuthorization();
app.UseWebSockets();
app.UseCustomMiddleware(); app.UseCustomMiddleware();
app.MapControllers(); app.MapControllers();