[backend/masto-client] Enforce renote/reply visibility in ws/streaming (ISH-248)

This commit is contained in:
Laura Hausmann 2024-04-08 21:09:58 +02:00
parent f6734aea11
commit bc50aa0259
No known key found for this signature in database
GPG key ID: D044E84C5BE01605
3 changed files with 75 additions and 40 deletions

View file

@ -51,8 +51,17 @@ public class PublicChannel(
if (!local && note.UserHost == null) return false; if (!local && note.UserHost == null) return false;
if (!remote && note.UserHost != null) return false; if (!remote && note.UserHost != null) return false;
if (onlyMedia && note.FileIds.Count == 0) return false; if (onlyMedia && note.FileIds.Count == 0) return false;
return EnforceRenoteReplyVisibility(note) is not { IsPureRenote: true, Renote: null };
}
return true; private Note EnforceRenoteReplyVisibility(Note note)
{
if (note.Renote?.IsVisibleFor(connection.Token.User, connection.Following) ?? false)
note.Renote = null;
if (note.Reply?.IsVisibleFor(connection.Token.User, connection.Following) ?? false)
note.Reply = null;
return note;
} }
private bool IsFiltered(Note note) => connection.IsFiltered(note.User) || private bool IsFiltered(Note note) => connection.IsFiltered(note.User) ||

View file

@ -5,7 +5,6 @@ using Iceshrimp.Backend.Core.Database;
using Iceshrimp.Backend.Core.Database.Tables; using Iceshrimp.Backend.Core.Database.Tables;
using Iceshrimp.Backend.Core.Events; using Iceshrimp.Backend.Core.Events;
using Iceshrimp.Backend.Core.Middleware; using Iceshrimp.Backend.Core.Middleware;
using Microsoft.EntityFrameworkCore;
namespace Iceshrimp.Backend.Controllers.Mastodon.Streaming.Channels; namespace Iceshrimp.Backend.Controllers.Mastodon.Streaming.Channels;
@ -14,10 +13,9 @@ public class UserChannel(WebSocketConnection connection, bool notificationsOnly)
public readonly ILogger<UserChannel> Logger = public readonly ILogger<UserChannel> Logger =
connection.Scope.ServiceProvider.GetRequiredService<ILogger<UserChannel>>(); connection.Scope.ServiceProvider.GetRequiredService<ILogger<UserChannel>>();
private List<string> _followedUsers = []; public string Name => notificationsOnly ? "user:notification" : "user";
public string Name => notificationsOnly ? "user:notification" : "user"; public List<string> Scopes => ["read:statuses", "read:notifications"];
public List<string> Scopes => ["read:statuses", "read:notifications"]; public bool IsSubscribed { get; private set; }
public bool IsSubscribed { get; private set; }
public async Task Subscribe(StreamingRequestMessage _) public async Task Subscribe(StreamingRequestMessage _)
{ {
@ -29,17 +27,9 @@ public class UserChannel(WebSocketConnection connection, bool notificationsOnly)
if (!notificationsOnly) if (!notificationsOnly)
{ {
_followedUsers = await db.Users.Where(p => p == connection.Token.User)
.SelectMany(p => p.Following)
.Select(p => p.Id)
.ToListAsync();
connection.EventService.NotePublished += OnNotePublished; connection.EventService.NotePublished += OnNotePublished;
connection.EventService.NoteUpdated += OnNoteUpdated; connection.EventService.NoteUpdated += OnNoteUpdated;
connection.EventService.NoteDeleted += OnNoteDeleted; connection.EventService.NoteDeleted += OnNoteDeleted;
connection.EventService.UserFollowed += OnRelationChange;
connection.EventService.UserUnfollowed += OnRelationChange;
connection.EventService.UserBlocked += OnRelationChange;
} }
connection.EventService.Notification += OnNotification; connection.EventService.Notification += OnNotification;
@ -60,15 +50,15 @@ public class UserChannel(WebSocketConnection connection, bool notificationsOnly)
connection.EventService.NotePublished -= OnNotePublished; connection.EventService.NotePublished -= OnNotePublished;
connection.EventService.NoteUpdated -= OnNoteUpdated; connection.EventService.NoteUpdated -= OnNoteUpdated;
connection.EventService.NoteDeleted -= OnNoteDeleted; connection.EventService.NoteDeleted -= OnNoteDeleted;
connection.EventService.UserFollowed -= OnRelationChange;
connection.EventService.UserUnfollowed -= OnRelationChange;
connection.EventService.UserBlocked -= OnRelationChange;
} }
connection.EventService.Notification -= OnNotification; connection.EventService.Notification -= OnNotification;
} }
private bool IsApplicable(Note note) => _followedUsers.Prepend(connection.Token.User.Id).Contains(note.UserId); private bool IsApplicable(Note note) =>
connection.Following.Prepend(connection.Token.User.Id).Contains(note.UserId) &&
EnforceRenoteReplyVisibility(note) is not { IsPureRenote: true, Renote: null };
private bool IsApplicable(Notification notification) => notification.NotifieeId == connection.Token.User.Id; private bool IsApplicable(Notification notification) => notification.NotifieeId == connection.Token.User.Id;
private bool IsApplicable(UserInteraction interaction) => interaction.Actor.Id == connection.Token.User.Id || private bool IsApplicable(UserInteraction interaction) => interaction.Actor.Id == connection.Token.User.Id ||
@ -83,6 +73,16 @@ public class UserChannel(WebSocketConnection connection, bool notificationsOnly)
(notification.Notifier != null && connection.IsFiltered(notification.Notifier)) || (notification.Notifier != null && connection.IsFiltered(notification.Notifier)) ||
(notification.Note != null && IsFiltered(notification.Note)); (notification.Note != null && IsFiltered(notification.Note));
private Note EnforceRenoteReplyVisibility(Note note)
{
if (note.Renote?.IsVisibleFor(connection.Token.User, connection.Following) ?? false)
note.Renote = null;
if (note.Reply?.IsVisibleFor(connection.Token.User, connection.Following) ?? false)
note.Reply = null;
return note;
}
private async void OnNotePublished(object? _, Note note) private async void OnNotePublished(object? _, Note note)
{ {
try try
@ -185,22 +185,4 @@ public class UserChannel(WebSocketConnection connection, bool notificationsOnly)
Logger.LogError("Event handler OnNotification threw exception: {e}", e); Logger.LogError("Event handler OnNotification threw exception: {e}", e);
} }
} }
private async void OnRelationChange(object? _, UserInteraction interaction)
{
try
{
if (!IsApplicable(interaction)) return;
await using var scope = connection.ScopeFactory.CreateAsyncScope();
await using var db = scope.ServiceProvider.GetRequiredService<DatabaseContext>();
_followedUsers = await db.Users.Where(p => p == connection.Token.User)
.SelectMany(p => p.Following)
.Select(p => p.Id)
.ToListAsync();
}
catch (Exception e)
{
Logger.LogError("Event handler OnRelationChange threw exception: {e}", e);
}
}
} }

View file

@ -3,10 +3,12 @@ using System.Net.WebSockets;
using System.Text; using System.Text;
using System.Text.Json; using System.Text.Json;
using Iceshrimp.Backend.Controllers.Mastodon.Streaming.Channels; using Iceshrimp.Backend.Controllers.Mastodon.Streaming.Channels;
using Iceshrimp.Backend.Core.Database;
using Iceshrimp.Backend.Core.Database.Tables; using Iceshrimp.Backend.Core.Database.Tables;
using Iceshrimp.Backend.Core.Events; using Iceshrimp.Backend.Core.Events;
using Iceshrimp.Backend.Core.Helpers; using Iceshrimp.Backend.Core.Helpers;
using Iceshrimp.Backend.Core.Services; using Iceshrimp.Backend.Core.Services;
using Microsoft.EntityFrameworkCore;
namespace Iceshrimp.Backend.Controllers.Mastodon.Streaming; namespace Iceshrimp.Backend.Controllers.Mastodon.Streaming;
@ -24,6 +26,7 @@ public sealed class WebSocketConnection(
public readonly IServiceScope Scope = scopeFactory.CreateScope(); public readonly IServiceScope Scope = scopeFactory.CreateScope();
public readonly IServiceScopeFactory ScopeFactory = scopeFactory; public readonly IServiceScopeFactory ScopeFactory = scopeFactory;
public readonly OauthToken Token = token; public readonly OauthToken Token = token;
public readonly WriteLockingList<string> Following = [];
private readonly WriteLockingList<string> _blocking = []; private readonly WriteLockingList<string> _blocking = [];
private readonly WriteLockingList<string> _blockedBy = []; private readonly WriteLockingList<string> _blockedBy = [];
private readonly WriteLockingList<string> _mutedUsers = []; private readonly WriteLockingList<string> _mutedUsers = [];
@ -54,10 +57,23 @@ public sealed class WebSocketConnection(
Channels.Add(new PublicChannel(this, "public:remote", false, true, false)); Channels.Add(new PublicChannel(this, "public:remote", false, true, false));
Channels.Add(new PublicChannel(this, "public:remote:media", false, true, true)); Channels.Add(new PublicChannel(this, "public:remote:media", false, true, true));
EventService.UserBlocked += OnUserUnblock; EventService.UserBlocked += OnUserUnblock;
EventService.UserUnblocked += OnUserBlock; EventService.UserUnblocked += OnUserBlock;
EventService.UserMuted += OnUserMute; EventService.UserMuted += OnUserMute;
EventService.UserUnmuted += OnUserUnmute; EventService.UserUnmuted += OnUserUnmute;
EventService.UserFollowed -= OnUserFollow;
EventService.UserUnfollowed -= OnUserUnfollow;
_ = InitializeFollowing();
}
private async Task InitializeFollowing()
{
await using var db = Scope.ServiceProvider.GetRequiredService<DatabaseContext>();
Following.AddRange(await db.Users.Where(p => p == Token.User)
.SelectMany(p => p.Following)
.Select(p => p.Id)
.ToListAsync());
} }
public async Task HandleSocketMessageAsync(string payload) public async Task HandleSocketMessageAsync(string payload)
@ -185,6 +201,34 @@ public sealed class WebSocketConnection(
} }
} }
private void OnUserFollow(object? _, UserInteraction interaction)
{
try
{
if (interaction.Actor.Id == Token.User.Id)
Following.Add(interaction.Object.Id);
}
catch (Exception e)
{
var logger = Scope.ServiceProvider.GetRequiredService<Logger<WebSocketConnection>>();
logger.LogError("Event handler OnUserFollow threw exception: {e}", e);
}
}
private void OnUserUnfollow(object? _, UserInteraction interaction)
{
try
{
if (interaction.Actor.Id == Token.User.Id)
Following.Remove(interaction.Object.Id);
}
catch (Exception e)
{
var logger = Scope.ServiceProvider.GetRequiredService<Logger<WebSocketConnection>>();
logger.LogError("Event handler OnUserUnfollow threw exception: {e}", e);
}
}
[SuppressMessage("ReSharper", "SuggestBaseTypeForParameter")] [SuppressMessage("ReSharper", "SuggestBaseTypeForParameter")]
public bool IsFiltered(User user) => public bool IsFiltered(User user) =>
_blocking.Contains(user.Id) || _blockedBy.Contains(user.Id) || _mutedUsers.Contains(user.Id); _blocking.Contains(user.Id) || _blockedBy.Contains(user.Id) || _mutedUsers.Contains(user.Id);