[backend/masto-client] Add hashtag streaming channel (ISH-332)

This commit is contained in:
Laura Hausmann 2024-05-23 15:16:46 +02:00
parent a5c1f063d2
commit a4087a4c81
No known key found for this signature in database
GPG key ID: D044E84C5BE01605
5 changed files with 197 additions and 3 deletions

View file

@ -0,0 +1,172 @@
using System.Net.WebSockets;
using System.Text.Json;
using Iceshrimp.Backend.Controllers.Mastodon.Renderers;
using Iceshrimp.Backend.Controllers.Mastodon.Schemas.Entities;
using Iceshrimp.Backend.Core.Database.Tables;
using Iceshrimp.Backend.Core.Extensions;
using Iceshrimp.Backend.Core.Helpers;
namespace Iceshrimp.Backend.Controllers.Mastodon.Streaming.Channels;
public class HashtagChannel(WebSocketConnection connection, bool local) : IChannel
{
private readonly ILogger<HashtagChannel> _logger =
connection.Scope.ServiceProvider.GetRequiredService<ILogger<HashtagChannel>>();
public string Name => local ? "hashtag:local" : "hashtag";
public List<string> Scopes => ["read:statuses"];
public bool IsSubscribed => _tags.Count != 0;
public bool IsAggregate => true;
private readonly WriteLockingList<string> _tags = [];
public async Task Subscribe(StreamingRequestMessage msg)
{
if (msg.Tag == null)
{
await connection.CloseAsync(WebSocketCloseStatus.InvalidPayloadData);
return;
}
if (!IsSubscribed)
{
connection.EventService.NotePublished += OnNotePublished;
connection.EventService.NoteUpdated += OnNoteUpdated;
connection.EventService.NoteDeleted += OnNoteDeleted;
}
_tags.AddIfMissing(msg.Tag);
}
public async Task Unsubscribe(StreamingRequestMessage msg)
{
if (msg.Tag == null)
{
await connection.CloseAsync(WebSocketCloseStatus.InvalidPayloadData);
return;
}
_tags.RemoveAll(p => p == msg.Tag);
if (!IsSubscribed) Dispose();
}
public void Dispose()
{
connection.EventService.NotePublished -= OnNotePublished;
connection.EventService.NoteUpdated -= OnNoteUpdated;
connection.EventService.NoteDeleted -= OnNoteDeleted;
}
private NoteWithVisibilities? IsApplicable(Note note)
{
if (!IsApplicableBool(note)) return null;
var res = EnforceRenoteReplyVisibility(note);
return res is not { Note.IsPureRenote: true, Renote: null } ? res : null;
}
private bool IsApplicableBool(Note note) =>
(!local || note.User.Host == null) &&
note.Tags.Intersects(_tags) &&
note.IsVisibleFor(connection.Token.User, connection.Following);
private NoteWithVisibilities EnforceRenoteReplyVisibility(Note note)
{
var wrapped = new NoteWithVisibilities(note);
if (!wrapped.Renote?.IsVisibleFor(connection.Token.User, connection.Following) ?? false)
wrapped.Renote = null;
return wrapped;
}
private class NoteWithVisibilities(Note note)
{
public readonly Note Note = note;
public Note? Renote = note.Renote;
}
private static StatusEntity EnforceRenoteReplyVisibility(StatusEntity rendered, NoteWithVisibilities note)
{
var renote = note.Renote == null && rendered.Renote != null;
if (!renote) return rendered;
rendered = (StatusEntity)rendered.Clone();
if (renote) rendered.Renote = null;
return rendered;
}
private IEnumerable<StreamingUpdateMessage> RenderMessage(
IEnumerable<string> tags, string eventType, string payload
) => tags.Select(tag => new StreamingUpdateMessage
{
Stream = [Name, tag],
Event = eventType,
Payload = payload
});
private async void OnNotePublished(object? _, Note note)
{
try
{
var wrapped = IsApplicable(note);
if (wrapped == null) return;
if (connection.IsFiltered(note)) return;
await using var scope = connection.ScopeFactory.CreateAsyncScope();
var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>();
var data = new NoteRenderer.NoteRendererDto { Filters = connection.Filters.ToList() };
var intermediate = await renderer.RenderAsync(note, connection.Token.User, data: data);
var rendered = EnforceRenoteReplyVisibility(intermediate, wrapped);
var messages = RenderMessage(_tags.Intersect(note.Tags), "update", JsonSerializer.Serialize(rendered));
foreach (var message in messages)
await connection.SendMessageAsync(JsonSerializer.Serialize(message));
}
catch (Exception e)
{
_logger.LogError("Event handler OnNotePublished threw exception: {e}", e);
}
}
private async void OnNoteUpdated(object? _, Note note)
{
try
{
var wrapped = IsApplicable(note);
if (wrapped == null) return;
if (connection.IsFiltered(note)) return;
await using var scope = connection.ScopeFactory.CreateAsyncScope();
var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>();
var data = new NoteRenderer.NoteRendererDto { Filters = connection.Filters.ToList() };
var intermediate = await renderer.RenderAsync(note, connection.Token.User, data: data);
var rendered = EnforceRenoteReplyVisibility(intermediate, wrapped);
var messages = RenderMessage(_tags.Intersect(note.Tags), "status.update",
JsonSerializer.Serialize(rendered));
foreach (var message in messages)
await connection.SendMessageAsync(JsonSerializer.Serialize(message));
}
catch (Exception e)
{
_logger.LogError("Event handler OnNoteUpdated threw exception: {e}", e);
}
}
private async void OnNoteDeleted(object? _, Note note)
{
try
{
if (!IsApplicableBool(note)) return;
if (connection.IsFiltered(note)) return;
var messages = RenderMessage(_tags.Intersect(note.Tags), "delete", note.Id);
foreach (var message in messages)
await connection.SendMessageAsync(JsonSerializer.Serialize(message));
}
catch (Exception e)
{
_logger.LogError("Event handler OnNoteDeleted threw exception: {e}", e);
}
}
}

View file

@ -19,6 +19,7 @@ public class PublicChannel(
public string Name => name; public string Name => name;
public List<string> Scopes => ["read:statuses"]; public List<string> Scopes => ["read:statuses"];
public bool IsSubscribed { get; private set; } public bool IsSubscribed { get; private set; }
public bool IsAggregate => false;
public Task Subscribe(StreamingRequestMessage _) public Task Subscribe(StreamingRequestMessage _)
{ {

View file

@ -15,6 +15,7 @@ public class UserChannel(WebSocketConnection connection, bool notificationsOnly)
public string Name => notificationsOnly ? "user:notification" : "user"; public string Name => notificationsOnly ? "user:notification" : "user";
public List<string> Scopes => ["read:statuses", "read:notifications"]; public List<string> Scopes => ["read:statuses", "read:notifications"];
public bool IsSubscribed { get; private set; } public bool IsSubscribed { get; private set; }
public bool IsAggregate => false;
public async Task Subscribe(StreamingRequestMessage _) public async Task Subscribe(StreamingRequestMessage _)
{ {

View file

@ -65,6 +65,8 @@ public sealed class WebSocketConnection(
Channels.Add(new PublicChannel(this, "public:local:media", true, false, true)); Channels.Add(new PublicChannel(this, "public:local:media", true, false, true));
Channels.Add(new PublicChannel(this, "public:remote", false, true, false)); Channels.Add(new PublicChannel(this, "public:remote", false, true, false));
Channels.Add(new PublicChannel(this, "public:remote:media", false, true, true)); Channels.Add(new PublicChannel(this, "public:remote:media", false, true, true));
Channels.Add(new HashtagChannel(this, true));
Channels.Add(new HashtagChannel(this, false));
EventService.UserBlocked += OnUserUnblock; EventService.UserBlocked += OnUserUnblock;
EventService.UserUnblocked += OnUserBlock; EventService.UserUnblocked += OnUserBlock;
@ -136,7 +138,8 @@ public sealed class WebSocketConnection(
{ {
case "subscribe": case "subscribe":
{ {
var channel = Channels.FirstOrDefault(p => p.Name == message.Stream && !p.IsSubscribed); var channel =
Channels.FirstOrDefault(p => p.Name == message.Stream && (!p.IsSubscribed || p.IsAggregate));
if (channel == null) return; if (channel == null) return;
if (channel.Scopes.Except(MastodonOauthHelpers.ExpandScopes(Token.Scopes)).Any()) if (channel.Scopes.Except(MastodonOauthHelpers.ExpandScopes(Token.Scopes)).Any())
await CloseAsync(WebSocketCloseStatus.PolicyViolation); await CloseAsync(WebSocketCloseStatus.PolicyViolation);
@ -146,7 +149,8 @@ public sealed class WebSocketConnection(
} }
case "unsubscribe": case "unsubscribe":
{ {
var channel = Channels.FirstOrDefault(p => p.Name == message.Stream && p.IsSubscribed); var channel =
Channels.FirstOrDefault(p => p.Name == message.Stream && (p.IsSubscribed || p.IsAggregate));
if (channel != null) await channel.Unsubscribe(message); if (channel != null) await channel.Unsubscribe(message);
break; break;
} }
@ -341,6 +345,7 @@ public interface IChannel
public string Name { get; } public string Name { get; }
public List<string> Scopes { get; } public List<string> Scopes { get; }
public bool IsSubscribed { get; } public bool IsSubscribed { get; }
public bool IsAggregate { get; }
public Task Subscribe(StreamingRequestMessage message); public Task Subscribe(StreamingRequestMessage message);
public Task Unsubscribe(StreamingRequestMessage message); public Task Unsubscribe(StreamingRequestMessage message);
public void Dispose(); public void Dispose();

View file

@ -1,7 +1,10 @@
using System.Collections; using System.Collections;
using System.Diagnostics.CodeAnalysis;
namespace Iceshrimp.Backend.Core.Helpers; namespace Iceshrimp.Backend.Core.Helpers;
[SuppressMessage("ReSharper", "InconsistentlySynchronizedField",
Justification = "This is intentional (it's a *write* locking list, after all)")]
public class WriteLockingList<T> : ICollection<T> public class WriteLockingList<T> : ICollection<T>
{ {
private readonly List<T> _list = []; private readonly List<T> _list = [];
@ -14,6 +17,13 @@ public class WriteLockingList<T> : ICollection<T>
lock (_list) _list.Add(item); lock (_list) _list.Add(item);
} }
public void AddIfMissing(T item)
{
lock (_list)
if (!_list.Contains(item))
_list.Add(item);
}
public void AddRange(IEnumerable<T> item) public void AddRange(IEnumerable<T> item)
{ {
lock (_list) _list.AddRange(item); lock (_list) _list.AddRange(item);
@ -33,6 +43,11 @@ public class WriteLockingList<T> : ICollection<T>
lock (_list) return _list.Remove(item); lock (_list) return _list.Remove(item);
} }
public int RemoveAll(Predicate<T> predicate)
{
lock (_list) return _list.RemoveAll(predicate);
}
public int Count => _list.Count; public int Count => _list.Count;
public bool IsReadOnly => ((ICollection<T>)_list).IsReadOnly; public bool IsReadOnly => ((ICollection<T>)_list).IsReadOnly;
} }