[backend/masto-client] Fix concurrent DbContext access

This commit is contained in:
Laura Hausmann 2024-02-03 04:05:48 +01:00
parent 17884c4975
commit d3da11f827
No known key found for this signature in database
GPG key ID: D044E84C5BE01605
5 changed files with 18 additions and 10 deletions

View file

@ -1,7 +1,10 @@
using System.Collections;
using Iceshrimp.Backend.Controllers.Mastodon.Schemas.Entities; using Iceshrimp.Backend.Controllers.Mastodon.Schemas.Entities;
using Iceshrimp.Backend.Core.Configuration; using Iceshrimp.Backend.Core.Configuration;
using Iceshrimp.Backend.Core.Database.Tables; using Iceshrimp.Backend.Core.Database.Tables;
using Iceshrimp.Backend.Core.Extensions;
using Iceshrimp.Backend.Core.Helpers.LibMfm.Conversion; using Iceshrimp.Backend.Core.Helpers.LibMfm.Conversion;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
namespace Iceshrimp.Backend.Controllers.Mastodon.Renderers; namespace Iceshrimp.Backend.Controllers.Mastodon.Renderers;
@ -42,4 +45,8 @@ public class NoteRenderer(IOptions<Config.InstanceSection> config, UserRenderer
return res; return res;
} }
public async Task<IEnumerable<Status>> RenderManyAsync(IEnumerable<Note> notes) {
return await notes.Select(RenderAsync).AwaitAllAsync();
}
} }

View file

@ -44,7 +44,6 @@ public class UserRenderer(IOptions<Config.InstanceSection> config, DatabaseConte
} }
public async Task<Account> RenderAsync(User user) { public async Task<Account> RenderAsync(User user) {
var profile = await db.UserProfiles.FirstOrDefaultAsync(p => p.User == user); return await RenderAsync(user, user.UserProfile);
return await RenderAsync(user, profile);
} }
} }

View file

@ -10,6 +10,7 @@ namespace Iceshrimp.Backend.Core.Extensions;
public static class NoteQueryableExtensions { public static class NoteQueryableExtensions {
public static IQueryable<Note> WithIncludes(this IQueryable<Note> query) { public static IQueryable<Note> WithIncludes(this IQueryable<Note> query) {
return query.Include(p => p.User) return query.Include(p => p.User)
.ThenInclude(p => p.UserProfile)
.Include(p => p.Renote) .Include(p => p.Renote)
.ThenInclude(p => p != null ? p.User : null) .ThenInclude(p => p != null ? p.User : null)
.Include(p => p.Reply) .Include(p => p.Reply)
@ -34,7 +35,7 @@ public static class NoteQueryableExtensions {
.OrderByDescending(note => note.Id), .OrderByDescending(note => note.Id),
{ MinId: not null } => query.Where(note => note.Id.IsGreaterThan(p.MinId)).OrderBy(note => note.Id), { MinId: not null } => query.Where(note => note.Id.IsGreaterThan(p.MinId)).OrderBy(note => note.Id),
{ MaxId: not null } => query.Where(note => note.Id.IsLessThan(p.MaxId)).OrderByDescending(note => note.Id), { MaxId: not null } => query.Where(note => note.Id.IsLessThan(p.MaxId)).OrderByDescending(note => note.Id),
_ => query _ => query.OrderByDescending(note => note.Id)
}; };
return query.Take(Math.Min(p.Limit ?? defaultLimit, maxLimit)); return query.Take(Math.Min(p.Limit ?? defaultLimit, maxLimit));
@ -48,13 +49,9 @@ public static class NoteQueryableExtensions {
return query.Where(note => note.User == user || note.User.IsFollowedBy(user)); return query.Where(note => note.User == user || note.User.IsFollowedBy(user));
} }
public static IQueryable<Note> OrderByIdDesc(this IQueryable<Note> query) {
return query.OrderByDescending(note => note.Id);
}
public static async Task<IEnumerable<Status>> RenderAllForMastodonAsync( public static async Task<IEnumerable<Status>> RenderAllForMastodonAsync(
this IQueryable<Note> notes, NoteRenderer renderer) { this IQueryable<Note> notes, NoteRenderer renderer) {
var list = await notes.ToListAsync(); var list = await notes.ToListAsync();
return await list.Select(renderer.RenderAsync).AwaitAllAsync(); return await renderer.RenderManyAsync(list);
} }
} }

View file

@ -19,7 +19,11 @@ public class AuthenticationMiddleware(DatabaseContext db) : IMiddleware {
} }
var token = header[7..]; var token = header[7..];
var session = await db.Sessions.Include(p => p.User).FirstOrDefaultAsync(p => p.Token == token && p.Active); var session = await db.Sessions
.Include(p => p.User)
.ThenInclude(p => p.UserProfile)
.FirstOrDefaultAsync(p => p.Token == token && p.Active);
if (session == null) { if (session == null) {
await next(ctx); await next(ctx);
return; return;

View file

@ -22,6 +22,7 @@ public class OauthAuthenticationMiddleware(DatabaseContext db) : IMiddleware {
header = header[7..]; header = header[7..];
var token = await db.OauthTokens var token = await db.OauthTokens
.Include(p => p.User) .Include(p => p.User)
.ThenInclude(p => p.UserProfile)
.Include(p => p.App) .Include(p => p.App)
.FirstOrDefaultAsync(p => p.Token == header && p.Active); .FirstOrDefaultAsync(p => p.Token == header && p.Active);