[backend/api-shared] Add thread mute support (ISH-172)

This commit is contained in:
Laura Hausmann 2024-07-13 00:52:46 +02:00
parent bc26e39812
commit 1d43f2c30b
No known key found for this signature in database
GPG key ID: D044E84C5BE01605
13 changed files with 108 additions and 11 deletions

View file

@ -41,6 +41,7 @@ public class ConversationsController(
var conversations = await db.Conversations(user) var conversations = await db.Conversations(user)
.IncludeCommonProperties() .IncludeCommonProperties()
.FilterHiddenConversations(user, db) .FilterHiddenConversations(user, db)
.FilterMutedThreads(user, db)
.Paginate(p => p.ThreadId ?? p.Id, pq, ControllerContext) .Paginate(p => p.ThreadId ?? p.Id, pq, ControllerContext)
.Select(p => new Conversation .Select(p => new Conversation
{ {

View file

@ -50,6 +50,7 @@ public class NotificationController(DatabaseContext db, NotificationRenderer not
.FilterByGetNotificationsRequest(request) .FilterByGetNotificationsRequest(request)
.EnsureNoteVisibilityFor(p => p.Note, user) .EnsureNoteVisibilityFor(p => p.Note, user)
.FilterHiddenNotifications(user, db) .FilterHiddenNotifications(user, db)
.FilterMutedThreads(user, db)
.Paginate(p => p.MastoId, query, ControllerContext) .Paginate(p => p.MastoId, query, ControllerContext)
.PrecomputeNoteVisibilities(user) .PrecomputeNoteVisibilities(user)
.RenderAllForMastodonAsync(notificationRenderer, user); .RenderAllForMastodonAsync(notificationRenderer, user);

View file

@ -73,10 +73,10 @@ public class DirectChannel(WebSocketConnection connection) : IChannel
return rendered; return rendered;
} }
private async Task<ConversationEntity> RenderConversation(Note note, NoteWithVisibilities wrapped) private async Task<ConversationEntity> RenderConversation(
Note note, NoteWithVisibilities wrapped, AsyncServiceScope scope
)
{ {
await using var scope = connection.ScopeFactory.CreateAsyncScope();
var db = scope.ServiceProvider.GetRequiredService<DatabaseContext>(); var db = scope.ServiceProvider.GetRequiredService<DatabaseContext>();
var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>(); var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>();
var userRenderer = scope.ServiceProvider.GetRequiredService<UserRenderer>(); var userRenderer = scope.ServiceProvider.GetRequiredService<UserRenderer>();
@ -106,11 +106,14 @@ public class DirectChannel(WebSocketConnection connection) : IChannel
if (connection.IsFiltered(note)) return; if (connection.IsFiltered(note)) return;
if (note.CreatedAt < DateTime.UtcNow - TimeSpan.FromMinutes(5)) return; if (note.CreatedAt < DateTime.UtcNow - TimeSpan.FromMinutes(5)) return;
await using var scope = connection.ScopeFactory.CreateAsyncScope();
if (await connection.IsMutedThread(note, scope)) return;
var message = new StreamingUpdateMessage var message = new StreamingUpdateMessage
{ {
Stream = [Name], Stream = [Name],
Event = "conversation", Event = "conversation",
Payload = JsonSerializer.Serialize(await RenderConversation(note, wrapped)) Payload = JsonSerializer.Serialize(await RenderConversation(note, wrapped, scope))
}; };
await connection.SendMessageAsync(JsonSerializer.Serialize(message)); await connection.SendMessageAsync(JsonSerializer.Serialize(message));
@ -129,11 +132,12 @@ public class DirectChannel(WebSocketConnection connection) : IChannel
if (wrapped == null) return; if (wrapped == null) return;
if (connection.IsFiltered(note)) return; if (connection.IsFiltered(note)) return;
await using var scope = connection.ScopeFactory.CreateAsyncScope();
var message = new StreamingUpdateMessage var message = new StreamingUpdateMessage
{ {
Stream = [Name], Stream = [Name],
Event = "conversation", Event = "conversation",
Payload = JsonSerializer.Serialize(await RenderConversation(note, wrapped)) Payload = JsonSerializer.Serialize(await RenderConversation(note, wrapped, scope))
}; };
await connection.SendMessageAsync(JsonSerializer.Serialize(message)); await connection.SendMessageAsync(JsonSerializer.Serialize(message));

View file

@ -106,6 +106,7 @@ public class HashtagChannel(WebSocketConnection connection, bool local) : IChann
if (wrapped == null) return; if (wrapped == null) return;
if (connection.IsFiltered(note)) return; if (connection.IsFiltered(note)) return;
await using var scope = connection.ScopeFactory.CreateAsyncScope(); await using var scope = connection.ScopeFactory.CreateAsyncScope();
if (await connection.IsMutedThread(note, scope)) return;
var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>(); var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>();
var data = new NoteRenderer.NoteRendererDto { Filters = connection.Filters.ToList() }; var data = new NoteRenderer.NoteRendererDto { Filters = connection.Filters.ToList() };

View file

@ -129,6 +129,7 @@ public class ListChannel(WebSocketConnection connection) : IChannel
if (wrapped == null) return; if (wrapped == null) return;
if (connection.IsFiltered(note)) return; if (connection.IsFiltered(note)) return;
await using var scope = connection.ScopeFactory.CreateAsyncScope(); await using var scope = connection.ScopeFactory.CreateAsyncScope();
if (await connection.IsMutedThread(note, scope)) return;
var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>(); var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>();
var data = new NoteRenderer.NoteRendererDto { Filters = connection.Filters.ToList() }; var data = new NoteRenderer.NoteRendererDto { Filters = connection.Filters.ToList() };

View file

@ -89,6 +89,7 @@ public class PublicChannel(
if (wrapped == null) return; if (wrapped == null) return;
if (connection.IsFiltered(note)) return; if (connection.IsFiltered(note)) return;
await using var scope = connection.ScopeFactory.CreateAsyncScope(); await using var scope = connection.ScopeFactory.CreateAsyncScope();
if (await connection.IsMutedThread(note, scope)) return;
var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>(); var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>();
var data = new NoteRenderer.NoteRendererDto { Filters = connection.Filters.ToList() }; var data = new NoteRenderer.NoteRendererDto { Filters = connection.Filters.ToList() };

View file

@ -102,6 +102,7 @@ public class UserChannel(WebSocketConnection connection, bool notificationsOnly)
if (connection.IsFiltered(note)) return; if (connection.IsFiltered(note)) return;
if (note.CreatedAt < DateTime.UtcNow - TimeSpan.FromMinutes(5)) return; if (note.CreatedAt < DateTime.UtcNow - TimeSpan.FromMinutes(5)) return;
await using var scope = connection.ScopeFactory.CreateAsyncScope(); await using var scope = connection.ScopeFactory.CreateAsyncScope();
if (await connection.IsMutedThread(note, scope)) return;
var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>(); var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>();
var intermediate = await renderer.RenderAsync(note, connection.Token.User); var intermediate = await renderer.RenderAsync(note, connection.Token.User);
@ -172,7 +173,9 @@ public class UserChannel(WebSocketConnection connection, bool notificationsOnly)
{ {
if (!IsApplicable(notification)) return; if (!IsApplicable(notification)) return;
if (IsFiltered(notification)) return; if (IsFiltered(notification)) return;
await using var scope = connection.ScopeFactory.CreateAsyncScope(); await using var scope = connection.ScopeFactory.CreateAsyncScope();
if (notification.Note != null && await connection.IsMutedThread(notification.Note, scope)) return;
var renderer = scope.ServiceProvider.GetRequiredService<NotificationRenderer>(); var renderer = scope.ServiceProvider.GetRequiredService<NotificationRenderer>();

View file

@ -366,6 +366,14 @@ public sealed class WebSocketConnection(
(IsFiltered(note.Renote.Renote.User) || (IsFiltered(note.Renote.Renote.User) ||
IsFilteredMentions(note.Renote.Renote.Mentions))); IsFilteredMentions(note.Renote.Renote.Mentions)));
public async Task<bool> IsMutedThread(Note note, AsyncServiceScope scope)
{
if (note.Reply == null) return false;
if (note.User.Id == Token.UserId) return false;
await using var db = scope.ServiceProvider.GetRequiredService<DatabaseContext>();
return await db.NoteThreadMutings.AnyAsync(p => p.UserId == Token.UserId && p.ThreadId == note.ThreadId);
}
public async Task CloseAsync(WebSocketCloseStatus status) public async Task CloseAsync(WebSocketCloseStatus status)
{ {
Dispose(); Dispose();

View file

@ -38,6 +38,7 @@ public class TimelineController(DatabaseContext db, NoteRenderer noteRenderer, C
.FilterByFollowingAndOwn(user, db, heuristic) .FilterByFollowingAndOwn(user, db, heuristic)
.EnsureVisibleFor(user) .EnsureVisibleFor(user)
.FilterHidden(user, db, filterHiddenListMembers: true) .FilterHidden(user, db, filterHiddenListMembers: true)
.FilterMutedThreads(user, db)
.Paginate(query, ControllerContext) .Paginate(query, ControllerContext)
.PrecomputeVisibilities(user) .PrecomputeVisibilities(user)
.RenderAllForMastodonAsync(noteRenderer, user, Filter.FilterContext.Home); .RenderAllForMastodonAsync(noteRenderer, user, Filter.FilterContext.Home);
@ -56,6 +57,7 @@ public class TimelineController(DatabaseContext db, NoteRenderer noteRenderer, C
.HasVisibility(Note.NoteVisibility.Public) .HasVisibility(Note.NoteVisibility.Public)
.FilterByPublicTimelineRequest(request) .FilterByPublicTimelineRequest(request)
.FilterHidden(user, db) .FilterHidden(user, db)
.FilterMutedThreads(user, db)
.Paginate(query, ControllerContext) .Paginate(query, ControllerContext)
.PrecomputeVisibilities(user) .PrecomputeVisibilities(user)
.RenderAllForMastodonAsync(noteRenderer, user, Filter.FilterContext.Public); .RenderAllForMastodonAsync(noteRenderer, user, Filter.FilterContext.Public);
@ -74,6 +76,7 @@ public class TimelineController(DatabaseContext db, NoteRenderer noteRenderer, C
.Where(p => p.Tags.Contains(hashtag.ToLowerInvariant())) .Where(p => p.Tags.Contains(hashtag.ToLowerInvariant()))
.FilterByHashtagTimelineRequest(request) .FilterByHashtagTimelineRequest(request)
.FilterHidden(user, db) .FilterHidden(user, db)
.FilterMutedThreads(user, db)
.Paginate(query, ControllerContext) .Paginate(query, ControllerContext)
.PrecomputeVisibilities(user) .PrecomputeVisibilities(user)
.RenderAllForMastodonAsync(noteRenderer, user, Filter.FilterContext.Public); .RenderAllForMastodonAsync(noteRenderer, user, Filter.FilterContext.Public);
@ -93,6 +96,7 @@ public class TimelineController(DatabaseContext db, NoteRenderer noteRenderer, C
.Where(p => db.UserListMembers.Any(l => l.UserListId == id && l.UserId == p.UserId)) .Where(p => db.UserListMembers.Any(l => l.UserListId == id && l.UserId == p.UserId))
.EnsureVisibleFor(user) .EnsureVisibleFor(user)
.FilterHidden(user, db) .FilterHidden(user, db)
.FilterMutedThreads(user, db)
.Paginate(query, ControllerContext) .Paginate(query, ControllerContext)
.PrecomputeVisibilities(user) .PrecomputeVisibilities(user)
.RenderAllForMastodonAsync(noteRenderer, user, Filter.FilterContext.Lists); .RenderAllForMastodonAsync(noteRenderer, user, Filter.FilterContext.Lists);

View file

@ -336,6 +336,50 @@ public class NoteController(
}; };
} }
[HttpPost("{id}/mute")]
[Authenticate]
[Authorize]
[EnableRateLimiting("strict")]
[ProducesResults(HttpStatusCode.OK)]
[ProducesErrors(HttpStatusCode.NotFound)]
public async Task MuteNoteThread(string id)
{
var user = HttpContext.GetUserOrFail();
var target = await db.Notes.Where(p => p.Id == id)
.EnsureVisibleFor(user)
.Select(p => p.ThreadId ?? p.Id)
.FirstOrDefaultAsync() ??
throw GracefulException.NotFound("Note not found");
var mute = new NoteThreadMuting
{
Id = IdHelpers.GenerateSlowflakeId(),
CreatedAt = DateTime.UtcNow,
ThreadId = target,
UserId = user.Id
};
await db.NoteThreadMutings.Upsert(mute).On(p => new { p.UserId, p.ThreadId }).NoUpdate().RunAsync();
}
[HttpPost("{id}/unmute")]
[Authenticate]
[Authorize]
[EnableRateLimiting("strict")]
[ProducesResults(HttpStatusCode.OK)]
[ProducesErrors(HttpStatusCode.NotFound)]
public async Task UnmuteNoteThread(string id)
{
var user = HttpContext.GetUserOrFail();
var target = await db.Notes.Where(p => p.Id == id)
.EnsureVisibleFor(user)
.Select(p => p.ThreadId ?? p.Id)
.FirstOrDefaultAsync() ??
throw GracefulException.NotFound("Note not found");
await db.NoteThreadMutings.Where(p => p.User == user && p.ThreadId == target).ExecuteDeleteAsync();
}
[HttpPost] [HttpPost]
[Authenticate] [Authenticate]
[Authorize] [Authorize]

View file

@ -34,6 +34,7 @@ public class TimelineController(DatabaseContext db, NoteRenderer noteRenderer, C
.FilterByFollowingAndOwn(user, db, heuristic) .FilterByFollowingAndOwn(user, db, heuristic)
.EnsureVisibleFor(user) .EnsureVisibleFor(user)
.FilterHidden(user, db, filterHiddenListMembers: true) .FilterHidden(user, db, filterHiddenListMembers: true)
.FilterMutedThreads(user, db)
.Paginate(pq, ControllerContext) .Paginate(pq, ControllerContext)
.PrecomputeVisibilities(user) .PrecomputeVisibilities(user)
.ToListAsync(); .ToListAsync();

View file

@ -351,6 +351,19 @@ public static class QueryableExtensions
return query.Where(p => p.VisibleUserIds.IsDisjoint(hidden)); return query.Where(p => p.VisibleUserIds.IsDisjoint(hidden));
} }
public static IQueryable<Note> FilterMutedThreads(this IQueryable<Note> query, User user, DatabaseContext db)
{
return query.Where(p => !db.NoteThreadMutings.Any(m => m.User == user && m.ThreadId == (p.ThreadId ?? p.Id)));
}
public static IQueryable<Notification> FilterMutedThreads(
this IQueryable<Notification> query, User user, DatabaseContext db
)
{
return query.Where(p => p.Note == null ||
!db.NoteThreadMutings.Any(m => m.User == user && m.ThreadId == (p.Note.ThreadId ?? p.Note.Id)));
}
private static (IQueryable<string> hidden, IQueryable<string>? mentionsHidden) FilterHiddenInternal( private static (IQueryable<string> hidden, IQueryable<string>? mentionsHidden) FilterHiddenInternal(
User? user, User? user,
DatabaseContext db, DatabaseContext db,
@ -388,8 +401,8 @@ public static class QueryableExtensions
if (except != null) if (except != null)
{ {
hidden = hidden.Except(new[] { except }); hidden = hidden.Except([except]);
mentionsHidden = mentionsHidden?.Except(new[] { except }); mentionsHidden = mentionsHidden?.Except([except]);
} }
return (hidden, mentionsHidden); return (hidden, mentionsHidden);
@ -404,16 +417,14 @@ public static class QueryableExtensions
return note => !hidden.Contains(note.UserId) && return note => !hidden.Contains(note.UserId) &&
!hidden.Contains(note.RenoteUserId) && !hidden.Contains(note.RenoteUserId) &&
!hidden.Contains(note.ReplyUserId) && !hidden.Contains(note.ReplyUserId) &&
(note.Renote == null || (note.Renote == null || !hidden.Contains(note.Renote.RenoteUserId)) &&
!hidden.Contains(note.Renote.RenoteUserId)) &&
note.Mentions.IsDisjoint(mentionsHidden); note.Mentions.IsDisjoint(mentionsHidden);
} }
return note => !hidden.Contains(note.UserId) && return note => !hidden.Contains(note.UserId) &&
!hidden.Contains(note.RenoteUserId) && !hidden.Contains(note.RenoteUserId) &&
!hidden.Contains(note.ReplyUserId) && !hidden.Contains(note.ReplyUserId) &&
(note.Renote == null || (note.Renote == null || !hidden.Contains(note.Renote.RenoteUserId));
!hidden.Contains(note.Renote.RenoteUserId));
} }
public static IQueryable<TSource> FilterHidden<TSource>( public static IQueryable<TSource> FilterHidden<TSource>(

View file

@ -70,7 +70,9 @@ public sealed class StreamingConnectionAggregate : IDisposable
{ {
if (notification.NotifieeId != _userId) return; if (notification.NotifieeId != _userId) return;
if (notification.Notifier != null && IsFiltered(notification.Notifier)) return; if (notification.Notifier != null && IsFiltered(notification.Notifier)) return;
await using var scope = GetTempScope(); await using var scope = GetTempScope();
if (notification.Note != null && await IsMutedThread(notification.Note, scope)) return;
var renderer = scope.ServiceProvider.GetRequiredService<NotificationRenderer>(); var renderer = scope.ServiceProvider.GetRequiredService<NotificationRenderer>();
var rendered = await renderer.RenderOne(notification, _user); var rendered = await renderer.RenderOne(notification, _user);
@ -91,6 +93,13 @@ public sealed class StreamingConnectionAggregate : IDisposable
var recipients = FindRecipients(data.note); var recipients = FindRecipients(data.note);
if (recipients.connectionIds.Count == 0) return; if (recipients.connectionIds.Count == 0) return;
if (data.note.Reply != null)
{
await using var scope = _scopeFactory.CreateAsyncScope();
if (await IsMutedThread(data.note, scope))
return;
}
var rendered = EnforceRenoteReplyVisibility(await data.rendered(), wrapped); var rendered = EnforceRenoteReplyVisibility(await data.rendered(), wrapped);
await _hub.Clients.Clients(recipients.connectionIds).NotePublished(recipients.timelines, rendered); await _hub.Clients.Clients(recipients.connectionIds).NotePublished(recipients.timelines, rendered);
} }
@ -144,6 +153,14 @@ public sealed class StreamingConnectionAggregate : IDisposable
return res is not { Note.IsPureRenote: true, Renote: null } ? res : null; return res is not { Note.IsPureRenote: true, Renote: null } ? res : null;
} }
private async Task<bool> IsMutedThread(Note note, AsyncServiceScope scope)
{
if (note.Reply == null) return false;
if (note.User.Id == _userId) return false;
await using var db = scope.ServiceProvider.GetRequiredService<DatabaseContext>();
return await db.NoteThreadMutings.AnyAsync(p => p.UserId == _userId && p.ThreadId == note.ThreadId);
}
[SuppressMessage("ReSharper", "SuggestBaseTypeForParameter")] [SuppressMessage("ReSharper", "SuggestBaseTypeForParameter")]
private bool IsFiltered(Note note) => private bool IsFiltered(Note note) =>
IsFiltered(note.User) || _blocking.Intersects(note.Mentions) || _muting.Intersects(note.Mentions); IsFiltered(note.User) || _blocking.Intersects(note.Mentions) || _muting.Intersects(note.Mentions);