[backend/streaming] Fix streaming updates not containing html markup if supported

This commit is contained in:
Laura Hausmann 2024-11-20 02:15:16 +01:00
parent 862d477dec
commit c0e8a6d680
No known key found for this signature in database
GPG key ID: D044E84C5BE01605
6 changed files with 32 additions and 17 deletions

View file

@ -23,7 +23,7 @@ public class DirectChannel(WebSocketConnection connection) : IChannel
if (IsSubscribed) return; if (IsSubscribed) return;
IsSubscribed = true; IsSubscribed = true;
await using var scope = connection.ScopeFactory.CreateAsyncScope(); await using var scope = connection.GetAsyncServiceScope();
await using var db = scope.ServiceProvider.GetRequiredService<DatabaseContext>(); await using var db = scope.ServiceProvider.GetRequiredService<DatabaseContext>();
connection.EventService.NotePublished += OnNotePublished; connection.EventService.NotePublished += OnNotePublished;
@ -106,7 +106,7 @@ public class DirectChannel(WebSocketConnection connection) : IChannel
if (connection.IsFiltered(note)) return; if (connection.IsFiltered(note)) return;
if (note.CreatedAt < DateTime.UtcNow - TimeSpan.FromMinutes(5)) return; if (note.CreatedAt < DateTime.UtcNow - TimeSpan.FromMinutes(5)) return;
await using var scope = connection.ScopeFactory.CreateAsyncScope(); await using var scope = connection.GetAsyncServiceScope();
if (await connection.IsMutedThreadAsync(note, scope)) return; if (await connection.IsMutedThreadAsync(note, scope)) return;
var message = new StreamingUpdateMessage var message = new StreamingUpdateMessage
@ -132,7 +132,7 @@ public class DirectChannel(WebSocketConnection connection) : IChannel
if (wrapped == null) return; if (wrapped == null) return;
if (connection.IsFiltered(note)) return; if (connection.IsFiltered(note)) return;
await using var scope = connection.ScopeFactory.CreateAsyncScope(); await using var scope = connection.GetAsyncServiceScope();
var message = new StreamingUpdateMessage var message = new StreamingUpdateMessage
{ {
Stream = [Name], Stream = [Name],

View file

@ -105,7 +105,7 @@ public class HashtagChannel(WebSocketConnection connection, bool local) : IChann
var wrapped = IsApplicable(note); var wrapped = IsApplicable(note);
if (wrapped == null) return; if (wrapped == null) return;
if (connection.IsFiltered(note)) return; if (connection.IsFiltered(note)) return;
await using var scope = connection.ScopeFactory.CreateAsyncScope(); await using var scope = connection.GetAsyncServiceScope();
if (await connection.IsMutedThreadAsync(note, scope)) return; if (await connection.IsMutedThreadAsync(note, scope)) return;
var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>(); var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>();
@ -130,7 +130,7 @@ public class HashtagChannel(WebSocketConnection connection, bool local) : IChann
var wrapped = IsApplicable(note); var wrapped = IsApplicable(note);
if (wrapped == null) return; if (wrapped == null) return;
if (connection.IsFiltered(note)) return; if (connection.IsFiltered(note)) return;
await using var scope = connection.ScopeFactory.CreateAsyncScope(); await using var scope = connection.GetAsyncServiceScope();
var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>(); var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>();
var data = new NoteRenderer.NoteRendererDto { Filters = connection.Filters.ToList() }; var data = new NoteRenderer.NoteRendererDto { Filters = connection.Filters.ToList() };

View file

@ -43,7 +43,7 @@ public class ListChannel(WebSocketConnection connection) : IChannel
if (_lists.AddIfMissing(msg.List)) if (_lists.AddIfMissing(msg.List))
{ {
await using var scope = connection.ScopeFactory.CreateAsyncScope(); await using var scope = connection.GetAsyncServiceScope();
var db = scope.ServiceProvider.GetRequiredService<DatabaseContext>(); var db = scope.ServiceProvider.GetRequiredService<DatabaseContext>();
var list = await db.UserLists.FirstOrDefaultAsync(p => p.UserId == connection.Token.User.Id && var list = await db.UserLists.FirstOrDefaultAsync(p => p.UserId == connection.Token.User.Id &&
@ -128,7 +128,7 @@ public class ListChannel(WebSocketConnection connection) : IChannel
var wrapped = IsApplicable(note); var wrapped = IsApplicable(note);
if (wrapped == null) return; if (wrapped == null) return;
if (connection.IsFiltered(note)) return; if (connection.IsFiltered(note)) return;
await using var scope = connection.ScopeFactory.CreateAsyncScope(); await using var scope = connection.GetAsyncServiceScope();
if (await connection.IsMutedThreadAsync(note, scope)) return; if (await connection.IsMutedThreadAsync(note, scope)) return;
var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>(); var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>();
@ -154,7 +154,7 @@ public class ListChannel(WebSocketConnection connection) : IChannel
var wrapped = IsApplicable(note); var wrapped = IsApplicable(note);
if (wrapped == null) return; if (wrapped == null) return;
if (connection.IsFiltered(note)) return; if (connection.IsFiltered(note)) return;
await using var scope = connection.ScopeFactory.CreateAsyncScope(); await using var scope = connection.GetAsyncServiceScope();
var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>(); var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>();
var data = new NoteRenderer.NoteRendererDto { Filters = connection.Filters.ToList() }; var data = new NoteRenderer.NoteRendererDto { Filters = connection.Filters.ToList() };
@ -197,7 +197,7 @@ public class ListChannel(WebSocketConnection connection) : IChannel
if (list.UserId != connection.Token.User.Id) return; if (list.UserId != connection.Token.User.Id) return;
if (!_lists.Contains(list.Id)) return; if (!_lists.Contains(list.Id)) return;
await using var scope = connection.ScopeFactory.CreateAsyncScope(); await using var scope = connection.GetAsyncServiceScope();
var db = scope.ServiceProvider.GetRequiredService<DatabaseContext>(); var db = scope.ServiceProvider.GetRequiredService<DatabaseContext>();
var members = await db.UserListMembers.Where(p => p.UserListId == list.Id) var members = await db.UserListMembers.Where(p => p.UserListId == list.Id)

View file

@ -88,7 +88,7 @@ public class PublicChannel(
var wrapped = IsApplicable(note); var wrapped = IsApplicable(note);
if (wrapped == null) return; if (wrapped == null) return;
if (connection.IsFiltered(note)) return; if (connection.IsFiltered(note)) return;
await using var scope = connection.ScopeFactory.CreateAsyncScope(); await using var scope = connection.GetAsyncServiceScope();
if (await connection.IsMutedThreadAsync(note, scope)) return; if (await connection.IsMutedThreadAsync(note, scope)) return;
var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>(); var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>();
@ -116,7 +116,7 @@ public class PublicChannel(
var wrapped = IsApplicable(note); var wrapped = IsApplicable(note);
if (wrapped == null) return; if (wrapped == null) return;
if (connection.IsFiltered(note)) return; if (connection.IsFiltered(note)) return;
await using var scope = connection.ScopeFactory.CreateAsyncScope(); await using var scope = connection.GetAsyncServiceScope();
var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>(); var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>();
var data = new NoteRenderer.NoteRendererDto { Filters = connection.Filters.ToList() }; var data = new NoteRenderer.NoteRendererDto { Filters = connection.Filters.ToList() };

View file

@ -22,7 +22,7 @@ public class UserChannel(WebSocketConnection connection, bool notificationsOnly)
if (IsSubscribed) return; if (IsSubscribed) return;
IsSubscribed = true; IsSubscribed = true;
await using var scope = connection.ScopeFactory.CreateAsyncScope(); await using var scope = connection.GetAsyncServiceScope();
await using var db = scope.ServiceProvider.GetRequiredService<DatabaseContext>(); await using var db = scope.ServiceProvider.GetRequiredService<DatabaseContext>();
if (!notificationsOnly) if (!notificationsOnly)
@ -101,7 +101,7 @@ public class UserChannel(WebSocketConnection connection, bool notificationsOnly)
if (wrapped == null) return; if (wrapped == null) return;
if (connection.IsFiltered(note)) return; if (connection.IsFiltered(note)) return;
if (note.CreatedAt < DateTime.UtcNow - TimeSpan.FromMinutes(5)) return; if (note.CreatedAt < DateTime.UtcNow - TimeSpan.FromMinutes(5)) return;
await using var scope = connection.ScopeFactory.CreateAsyncScope(); await using var scope = connection.GetAsyncServiceScope();
if (await connection.IsMutedThreadAsync(note, scope)) return; if (await connection.IsMutedThreadAsync(note, scope)) return;
var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>(); var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>();
@ -128,7 +128,7 @@ public class UserChannel(WebSocketConnection connection, bool notificationsOnly)
var wrapped = IsApplicable(note); var wrapped = IsApplicable(note);
if (wrapped == null) return; if (wrapped == null) return;
if (connection.IsFiltered(note)) return; if (connection.IsFiltered(note)) return;
await using var scope = connection.ScopeFactory.CreateAsyncScope(); await using var scope = connection.GetAsyncServiceScope();
var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>(); var renderer = scope.ServiceProvider.GetRequiredService<NoteRenderer>();
var intermediate = await renderer.RenderAsync(note, connection.Token.User); var intermediate = await renderer.RenderAsync(note, connection.Token.User);
@ -174,7 +174,7 @@ public class UserChannel(WebSocketConnection connection, bool notificationsOnly)
if (!IsApplicable(notification)) return; if (!IsApplicable(notification)) return;
if (IsFiltered(notification)) return; if (IsFiltered(notification)) return;
await using var scope = connection.ScopeFactory.CreateAsyncScope(); await using var scope = connection.GetAsyncServiceScope();
if (notification.Note != null && await connection.IsMutedThreadAsync(notification.Note, scope, true)) if (notification.Note != null && await connection.IsMutedThreadAsync(notification.Note, scope, true))
return; return;

View file

@ -8,6 +8,7 @@ using Iceshrimp.Backend.Core.Database.Tables;
using Iceshrimp.Backend.Core.Events; using Iceshrimp.Backend.Core.Events;
using Iceshrimp.Backend.Core.Extensions; using Iceshrimp.Backend.Core.Extensions;
using Iceshrimp.Backend.Core.Helpers; using Iceshrimp.Backend.Core.Helpers;
using Iceshrimp.Backend.Core.Helpers.LibMfm.Conversion;
using Iceshrimp.Backend.Core.Services; using Iceshrimp.Backend.Core.Services;
using JetBrains.Annotations; using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore;
@ -32,7 +33,6 @@ public sealed class WebSocketConnection(
public readonly WriteLockingList<Filter> Filters = []; public readonly WriteLockingList<Filter> Filters = [];
public readonly EventService EventService = eventSvc; public readonly EventService EventService = eventSvc;
public readonly IServiceScope Scope = scopeFactory.CreateScope(); public readonly IServiceScope Scope = scopeFactory.CreateScope();
public readonly IServiceScopeFactory ScopeFactory = scopeFactory;
public readonly OauthToken Token = token; public readonly OauthToken Token = token;
public HashSet<string> HiddenFromHome = []; public HashSet<string> HiddenFromHome = [];
@ -57,6 +57,8 @@ public sealed class WebSocketConnection(
public void InitializeStreamingWorker() public void InitializeStreamingWorker()
{ {
InitializeScopeLocalParameters(Scope);
_channels.Add(new ListChannel(this)); _channels.Add(new ListChannel(this));
_channels.Add(new DirectChannel(this)); _channels.Add(new DirectChannel(this));
_channels.Add(new UserChannel(this, true)); _channels.Add(new UserChannel(this, true));
@ -343,7 +345,7 @@ public sealed class WebSocketConnection(
if (list.UserId != Token.User.Id) return; if (list.UserId != Token.User.Id) return;
if (!list.HideFromHomeTl) return; if (!list.HideFromHomeTl) return;
await using var scope = ScopeFactory.CreateAsyncScope(); await using var scope = GetAsyncServiceScope();
var db = scope.ServiceProvider.GetRequiredService<DatabaseContext>(); var db = scope.ServiceProvider.GetRequiredService<DatabaseContext>();
HiddenFromHome = await db.UserListMembers HiddenFromHome = await db.UserListMembers
@ -383,6 +385,19 @@ public sealed class WebSocketConnection(
return await db.NoteThreadMutings.AnyAsync(p => p.UserId == Token.UserId && p.ThreadId == note.ThreadId); return await db.NoteThreadMutings.AnyAsync(p => p.UserId == Token.UserId && p.ThreadId == note.ThreadId);
} }
public AsyncServiceScope GetAsyncServiceScope()
{
var scope = scopeFactory.CreateAsyncScope();
InitializeScopeLocalParameters(scope);
return scope;
}
private void InitializeScopeLocalParameters(IServiceScope scope)
{
var mfmConverter = scope.ServiceProvider.GetRequiredService<MfmConverter>();
mfmConverter.SupportsHtmlFormatting.Value = Token.SupportsHtmlFormatting;
}
public async Task CloseAsync(WebSocketCloseStatus status) public async Task CloseAsync(WebSocketCloseStatus status)
{ {
Dispose(); Dispose();