[backend/federation] Fix possibly unbounded UserResolver recursion

This commit is contained in:
Laura Hausmann 2024-08-14 03:44:14 +02:00
parent 92f957a536
commit 4f98fa8461
No known key found for this signature in database
GPG key ID: D044E84C5BE01605
10 changed files with 48 additions and 47 deletions

View file

@ -604,7 +604,7 @@ public class AccountController(
[ProducesErrors(HttpStatusCode.NotFound)] [ProducesErrors(HttpStatusCode.NotFound)]
public async Task<AccountEntity> LookupUser([FromQuery] string acct) public async Task<AccountEntity> LookupUser([FromQuery] string acct)
{ {
var user = await userResolver.LookupAsync(acct) ?? throw GracefulException.RecordNotFound(); var user = await userResolver.LookupAsync(acct, false) ?? throw GracefulException.RecordNotFound();
return await userRenderer.RenderAsync(user); return await userRenderer.RenderAsync(user);
} }

View file

@ -86,7 +86,7 @@ public class SearchController(
if (pagination.Offset is not null and not 0) return []; if (pagination.Offset is not null and not 0) return [];
try try
{ {
var result = await userResolver.ResolveAsync(search.Query); var result = await userResolver.ResolveAsync(search.Query, false);
return [await userRenderer.RenderAsync(result)]; return [await userRenderer.RenderAsync(result)];
} }
catch catch
@ -118,7 +118,7 @@ public class SearchController(
try try
{ {
var result = await userResolver.ResolveAsync($"@{username}@{host}"); var result = await userResolver.ResolveAsync($"@{username}@{host}", false);
return [await userRenderer.RenderAsync(result)]; return [await userRenderer.RenderAsync(result)];
} }
catch catch

View file

@ -87,7 +87,7 @@ public class SearchController(
if (target.StartsWith('@') || target.StartsWith(userPrefixAlt)) if (target.StartsWith('@') || target.StartsWith(userPrefixAlt))
{ {
var hit = await userResolver.ResolveAsyncOrNull(target); var hit = await userResolver.ResolveAsyncOrNull(target, false);
if (hit != null) return new RedirectResponse { TargetUrl = $"/users/{hit.Id}" }; if (hit != null) return new RedirectResponse { TargetUrl = $"/users/{hit.Id}" };
throw GracefulException.NotFound("No result found"); throw GracefulException.NotFound("No result found");
} }
@ -125,7 +125,7 @@ public class SearchController(
noteHit = await noteSvc.ResolveNoteAsync(target); noteHit = await noteSvc.ResolveNoteAsync(target);
if (noteHit != null) return new RedirectResponse { TargetUrl = $"/notes/{noteHit.Id}" }; if (noteHit != null) return new RedirectResponse { TargetUrl = $"/notes/{noteHit.Id}" };
userHit = await userResolver.ResolveAsyncOrNull(target); userHit = await userResolver.ResolveAsyncOrNull(target, false);
if (userHit != null) return new RedirectResponse { TargetUrl = $"/users/{userHit.Id}" }; if (userHit != null) return new RedirectResponse { TargetUrl = $"/users/{userHit.Id}" };
throw GracefulException.NotFound("No result found"); throw GracefulException.NotFound("No result found");

View file

@ -112,9 +112,8 @@ public class AcceptHeaderOutputFormatterSelector(
) : OutputFormatterSelector ) : OutputFormatterSelector
{ {
private readonly DefaultOutputFormatterSelector _fallbackSelector = new(options, loggerFactory); private readonly DefaultOutputFormatterSelector _fallbackSelector = new(options, loggerFactory);
private readonly List<IOutputFormatter> _formatters = [..options.Value.OutputFormatters];
public override IOutputFormatter? SelectFormatter( public override IOutputFormatter SelectFormatter(
OutputFormatterCanWriteContext context, IList<IOutputFormatter> formatters, MediaTypeCollection mediaTypes OutputFormatterCanWriteContext context, IList<IOutputFormatter> formatters, MediaTypeCollection mediaTypes
) )
{ {

View file

@ -39,7 +39,7 @@ public class ActivityHandlerService(
if (activity.Object == null && activity is not ASBite) if (activity.Object == null && activity is not ASBite)
throw GracefulException.UnprocessableEntity("Activity object is null"); throw GracefulException.UnprocessableEntity("Activity object is null");
var resolvedActor = await userResolver.ResolveAsync(activity.Actor.Id); var resolvedActor = await userResolver.ResolveAsync(activity.Actor.Id, true);
if (authenticatedUserId == null) if (authenticatedUserId == null)
throw GracefulException.UnprocessableEntity("Refusing to process activity without authenticatedUserId"); throw GracefulException.UnprocessableEntity("Refusing to process activity without authenticatedUserId");
@ -157,7 +157,7 @@ public class ActivityHandlerService(
if (activity.Object is not ASActor obj) if (activity.Object is not ASActor obj)
throw GracefulException.UnprocessableEntity("Follow activity object is invalid"); throw GracefulException.UnprocessableEntity("Follow activity object is invalid");
var followee = await userResolver.ResolveAsync(obj.Id); var followee = await userResolver.ResolveAsync(obj.Id, true);
if (followee.IsRemoteUser) if (followee.IsRemoteUser)
throw GracefulException.UnprocessableEntity("Cannot process follow for remote followee"); throw GracefulException.UnprocessableEntity("Cannot process follow for remote followee");
@ -223,7 +223,7 @@ public class ActivityHandlerService(
if (follow is not { Actor: not null }) if (follow is not { Actor: not null })
throw GracefulException.UnprocessableEntity("Refusing to reject object with invalid follow object"); throw GracefulException.UnprocessableEntity("Refusing to reject object with invalid follow object");
var resolvedFollower = await userResolver.ResolveAsync(follow.Actor.Id); var resolvedFollower = await userResolver.ResolveAsync(follow.Actor.Id, true);
if (resolvedFollower is not { IsLocalUser: true }) if (resolvedFollower is not { IsLocalUser: true })
throw GracefulException.UnprocessableEntity("Refusing to reject remote follow"); throw GracefulException.UnprocessableEntity("Refusing to reject remote follow");
if (resolvedActor.Uri == null) if (resolvedActor.Uri == null)
@ -355,7 +355,7 @@ public class ActivityHandlerService(
Uri = activity.Id, Uri = activity.Id,
User = resolvedActor, User = resolvedActor,
UserHost = resolvedActor.Host, UserHost = resolvedActor.Host,
TargetUser = await userResolver.ResolveAsync(targetActor.Id) TargetUser = await userResolver.ResolveAsync(targetActor.Id, true)
}, },
ASNote targetNote => new Bite ASNote targetNote => new Bite
{ {
@ -385,7 +385,7 @@ public class ActivityHandlerService(
Uri = activity.Id, Uri = activity.Id,
User = resolvedActor, User = resolvedActor,
UserHost = resolvedActor.Host, UserHost = resolvedActor.Host,
TargetUser = await userResolver.ResolveAsync(activity.To.Id) TargetUser = await userResolver.ResolveAsync(activity.To.Id, true)
}, },
_ => throw GracefulException.UnprocessableEntity($"Invalid bite target {target.Id} with type {target.Type}") _ => throw GracefulException.UnprocessableEntity($"Invalid bite target {target.Id} with type {target.Type}")
@ -479,7 +479,7 @@ public class ActivityHandlerService(
private async Task UnfollowAsync(ASActor followeeActor, User follower) private async Task UnfollowAsync(ASActor followeeActor, User follower)
{ {
//TODO: send reject? or do we not want to copy that part of the old ap core //TODO: send reject? or do we not want to copy that part of the old ap core
var followee = await userResolver.ResolveAsync(followeeActor.Id); var followee = await userResolver.ResolveAsync(followeeActor.Id, true);
await db.FollowRequests.Where(p => p.Follower == follower && p.Followee == followee).ExecuteDeleteAsync(); await db.FollowRequests.Where(p => p.Follower == follower && p.Followee == followee).ExecuteDeleteAsync();

View file

@ -192,21 +192,23 @@ public class UserResolver(
return query; return query;
} }
public async Task<User> ResolveAsync(string username, string? host) public async Task<User> ResolveAsync(string username, string? host, bool skipUpdate)
{ {
return host != null ? await ResolveAsync($"acct:{username}@{host}") : await ResolveAsync($"acct:{username}"); return host != null
? await ResolveAsync($"acct:{username}@{host}", skipUpdate)
: await ResolveAsync($"acct:{username}", skipUpdate);
} }
public async Task<User?> LookupAsync(string query) public async Task<User?> LookupAsync(string query, bool skipUpdate)
{ {
query = NormalizeQuery(query); query = NormalizeQuery(query);
var user = await userSvc.GetUserFromQueryAsync(query); var user = await userSvc.GetUserFromQueryAsync(query);
if (user != null) if (user != null)
return await GetUpdatedUser(user); return skipUpdate ? user : await GetUpdatedUser(user);
return user; return user;
} }
public async Task<User> ResolveAsync(string query) public async Task<User> ResolveAsync(string query, bool skipUpdate)
{ {
query = NormalizeQuery(query); query = NormalizeQuery(query);
@ -217,16 +219,16 @@ public class UserResolver(
// First, let's see if we already know the user // First, let's see if we already know the user
var user = await userSvc.GetUserFromQueryAsync(query); var user = await userSvc.GetUserFromQueryAsync(query);
if (user != null) if (user != null)
return await GetUpdatedUser(user); return skipUpdate ? user : await GetUpdatedUser(user);
// We don't, so we need to run WebFinger // We don't, so we need to run WebFinger
var (acct, uri) = await WebFingerAsync(query); var (acct, uri) = await WebFingerAsync(query);
// Check the database again with the new data // Check the database again with the new data
if (uri != query) user = await userSvc.GetUserFromQueryAsync(uri); if (uri != query) user = await userSvc.GetUserFromQueryAsync(uri);
if (user == null && acct != query) await userSvc.GetUserFromQueryAsync(acct); if (user == null && acct != query) user = await userSvc.GetUserFromQueryAsync(acct);
if (user != null) if (user != null)
return await GetUpdatedUser(user); return skipUpdate ? user : await GetUpdatedUser(user);
using (await KeyedLocker.LockAsync(uri)) using (await KeyedLocker.LockAsync(uri))
{ {
@ -235,7 +237,7 @@ public class UserResolver(
} }
} }
public async Task<User?> ResolveAsync(string query, bool onlyExisting) public async Task<User?> ResolveAsync(string query, bool onlyExisting, bool skipUpdate)
{ {
query = NormalizeQuery(query); query = NormalizeQuery(query);
@ -246,7 +248,7 @@ public class UserResolver(
// First, let's see if we already know the user // First, let's see if we already know the user
var user = await userSvc.GetUserFromQueryAsync(query); var user = await userSvc.GetUserFromQueryAsync(query);
if (user != null) if (user != null)
return await GetUpdatedUser(user); return skipUpdate ? user : await GetUpdatedUser(user);
if (onlyExisting) if (onlyExisting)
return null; return null;
@ -255,10 +257,10 @@ public class UserResolver(
var (acct, uri) = await WebFingerAsync(query); var (acct, uri) = await WebFingerAsync(query);
// Check the database again with the new data // Check the database again with the new data
if (uri != query) user = await userSvc.GetUserFromQueryAsync(uri); if (uri != query) user = await userSvc.GetUserFromQueryAsync(uri);
if (user == null && acct != query) await userSvc.GetUserFromQueryAsync(acct); if (user == null && acct != query) user = await userSvc.GetUserFromQueryAsync(acct);
if (user != null) if (user != null)
return await GetUpdatedUser(user); return skipUpdate ? user : await GetUpdatedUser(user);
using (await KeyedLocker.LockAsync(uri)) using (await KeyedLocker.LockAsync(uri))
{ {
@ -267,7 +269,7 @@ public class UserResolver(
} }
} }
public async Task<User?> ResolveAsyncOrNull(string username, string? host) public async Task<User?> ResolveAsyncOrNull(string username, string? host, bool skipUpdate)
{ {
try try
{ {
@ -276,7 +278,7 @@ public class UserResolver(
// First, let's see if we already know the user // First, let's see if we already know the user
var user = await userSvc.GetUserFromQueryAsync(query); var user = await userSvc.GetUserFromQueryAsync(query);
if (user != null) if (user != null)
return await GetUpdatedUser(user); return skipUpdate ? user : await GetUpdatedUser(user);
if (host == null) return null; if (host == null) return null;
@ -284,10 +286,10 @@ public class UserResolver(
var (acct, uri) = await WebFingerAsync(query); var (acct, uri) = await WebFingerAsync(query);
// Check the database again with the new data // Check the database again with the new data
if (uri != query) user = await userSvc.GetUserFromQueryAsync(uri); if (uri != query) user = await userSvc.GetUserFromQueryAsync(uri);
if (user == null && acct != query) await userSvc.GetUserFromQueryAsync(acct); if (user == null && acct != query) user = await userSvc.GetUserFromQueryAsync(acct);
if (user != null) if (user != null)
return await GetUpdatedUser(user); return skipUpdate ? user : await GetUpdatedUser(user);
using (await KeyedLocker.LockAsync(uri)) using (await KeyedLocker.LockAsync(uri))
{ {
@ -301,7 +303,7 @@ public class UserResolver(
} }
} }
public async Task<User?> ResolveAsyncOrNull(string query) public async Task<User?> ResolveAsyncOrNull(string query, bool skipUpdate)
{ {
try try
{ {
@ -310,7 +312,7 @@ public class UserResolver(
// First, let's see if we already know the user // First, let's see if we already know the user
var user = await userSvc.GetUserFromQueryAsync(query); var user = await userSvc.GetUserFromQueryAsync(query);
if (user != null) if (user != null)
return await GetUpdatedUser(user); return skipUpdate ? user : await GetUpdatedUser(user);
if (query.StartsWith($"https://{config.Value.WebDomain}/")) return null; if (query.StartsWith($"https://{config.Value.WebDomain}/")) return null;
@ -321,7 +323,7 @@ public class UserResolver(
if (resolvedUri != query) user = await userSvc.GetUserFromQueryAsync(resolvedUri); if (resolvedUri != query) user = await userSvc.GetUserFromQueryAsync(resolvedUri);
if (user == null && acct != query) await userSvc.GetUserFromQueryAsync(acct); if (user == null && acct != query) await userSvc.GetUserFromQueryAsync(acct);
if (user != null) if (user != null)
return await GetUpdatedUser(user); return skipUpdate ? user : await GetUpdatedUser(user);
using (await KeyedLocker.LockAsync(resolvedUri)) using (await KeyedLocker.LockAsync(resolvedUri))
{ {

View file

@ -66,7 +66,7 @@ public class AuthorizedFetchMiddleware(
{ {
try try
{ {
var user = await userResolver.ResolveAsync(sig.KeyId).WaitAsync(ct); var user = await userResolver.ResolveAsync(sig.KeyId, skipUpdate: true).WaitAsync(ct);
key = await db.UserPublickeys.Include(p => p.User) key = await db.UserPublickeys.Include(p => p.User)
.FirstOrDefaultAsync(p => p.User == user, ct); .FirstOrDefaultAsync(p => p.User == user, ct);

View file

@ -108,7 +108,7 @@ public class InboxValidationMiddleware(
{ {
try try
{ {
var user = await userResolver.ResolveAsync(sig.KeyId, activity is ASDelete).WaitAsync(ct); var user = await userResolver.ResolveAsync(sig.KeyId, activity is ASDelete, true).WaitAsync(ct);
if (user == null) throw AuthFetchException.NotFound("Delete activity actor is unknown"); if (user == null) throw AuthFetchException.NotFound("Delete activity actor is unknown");
key = await db.UserPublickeys.Include(p => p.User) key = await db.UserPublickeys.Include(p => p.User)
.FirstOrDefaultAsync(p => p.User == user, ct); .FirstOrDefaultAsync(p => p.User == user, ct);
@ -185,7 +185,7 @@ public class InboxValidationMiddleware(
if (key == null) if (key == null)
{ {
var user = await userResolver.ResolveAsync(activity.Actor.Id, activity is ASDelete) var user = await userResolver.ResolveAsync(activity.Actor.Id, activity is ASDelete, true)
.WaitAsync(ct); .WaitAsync(ct);
if (user == null) throw AuthFetchException.NotFound("Delete activity actor is unknown"); if (user == null) throw AuthFetchException.NotFound("Delete activity actor is unknown");
key = await db.UserPublickeys key = await db.UserPublickeys

View file

@ -177,7 +177,7 @@ public class NoteService(
if (asNote != null) if (asNote != null)
{ {
visibleUserIds = (await asNote.GetRecipients(user) visibleUserIds = (await asNote.GetRecipients(user)
.Select(userResolver.ResolveAsync) .Select(p => userResolver.ResolveAsync(p, true))
.AwaitAllNoConcurrencyAsync()) .AwaitAllNoConcurrencyAsync())
.Select(p => p.Id) .Select(p => p.Id)
.Concat(mentionedUserIds) .Concat(mentionedUserIds)
@ -480,7 +480,7 @@ public class NoteService(
if (asNote != null) if (asNote != null)
{ {
visibleUserIds = (await asNote.GetRecipients(note.User) visibleUserIds = (await asNote.GetRecipients(note.User)
.Select(userResolver.ResolveAsync) .Select(p => userResolver.ResolveAsync(p, true))
.AwaitAllNoConcurrencyAsync()) .AwaitAllNoConcurrencyAsync())
.Select(p => p.Id) .Select(p => p.Id)
.Concat(visibleUserIds) .Concat(visibleUserIds)
@ -907,7 +907,7 @@ public class NoteService(
{ {
try try
{ {
return await userResolver.ResolveAsync(p.Href!.Id!); return await userResolver.ResolveAsync(p.Href!.Id!, true);
} }
catch catch
{ {
@ -930,7 +930,7 @@ public class NoteService(
{ {
try try
{ {
return await userResolver.ResolveAsync(p.Acct); return await userResolver.ResolveAsync(p.Acct, true);
} }
catch catch
{ {
@ -1096,7 +1096,7 @@ public class NoteService(
if (res != null && !forceRefresh) return res; if (res != null && !forceRefresh) return res;
} }
var actor = await userResolver.ResolveAsync(attrTo.Id); var actor = await userResolver.ResolveAsync(attrTo.Id, true);
using (await KeyedLocker.LockAsync(uri)) using (await KeyedLocker.LockAsync(uri))
{ {

View file

@ -39,12 +39,12 @@ public class UserProfileMentionsResolver(ActivityPub.UserResolver userResolver,
var users = await mentionNodes var users = await mentionNodes
.DistinctBy(p => p.Acct) .DistinctBy(p => p.Acct)
.Select(async p => await userResolver.ResolveAsyncOrNull(p.Username, p.Host?.Value ?? host)) .Select(p => userResolver.ResolveAsyncOrNull(p.Username, p.Host?.Value ?? host, true))
.AwaitAllNoConcurrencyAsync(); .AwaitAllNoConcurrencyAsync();
users.AddRange(await userUris users.AddRange(await userUris
.Distinct() .Distinct()
.Select(async p => await userResolver.ResolveAsyncOrNull(p)) .Select(p => userResolver.ResolveAsyncOrNull(p, true))
.AwaitAllNoConcurrencyAsync()); .AwaitAllNoConcurrencyAsync());
var mentions = users.Where(p => p != null) var mentions = users.Where(p => p != null)
@ -78,11 +78,11 @@ public class UserProfileMentionsResolver(ActivityPub.UserResolver userResolver,
.Cast<string>() .Cast<string>()
.ToList(); .ToList();
var nodes = input.SelectMany(p => MfmParser.Parse(p)); var nodes = input.SelectMany(MfmParser.Parse);
var mentionNodes = EnumerateMentions(nodes); var mentionNodes = EnumerateMentions(nodes);
var users = await mentionNodes var users = await mentionNodes
.DistinctBy(p => p.Acct) .DistinctBy(p => p.Acct)
.Select(async p => await userResolver.ResolveAsyncOrNull(p.Username, p.Host?.Value ?? host)) .Select(p => userResolver.ResolveAsyncOrNull(p.Username, p.Host?.Value ?? host, true))
.AwaitAllNoConcurrencyAsync(); .AwaitAllNoConcurrencyAsync();
return users.Where(p => p != null) return users.Where(p => p != null)