[backend/masto-client] Respect filters in WebSocket connections (ISH-328)

This commit is contained in:
Laura Hausmann 2024-05-17 16:07:56 +02:00
parent 849ecd9841
commit 9636a096fc
No known key found for this signature in database
GPG key ID: D044E84C5BE01605
7 changed files with 168 additions and 18 deletions

View file

@ -22,7 +22,7 @@ namespace Iceshrimp.Backend.Controllers.Mastodon;
[EnableRateLimiting("sliding")]
[EnableCors("mastodon")]
[Produces(MediaTypeNames.Application.Json)]
public class FilterController(DatabaseContext db, QueueService queueSvc) : ControllerBase
public class FilterController(DatabaseContext db, QueueService queueSvc, EventService eventSvc) : ControllerBase
{
[HttpGet]
[Authorize("read:filters")]
@ -94,6 +94,7 @@ public class FilterController(DatabaseContext db, QueueService queueSvc) : Contr
db.Add(filter);
await db.SaveChangesAsync();
eventSvc.RaiseFilterAdded(this, filter);
if (expiry.HasValue)
{
@ -157,6 +158,7 @@ public class FilterController(DatabaseContext db, QueueService queueSvc) : Contr
db.Update(filter);
await db.SaveChangesAsync();
eventSvc.RaiseFilterUpdated(this, filter);
if (expiry.HasValue)
{
@ -179,6 +181,7 @@ public class FilterController(DatabaseContext db, QueueService queueSvc) : Contr
db.Remove(filter);
await db.SaveChangesAsync();
eventSvc.RaiseFilterRemoved(this, filter);
return Ok(new object());
}
@ -213,6 +216,7 @@ public class FilterController(DatabaseContext db, QueueService queueSvc) : Contr
db.Update(keyword);
await db.SaveChangesAsync();
eventSvc.RaiseFilterUpdated(this, filter);
return Ok(new FilterKeyword(keyword, filter.Id, filter.Keywords.Count - 1));
}
@ -251,6 +255,7 @@ public class FilterController(DatabaseContext db, QueueService queueSvc) : Contr
filter.Keywords[keywordId] = request.WholeWord ? $"\"{request.Keyword}\"" : request.Keyword;
db.Update(filter);
await db.SaveChangesAsync();
eventSvc.RaiseFilterUpdated(this, filter);
return Ok(new FilterKeyword(filter.Keywords[keywordId], filter.Id, keywordId));
}
@ -271,6 +276,7 @@ public class FilterController(DatabaseContext db, QueueService queueSvc) : Contr
filter.Keywords.RemoveAt(keywordId);
db.Update(filter);
await db.SaveChangesAsync();
eventSvc.RaiseFilterUpdated(this, filter);
return Ok(new object());
}

View file

@ -92,9 +92,19 @@ public class NoteRenderer(
? (data?.Polls ?? await GetPolls([note], user)).FirstOrDefault(p => p.Id == note.Id)
: null;
var filters = data?.Filters ?? await GetFilters(user, filterContext);
var filtered = FilterHelper.IsFiltered([note, note.Reply, note.Renote, note.Renote?.Renote], filters);
var filterResult = GetFilterResult(filtered);
var filters = data?.Filters ?? await GetFilters(user, filterContext);
List<FilterResultEntity> filterResult;
if (filters.Count > 0 && filterContext == null)
{
var filtered = FilterHelper.CheckFilters([note, note.Reply, note.Renote, note.Renote?.Renote], filters);
filterResult = GetFilterResult(filtered);
}
else
{
var filtered = FilterHelper.IsFiltered([note, note.Reply, note.Renote, note.Renote?.Renote], filters);
filterResult = GetFilterResult(filtered.HasValue ? [filtered.Value] : []);
}
if ((user?.UserSettings?.FilterInaccessible ?? false) && (replyInaccessible || quoteInaccessible))
filterResult.Insert(0, InaccessibleFilter);
@ -204,12 +214,19 @@ public class NoteRenderer(
};
}
private static List<FilterResultEntity> GetFilterResult((Filter filter, string keyword)? filtered)
private static List<FilterResultEntity> GetFilterResult(
IReadOnlyCollection<(Filter filter, string keyword)> filtered
)
{
if (filtered == null) return [];
var (filter, keyword) = filtered.Value;
var res = new List<FilterResultEntity>();
return [new FilterResultEntity { Filter = FilterRenderer.RenderOne(filter), KeywordMatches = [keyword] }];
foreach (var entry in filtered)
{
var (filter, keyword) = entry;
res.Add(new FilterResultEntity { Filter = FilterRenderer.RenderOne(filter), KeywordMatches = [keyword] });
}
return res;
}
private async Task<List<MentionEntity>> GetMentions(List<Note> notes)
@ -368,8 +385,9 @@ public class NoteRenderer(
private async Task<List<Filter>> GetFilters(User? user, Filter.FilterContext? filterContext)
{
if (filterContext == null) return [];
return await db.Filters.Where(p => p.User == user && p.Contexts.Contains(filterContext.Value)).ToListAsync();
return filterContext == null
? await db.Filters.Where(p => p.User == user).ToListAsync()
: await db.Filters.Where(p => p.User == user && p.Contexts.Contains(filterContext.Value)).ToListAsync();
}
public async Task<IEnumerable<StatusEntity>> RenderManyAsync(

View file

@ -96,7 +96,8 @@ public class PublicChannel(
await using var scope = connection.ScopeFactory.CreateAsyncScope();
var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>();
var intermediate = await renderer.RenderAsync(note, connection.Token.User);
var data = new NoteRenderer.NoteRendererDto { Filters = connection.Filters.ToList() };
var intermediate = await renderer.RenderAsync(note, connection.Token.User, data: data);
var rendered = EnforceRenoteReplyVisibility(intermediate, wrapped);
var message = new StreamingUpdateMessage
{
@ -122,7 +123,8 @@ public class PublicChannel(
await using var scope = connection.ScopeFactory.CreateAsyncScope();
var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>();
var intermediate = await renderer.RenderAsync(note, connection.Token.User);
var data = new NoteRenderer.NoteRendererDto { Filters = connection.Filters.ToList() };
var intermediate = await renderer.RenderAsync(note, connection.Token.User, data: data);
var rendered = EnforceRenoteReplyVisibility(intermediate, wrapped);
var message = new StreamingUpdateMessage
{

View file

@ -29,6 +29,7 @@ public sealed class WebSocketConnection(
public readonly IServiceScope Scope = scopeFactory.CreateScope();
public readonly IServiceScopeFactory ScopeFactory = scopeFactory;
public readonly OauthToken Token = token;
public readonly WriteLockingList<Filter> Filters = [];
public readonly WriteLockingList<string> Following = [];
private readonly WriteLockingList<string> _blocking = [];
private readonly WriteLockingList<string> _blockedBy = [];
@ -39,10 +40,15 @@ public sealed class WebSocketConnection(
foreach (var channel in Channels)
channel.Dispose();
EventService.UserBlocked -= OnUserUnblock;
EventService.UserUnblocked -= OnUserBlock;
EventService.UserMuted -= OnUserMute;
EventService.UserUnmuted -= OnUserUnmute;
EventService.UserBlocked -= OnUserUnblock;
EventService.UserUnblocked -= OnUserBlock;
EventService.UserMuted -= OnUserMute;
EventService.UserUnmuted -= OnUserUnmute;
EventService.UserFollowed -= OnUserFollow;
EventService.UserUnfollowed -= OnUserUnfollow;
EventService.FilterAdded -= OnFilterAdded;
EventService.FilterRemoved -= OnFilterRemoved;
EventService.FilterUpdated -= OnFilterUpdated;
Scope.Dispose();
}
@ -64,8 +70,11 @@ public sealed class WebSocketConnection(
EventService.UserUnblocked += OnUserBlock;
EventService.UserMuted += OnUserMute;
EventService.UserUnmuted += OnUserUnmute;
EventService.UserFollowed -= OnUserFollow;
EventService.UserUnfollowed -= OnUserUnfollow;
EventService.UserFollowed += OnUserFollow;
EventService.UserUnfollowed += OnUserUnfollow;
EventService.FilterAdded += OnFilterAdded;
EventService.FilterRemoved += OnFilterRemoved;
EventService.FilterUpdated += OnFilterUpdated;
_ = InitializeRelationships();
}
@ -85,6 +94,19 @@ public sealed class WebSocketConnection(
_muting.AddRange(await db.Mutings.Where(p => p.Muter == Token.User)
.Select(p => p.MuteeId)
.ToListAsync());
Filters.AddRange(await db.Filters.Where(p => p.User == Token.User)
.Select(p => new Filter
{
Name = p.Name,
Action = p.Action,
Contexts = p.Contexts,
Expiry = p.Expiry,
Id = p.Id,
Keywords = p.Keywords
})
.AsNoTracking()
.ToListAsync());
}
public async Task HandleSocketMessageAsync(string payload)
@ -240,6 +262,57 @@ public sealed class WebSocketConnection(
}
}
private void OnFilterAdded(object? _, Filter filter)
{
try
{
if (filter.User.Id != Token.User.Id) return;
Filters.Add(filter.Clone(Token.User));
}
catch (Exception e)
{
var logger = Scope.ServiceProvider.GetRequiredService<Logger<WebSocketConnection>>();
logger.LogError("Event handler OnFilterAdded threw exception: {e}", e);
}
}
private void OnFilterRemoved(object? _, Filter filter)
{
try
{
if (filter.User.Id != Token.User.Id) return;
var match = Filters.FirstOrDefault(p => p.Id == filter.Id);
if (match != null) Filters.Remove(match);
}
catch (Exception e)
{
var logger = Scope.ServiceProvider.GetRequiredService<Logger<WebSocketConnection>>();
logger.LogError("Event handler OnFilterRemoved threw exception: {e}", e);
}
}
private void OnFilterUpdated(object? _, Filter filter)
{
try
{
if (filter.User.Id != Token.User.Id) return;
var match = Filters.FirstOrDefault(p => p.Id == filter.Id);
if (match == null) Filters.Add(filter.Clone(Token.User));
else
{
match.Contexts = filter.Contexts;
match.Action = filter.Action;
match.Keywords = filter.Keywords;
match.Name = filter.Name;
}
}
catch (Exception e)
{
var logger = Scope.ServiceProvider.GetRequiredService<Logger<WebSocketConnection>>();
logger.LogError("Event handler OnFilterUpdated threw exception: {e}", e);
}
}
[SuppressMessage("ReSharper", "SuggestBaseTypeForParameter")]
public bool IsFiltered(User user) =>
_blocking.Contains(user.Id) || _blockedBy.Contains(user.Id) || _muting.Contains(user.Id);

View file

@ -40,6 +40,20 @@ public class Filter
[Column("keywords")] public List<string> Keywords { get; set; } = [];
[Column("contexts")] public List<FilterContext> Contexts { get; set; } = [];
[Column("action")] public FilterAction Action { get; set; }
public Filter Clone(User? user = null)
{
return new Filter
{
Name = Name,
Action = Action,
Contexts = Contexts,
Expiry = Expiry,
Keywords = Keywords,
Id = Id,
User = user!
};
}
}
public class FilterEntityTypeConfiguration : IEntityTypeConfiguration<Filter>

View file

@ -32,6 +32,36 @@ public static class FilterHelper
return null;
}
public static List<(Filter filter, string keyword)> CheckFilters(IEnumerable<Note?> notes, List<Filter> filters)
{
if (filters.Count == 0) return [];
var res = new List<(Filter filter, string keyword)>();
foreach (var note in notes.OfType<Note>())
{
var match = CheckFilters(note, filters);
if (match.Count != 0) res.AddRange(match);
}
return res;
}
private static List<(Filter filter, string keyword)> CheckFilters(Note note, List<Filter> filters)
{
if (filters.Count == 0) return [];
var res = new List<(Filter filter, string keyword)>();
foreach (var filter in filters)
{
var match = IsFiltered(note, filter);
if (match != null) res.Add((filter, match));
}
return res;
}
private static string? IsFiltered(Note note, Filter filter)
{
foreach (var keyword in filter.Keywords)

View file

@ -19,6 +19,9 @@ public class EventService
public event EventHandler<UserInteraction>? UserMuted;
public event EventHandler<UserInteraction>? UserUnmuted;
public event EventHandler<Notification>? Notification;
public event EventHandler<Filter>? FilterAdded;
public event EventHandler<Filter>? FilterRemoved;
public event EventHandler<Filter>? FilterUpdated;
public void RaiseNotePublished(object? sender, Note note) => NotePublished?.Invoke(sender, note);
public void RaiseNoteUpdated(object? sender, Note note) => NoteUpdated?.Invoke(sender, note);
@ -61,4 +64,8 @@ public class EventService
public void RaiseUserUnmuted(object? sender, User actor, User obj) =>
UserUnmuted?.Invoke(sender, new UserInteraction { Actor = actor, Object = obj });
public void RaiseFilterAdded(object? sender, Filter filter) => FilterAdded?.Invoke(sender, filter);
public void RaiseFilterRemoved(object? sender, Filter filter) => FilterRemoved?.Invoke(sender, filter);
public void RaiseFilterUpdated(object? sender, Filter filter) => FilterUpdated?.Invoke(sender, filter);
}