[backend/masto-client] Handle mutes & blocks in WebSockets (ISH-219)

This commit is contained in:
Laura Hausmann 2024-03-26 18:50:16 +01:00
parent b4adfe7067
commit 45dcbf29fe
No known key found for this signature in database
GPG key ID: D044E84C5BE01605
3 changed files with 114 additions and 5 deletions

View file

@ -55,11 +55,17 @@ public class PublicChannel(
return true; return true;
} }
private bool IsFiltered(Note note) => connection.IsFiltered(note.User) ||
(note.Renote?.User != null && connection.IsFiltered(note.Renote.User)) ||
note.Renote?.Renote?.User != null &&
connection.IsFiltered(note.Renote.Renote.User);
private async void OnNotePublished(object? _, Note note) private async void OnNotePublished(object? _, Note note)
{ {
try try
{ {
if (!IsApplicable(note)) return; if (!IsApplicable(note)) return;
if (IsFiltered(note)) return;
await using var scope = connection.ScopeFactory.CreateAsyncScope(); await using var scope = connection.ScopeFactory.CreateAsyncScope();
var provider = scope.ServiceProvider; var provider = scope.ServiceProvider;
@ -84,6 +90,7 @@ public class PublicChannel(
try try
{ {
if (!IsApplicable(note)) return; if (!IsApplicable(note)) return;
if (IsFiltered(note)) return;
await using var scope = connection.ScopeFactory.CreateAsyncScope(); await using var scope = connection.ScopeFactory.CreateAsyncScope();
var provider = scope.ServiceProvider; var provider = scope.ServiceProvider;
@ -108,6 +115,7 @@ public class PublicChannel(
try try
{ {
if (!IsApplicable(note)) return; if (!IsApplicable(note)) return;
if (IsFiltered(note)) return;
var message = new StreamingUpdateMessage var message = new StreamingUpdateMessage
{ {
Stream = [Name], Stream = [Name],

View file

@ -27,13 +27,13 @@ public class UserChannel(WebSocketConnection connection, bool notificationsOnly)
await using var scope = connection.ScopeFactory.CreateAsyncScope(); await using var scope = connection.ScopeFactory.CreateAsyncScope();
await using var db = scope.ServiceProvider.GetRequiredService<DatabaseContext>(); await using var db = scope.ServiceProvider.GetRequiredService<DatabaseContext>();
if (!notificationsOnly)
{
_followedUsers = await db.Users.Where(p => p == connection.Token.User) _followedUsers = await db.Users.Where(p => p == connection.Token.User)
.SelectMany(p => p.Following) .SelectMany(p => p.Following)
.Select(p => p.Id) .Select(p => p.Id)
.ToListAsync(); .ToListAsync();
if (!notificationsOnly)
{
connection.EventService.NotePublished += OnNotePublished; connection.EventService.NotePublished += OnNotePublished;
connection.EventService.NoteUpdated += OnNoteUpdated; connection.EventService.NoteUpdated += OnNoteUpdated;
connection.EventService.NoteDeleted += OnNoteDeleted; connection.EventService.NoteDeleted += OnNoteDeleted;
@ -74,11 +74,21 @@ public class UserChannel(WebSocketConnection connection, bool notificationsOnly)
private bool IsApplicable(UserInteraction interaction) => interaction.Actor.Id == connection.Token.User.Id || private bool IsApplicable(UserInteraction interaction) => interaction.Actor.Id == connection.Token.User.Id ||
interaction.Object.Id == connection.Token.User.Id; interaction.Object.Id == connection.Token.User.Id;
private bool IsFiltered(Note note) => connection.IsFiltered(note.User) ||
(note.Renote?.User != null && connection.IsFiltered(note.Renote.User)) ||
note.Renote?.Renote?.User != null &&
connection.IsFiltered(note.Renote.Renote.User);
private bool IsFiltered(Notification notification) =>
(notification.Notifier != null && connection.IsFiltered(notification.Notifier)) ||
(notification.Note != null && IsFiltered(notification.Note));
private async void OnNotePublished(object? _, Note note) private async void OnNotePublished(object? _, Note note)
{ {
try try
{ {
if (!IsApplicable(note)) return; if (!IsApplicable(note)) return;
if (IsFiltered(note)) return;
await using var scope = connection.ScopeFactory.CreateAsyncScope(); await using var scope = connection.ScopeFactory.CreateAsyncScope();
var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>(); var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>();
@ -102,6 +112,7 @@ public class UserChannel(WebSocketConnection connection, bool notificationsOnly)
try try
{ {
if (!IsApplicable(note)) return; if (!IsApplicable(note)) return;
if (IsFiltered(note)) return;
await using var scope = connection.ScopeFactory.CreateAsyncScope(); await using var scope = connection.ScopeFactory.CreateAsyncScope();
var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>(); var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>();
@ -125,6 +136,7 @@ public class UserChannel(WebSocketConnection connection, bool notificationsOnly)
try try
{ {
if (!IsApplicable(note)) return; if (!IsApplicable(note)) return;
if (IsFiltered(note)) return;
var message = new StreamingUpdateMessage var message = new StreamingUpdateMessage
{ {
Stream = [Name], Stream = [Name],
@ -144,6 +156,7 @@ public class UserChannel(WebSocketConnection connection, bool notificationsOnly)
try try
{ {
if (!IsApplicable(notification)) return; if (!IsApplicable(notification)) return;
if (IsFiltered(notification)) return;
await using var scope = connection.ScopeFactory.CreateAsyncScope(); await using var scope = connection.ScopeFactory.CreateAsyncScope();
var renderer = scope.ServiceProvider.GetRequiredService<NotificationRenderer>(); var renderer = scope.ServiceProvider.GetRequiredService<NotificationRenderer>();

View file

@ -1,8 +1,10 @@
using System.Diagnostics.CodeAnalysis;
using System.Net.WebSockets; using System.Net.WebSockets;
using System.Text; using System.Text;
using System.Text.Json; using System.Text.Json;
using Iceshrimp.Backend.Controllers.Mastodon.Streaming.Channels; using Iceshrimp.Backend.Controllers.Mastodon.Streaming.Channels;
using Iceshrimp.Backend.Core.Database.Tables; using Iceshrimp.Backend.Core.Database.Tables;
using Iceshrimp.Backend.Core.Events;
using Iceshrimp.Backend.Core.Helpers; using Iceshrimp.Backend.Core.Helpers;
using Iceshrimp.Backend.Core.Services; using Iceshrimp.Backend.Core.Services;
@ -22,12 +24,20 @@ public sealed class WebSocketConnection(
public readonly IServiceScope Scope = scopeFactory.CreateScope(); public readonly IServiceScope Scope = scopeFactory.CreateScope();
public readonly IServiceScopeFactory ScopeFactory = scopeFactory; public readonly IServiceScopeFactory ScopeFactory = scopeFactory;
public readonly OauthToken Token = token; public readonly OauthToken Token = token;
private readonly List<string> _blocking = [];
private readonly List<string> _blockedBy = [];
private readonly List<string> _mutedUsers = [];
public void Dispose() public void Dispose()
{ {
foreach (var channel in Channels) foreach (var channel in Channels)
channel.Dispose(); channel.Dispose();
EventService.UserBlocked -= OnUserUnblock;
EventService.UserUnblocked -= OnUserBlock;
EventService.UserMuted -= OnUserMute;
EventService.UserUnmuted -= OnUserUnmute;
Scope.Dispose(); Scope.Dispose();
} }
@ -43,6 +53,11 @@ public sealed class WebSocketConnection(
Channels.Add(new PublicChannel(this, "public:local:media", true, false, true)); Channels.Add(new PublicChannel(this, "public:local:media", true, false, true));
Channels.Add(new PublicChannel(this, "public:remote", false, true, false)); Channels.Add(new PublicChannel(this, "public:remote", false, true, false));
Channels.Add(new PublicChannel(this, "public:remote:media", false, true, true)); Channels.Add(new PublicChannel(this, "public:remote:media", false, true, true));
EventService.UserBlocked += OnUserUnblock;
EventService.UserUnblocked += OnUserBlock;
EventService.UserMuted += OnUserMute;
EventService.UserUnmuted += OnUserUnmute;
} }
public async Task HandleSocketMessageAsync(string payload) public async Task HandleSocketMessageAsync(string payload)
@ -108,6 +123,79 @@ public sealed class WebSocketConnection(
} }
} }
private void OnUserBlock(object? _, UserInteraction interaction)
{
try
{
if (interaction.Actor.Id == Token.User.Id)
lock (_blocking)
_blocking.Add(interaction.Object.Id);
if (interaction.Object.Id == Token.User.Id)
lock (_blockedBy)
_blockedBy.Add(interaction.Actor.Id);
}
catch (Exception e)
{
var logger = Scope.ServiceProvider.GetRequiredService<Logger<WebSocketConnection>>();
logger.LogError("Event handler OnUserBlock threw exception: {e}", e);
}
}
private void OnUserUnblock(object? _, UserInteraction interaction)
{
try
{
if (interaction.Actor.Id == Token.User.Id)
lock (_blocking)
_blocking.Remove(interaction.Object.Id);
if (interaction.Object.Id == Token.User.Id)
lock (_blockedBy)
_blockedBy.Remove(interaction.Actor.Id);
}
catch (Exception e)
{
var logger = Scope.ServiceProvider.GetRequiredService<Logger<WebSocketConnection>>();
logger.LogError("Event handler OnUserUnblock threw exception: {e}", e);
}
}
private void OnUserMute(object? _, UserInteraction interaction)
{
try
{
if (interaction.Actor.Id == Token.User.Id)
lock (_mutedUsers)
_mutedUsers.Add(interaction.Object.Id);
}
catch (Exception e)
{
var logger = Scope.ServiceProvider.GetRequiredService<Logger<WebSocketConnection>>();
logger.LogError("Event handler OnUserMute threw exception: {e}", e);
}
}
private void OnUserUnmute(object? _, UserInteraction interaction)
{
try
{
if (interaction.Actor.Id == Token.User.Id)
lock (_mutedUsers)
_mutedUsers.Remove(interaction.Object.Id);
}
catch (Exception e)
{
var logger = Scope.ServiceProvider.GetRequiredService<Logger<WebSocketConnection>>();
logger.LogError("Event handler OnUserUnmute threw exception: {e}", e);
}
}
[SuppressMessage("ReSharper", "InconsistentlySynchronizedField")]
[SuppressMessage("ReSharper", "SuggestBaseTypeForParameter")]
public bool IsFiltered(User user) =>
_blocking.Contains(user.Id) || _blockedBy.Contains(user.Id) || _mutedUsers.Contains(user.Id);
public async Task CloseAsync(WebSocketCloseStatus status) public async Task CloseAsync(WebSocketCloseStatus status)
{ {
Dispose(); Dispose();