[backend/masto-client] Respect filters in WebSocket connections (ISH-328)
This commit is contained in:
parent
849ecd9841
commit
9636a096fc
7 changed files with 168 additions and 18 deletions
|
@ -22,7 +22,7 @@ namespace Iceshrimp.Backend.Controllers.Mastodon;
|
|||
[EnableRateLimiting("sliding")]
|
||||
[EnableCors("mastodon")]
|
||||
[Produces(MediaTypeNames.Application.Json)]
|
||||
public class FilterController(DatabaseContext db, QueueService queueSvc) : ControllerBase
|
||||
public class FilterController(DatabaseContext db, QueueService queueSvc, EventService eventSvc) : ControllerBase
|
||||
{
|
||||
[HttpGet]
|
||||
[Authorize("read:filters")]
|
||||
|
@ -94,6 +94,7 @@ public class FilterController(DatabaseContext db, QueueService queueSvc) : Contr
|
|||
|
||||
db.Add(filter);
|
||||
await db.SaveChangesAsync();
|
||||
eventSvc.RaiseFilterAdded(this, filter);
|
||||
|
||||
if (expiry.HasValue)
|
||||
{
|
||||
|
@ -157,6 +158,7 @@ public class FilterController(DatabaseContext db, QueueService queueSvc) : Contr
|
|||
|
||||
db.Update(filter);
|
||||
await db.SaveChangesAsync();
|
||||
eventSvc.RaiseFilterUpdated(this, filter);
|
||||
|
||||
if (expiry.HasValue)
|
||||
{
|
||||
|
@ -179,6 +181,7 @@ public class FilterController(DatabaseContext db, QueueService queueSvc) : Contr
|
|||
|
||||
db.Remove(filter);
|
||||
await db.SaveChangesAsync();
|
||||
eventSvc.RaiseFilterRemoved(this, filter);
|
||||
|
||||
return Ok(new object());
|
||||
}
|
||||
|
@ -213,6 +216,7 @@ public class FilterController(DatabaseContext db, QueueService queueSvc) : Contr
|
|||
|
||||
db.Update(keyword);
|
||||
await db.SaveChangesAsync();
|
||||
eventSvc.RaiseFilterUpdated(this, filter);
|
||||
|
||||
return Ok(new FilterKeyword(keyword, filter.Id, filter.Keywords.Count - 1));
|
||||
}
|
||||
|
@ -251,6 +255,7 @@ public class FilterController(DatabaseContext db, QueueService queueSvc) : Contr
|
|||
filter.Keywords[keywordId] = request.WholeWord ? $"\"{request.Keyword}\"" : request.Keyword;
|
||||
db.Update(filter);
|
||||
await db.SaveChangesAsync();
|
||||
eventSvc.RaiseFilterUpdated(this, filter);
|
||||
|
||||
return Ok(new FilterKeyword(filter.Keywords[keywordId], filter.Id, keywordId));
|
||||
}
|
||||
|
@ -271,6 +276,7 @@ public class FilterController(DatabaseContext db, QueueService queueSvc) : Contr
|
|||
filter.Keywords.RemoveAt(keywordId);
|
||||
db.Update(filter);
|
||||
await db.SaveChangesAsync();
|
||||
eventSvc.RaiseFilterUpdated(this, filter);
|
||||
|
||||
return Ok(new object());
|
||||
}
|
||||
|
|
|
@ -92,9 +92,19 @@ public class NoteRenderer(
|
|||
? (data?.Polls ?? await GetPolls([note], user)).FirstOrDefault(p => p.Id == note.Id)
|
||||
: null;
|
||||
|
||||
var filters = data?.Filters ?? await GetFilters(user, filterContext);
|
||||
var filtered = FilterHelper.IsFiltered([note, note.Reply, note.Renote, note.Renote?.Renote], filters);
|
||||
var filterResult = GetFilterResult(filtered);
|
||||
var filters = data?.Filters ?? await GetFilters(user, filterContext);
|
||||
|
||||
List<FilterResultEntity> filterResult;
|
||||
if (filters.Count > 0 && filterContext == null)
|
||||
{
|
||||
var filtered = FilterHelper.CheckFilters([note, note.Reply, note.Renote, note.Renote?.Renote], filters);
|
||||
filterResult = GetFilterResult(filtered);
|
||||
}
|
||||
else
|
||||
{
|
||||
var filtered = FilterHelper.IsFiltered([note, note.Reply, note.Renote, note.Renote?.Renote], filters);
|
||||
filterResult = GetFilterResult(filtered.HasValue ? [filtered.Value] : []);
|
||||
}
|
||||
|
||||
if ((user?.UserSettings?.FilterInaccessible ?? false) && (replyInaccessible || quoteInaccessible))
|
||||
filterResult.Insert(0, InaccessibleFilter);
|
||||
|
@ -204,12 +214,19 @@ public class NoteRenderer(
|
|||
};
|
||||
}
|
||||
|
||||
private static List<FilterResultEntity> GetFilterResult((Filter filter, string keyword)? filtered)
|
||||
private static List<FilterResultEntity> GetFilterResult(
|
||||
IReadOnlyCollection<(Filter filter, string keyword)> filtered
|
||||
)
|
||||
{
|
||||
if (filtered == null) return [];
|
||||
var (filter, keyword) = filtered.Value;
|
||||
var res = new List<FilterResultEntity>();
|
||||
|
||||
return [new FilterResultEntity { Filter = FilterRenderer.RenderOne(filter), KeywordMatches = [keyword] }];
|
||||
foreach (var entry in filtered)
|
||||
{
|
||||
var (filter, keyword) = entry;
|
||||
res.Add(new FilterResultEntity { Filter = FilterRenderer.RenderOne(filter), KeywordMatches = [keyword] });
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
private async Task<List<MentionEntity>> GetMentions(List<Note> notes)
|
||||
|
@ -368,8 +385,9 @@ public class NoteRenderer(
|
|||
|
||||
private async Task<List<Filter>> GetFilters(User? user, Filter.FilterContext? filterContext)
|
||||
{
|
||||
if (filterContext == null) return [];
|
||||
return await db.Filters.Where(p => p.User == user && p.Contexts.Contains(filterContext.Value)).ToListAsync();
|
||||
return filterContext == null
|
||||
? await db.Filters.Where(p => p.User == user).ToListAsync()
|
||||
: await db.Filters.Where(p => p.User == user && p.Contexts.Contains(filterContext.Value)).ToListAsync();
|
||||
}
|
||||
|
||||
public async Task<IEnumerable<StatusEntity>> RenderManyAsync(
|
||||
|
|
|
@ -96,7 +96,8 @@ public class PublicChannel(
|
|||
await using var scope = connection.ScopeFactory.CreateAsyncScope();
|
||||
|
||||
var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>();
|
||||
var intermediate = await renderer.RenderAsync(note, connection.Token.User);
|
||||
var data = new NoteRenderer.NoteRendererDto { Filters = connection.Filters.ToList() };
|
||||
var intermediate = await renderer.RenderAsync(note, connection.Token.User, data: data);
|
||||
var rendered = EnforceRenoteReplyVisibility(intermediate, wrapped);
|
||||
var message = new StreamingUpdateMessage
|
||||
{
|
||||
|
@ -122,7 +123,8 @@ public class PublicChannel(
|
|||
await using var scope = connection.ScopeFactory.CreateAsyncScope();
|
||||
|
||||
var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>();
|
||||
var intermediate = await renderer.RenderAsync(note, connection.Token.User);
|
||||
var data = new NoteRenderer.NoteRendererDto { Filters = connection.Filters.ToList() };
|
||||
var intermediate = await renderer.RenderAsync(note, connection.Token.User, data: data);
|
||||
var rendered = EnforceRenoteReplyVisibility(intermediate, wrapped);
|
||||
var message = new StreamingUpdateMessage
|
||||
{
|
||||
|
|
|
@ -29,6 +29,7 @@ public sealed class WebSocketConnection(
|
|||
public readonly IServiceScope Scope = scopeFactory.CreateScope();
|
||||
public readonly IServiceScopeFactory ScopeFactory = scopeFactory;
|
||||
public readonly OauthToken Token = token;
|
||||
public readonly WriteLockingList<Filter> Filters = [];
|
||||
public readonly WriteLockingList<string> Following = [];
|
||||
private readonly WriteLockingList<string> _blocking = [];
|
||||
private readonly WriteLockingList<string> _blockedBy = [];
|
||||
|
@ -39,10 +40,15 @@ public sealed class WebSocketConnection(
|
|||
foreach (var channel in Channels)
|
||||
channel.Dispose();
|
||||
|
||||
EventService.UserBlocked -= OnUserUnblock;
|
||||
EventService.UserUnblocked -= OnUserBlock;
|
||||
EventService.UserMuted -= OnUserMute;
|
||||
EventService.UserUnmuted -= OnUserUnmute;
|
||||
EventService.UserBlocked -= OnUserUnblock;
|
||||
EventService.UserUnblocked -= OnUserBlock;
|
||||
EventService.UserMuted -= OnUserMute;
|
||||
EventService.UserUnmuted -= OnUserUnmute;
|
||||
EventService.UserFollowed -= OnUserFollow;
|
||||
EventService.UserUnfollowed -= OnUserUnfollow;
|
||||
EventService.FilterAdded -= OnFilterAdded;
|
||||
EventService.FilterRemoved -= OnFilterRemoved;
|
||||
EventService.FilterUpdated -= OnFilterUpdated;
|
||||
|
||||
Scope.Dispose();
|
||||
}
|
||||
|
@ -64,8 +70,11 @@ public sealed class WebSocketConnection(
|
|||
EventService.UserUnblocked += OnUserBlock;
|
||||
EventService.UserMuted += OnUserMute;
|
||||
EventService.UserUnmuted += OnUserUnmute;
|
||||
EventService.UserFollowed -= OnUserFollow;
|
||||
EventService.UserUnfollowed -= OnUserUnfollow;
|
||||
EventService.UserFollowed += OnUserFollow;
|
||||
EventService.UserUnfollowed += OnUserUnfollow;
|
||||
EventService.FilterAdded += OnFilterAdded;
|
||||
EventService.FilterRemoved += OnFilterRemoved;
|
||||
EventService.FilterUpdated += OnFilterUpdated;
|
||||
|
||||
_ = InitializeRelationships();
|
||||
}
|
||||
|
@ -85,6 +94,19 @@ public sealed class WebSocketConnection(
|
|||
_muting.AddRange(await db.Mutings.Where(p => p.Muter == Token.User)
|
||||
.Select(p => p.MuteeId)
|
||||
.ToListAsync());
|
||||
|
||||
Filters.AddRange(await db.Filters.Where(p => p.User == Token.User)
|
||||
.Select(p => new Filter
|
||||
{
|
||||
Name = p.Name,
|
||||
Action = p.Action,
|
||||
Contexts = p.Contexts,
|
||||
Expiry = p.Expiry,
|
||||
Id = p.Id,
|
||||
Keywords = p.Keywords
|
||||
})
|
||||
.AsNoTracking()
|
||||
.ToListAsync());
|
||||
}
|
||||
|
||||
public async Task HandleSocketMessageAsync(string payload)
|
||||
|
@ -240,6 +262,57 @@ public sealed class WebSocketConnection(
|
|||
}
|
||||
}
|
||||
|
||||
private void OnFilterAdded(object? _, Filter filter)
|
||||
{
|
||||
try
|
||||
{
|
||||
if (filter.User.Id != Token.User.Id) return;
|
||||
Filters.Add(filter.Clone(Token.User));
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
var logger = Scope.ServiceProvider.GetRequiredService<Logger<WebSocketConnection>>();
|
||||
logger.LogError("Event handler OnFilterAdded threw exception: {e}", e);
|
||||
}
|
||||
}
|
||||
|
||||
private void OnFilterRemoved(object? _, Filter filter)
|
||||
{
|
||||
try
|
||||
{
|
||||
if (filter.User.Id != Token.User.Id) return;
|
||||
var match = Filters.FirstOrDefault(p => p.Id == filter.Id);
|
||||
if (match != null) Filters.Remove(match);
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
var logger = Scope.ServiceProvider.GetRequiredService<Logger<WebSocketConnection>>();
|
||||
logger.LogError("Event handler OnFilterRemoved threw exception: {e}", e);
|
||||
}
|
||||
}
|
||||
|
||||
private void OnFilterUpdated(object? _, Filter filter)
|
||||
{
|
||||
try
|
||||
{
|
||||
if (filter.User.Id != Token.User.Id) return;
|
||||
var match = Filters.FirstOrDefault(p => p.Id == filter.Id);
|
||||
if (match == null) Filters.Add(filter.Clone(Token.User));
|
||||
else
|
||||
{
|
||||
match.Contexts = filter.Contexts;
|
||||
match.Action = filter.Action;
|
||||
match.Keywords = filter.Keywords;
|
||||
match.Name = filter.Name;
|
||||
}
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
var logger = Scope.ServiceProvider.GetRequiredService<Logger<WebSocketConnection>>();
|
||||
logger.LogError("Event handler OnFilterUpdated threw exception: {e}", e);
|
||||
}
|
||||
}
|
||||
|
||||
[SuppressMessage("ReSharper", "SuggestBaseTypeForParameter")]
|
||||
public bool IsFiltered(User user) =>
|
||||
_blocking.Contains(user.Id) || _blockedBy.Contains(user.Id) || _muting.Contains(user.Id);
|
||||
|
|
|
@ -40,6 +40,20 @@ public class Filter
|
|||
[Column("keywords")] public List<string> Keywords { get; set; } = [];
|
||||
[Column("contexts")] public List<FilterContext> Contexts { get; set; } = [];
|
||||
[Column("action")] public FilterAction Action { get; set; }
|
||||
|
||||
public Filter Clone(User? user = null)
|
||||
{
|
||||
return new Filter
|
||||
{
|
||||
Name = Name,
|
||||
Action = Action,
|
||||
Contexts = Contexts,
|
||||
Expiry = Expiry,
|
||||
Keywords = Keywords,
|
||||
Id = Id,
|
||||
User = user!
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
public class FilterEntityTypeConfiguration : IEntityTypeConfiguration<Filter>
|
||||
|
|
|
@ -32,6 +32,36 @@ public static class FilterHelper
|
|||
return null;
|
||||
}
|
||||
|
||||
public static List<(Filter filter, string keyword)> CheckFilters(IEnumerable<Note?> notes, List<Filter> filters)
|
||||
{
|
||||
if (filters.Count == 0) return [];
|
||||
|
||||
var res = new List<(Filter filter, string keyword)>();
|
||||
|
||||
foreach (var note in notes.OfType<Note>())
|
||||
{
|
||||
var match = CheckFilters(note, filters);
|
||||
if (match.Count != 0) res.AddRange(match);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
private static List<(Filter filter, string keyword)> CheckFilters(Note note, List<Filter> filters)
|
||||
{
|
||||
if (filters.Count == 0) return [];
|
||||
|
||||
var res = new List<(Filter filter, string keyword)>();
|
||||
|
||||
foreach (var filter in filters)
|
||||
{
|
||||
var match = IsFiltered(note, filter);
|
||||
if (match != null) res.Add((filter, match));
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
private static string? IsFiltered(Note note, Filter filter)
|
||||
{
|
||||
foreach (var keyword in filter.Keywords)
|
||||
|
|
|
@ -19,6 +19,9 @@ public class EventService
|
|||
public event EventHandler<UserInteraction>? UserMuted;
|
||||
public event EventHandler<UserInteraction>? UserUnmuted;
|
||||
public event EventHandler<Notification>? Notification;
|
||||
public event EventHandler<Filter>? FilterAdded;
|
||||
public event EventHandler<Filter>? FilterRemoved;
|
||||
public event EventHandler<Filter>? FilterUpdated;
|
||||
|
||||
public void RaiseNotePublished(object? sender, Note note) => NotePublished?.Invoke(sender, note);
|
||||
public void RaiseNoteUpdated(object? sender, Note note) => NoteUpdated?.Invoke(sender, note);
|
||||
|
@ -61,4 +64,8 @@ public class EventService
|
|||
|
||||
public void RaiseUserUnmuted(object? sender, User actor, User obj) =>
|
||||
UserUnmuted?.Invoke(sender, new UserInteraction { Actor = actor, Object = obj });
|
||||
|
||||
public void RaiseFilterAdded(object? sender, Filter filter) => FilterAdded?.Invoke(sender, filter);
|
||||
public void RaiseFilterRemoved(object? sender, Filter filter) => FilterRemoved?.Invoke(sender, filter);
|
||||
public void RaiseFilterUpdated(object? sender, Filter filter) => FilterUpdated?.Invoke(sender, filter);
|
||||
}
|
Loading…
Add table
Reference in a new issue