diff --git a/Iceshrimp.Backend/Controllers/AuthController.cs b/Iceshrimp.Backend/Controllers/AuthController.cs index 3a2b6451..248be8fb 100644 --- a/Iceshrimp.Backend/Controllers/AuthController.cs +++ b/Iceshrimp.Backend/Controllers/AuthController.cs @@ -46,7 +46,7 @@ public class AuthController(DatabaseContext db, UserService userSvc) : Controlle [ProducesResponseType(StatusCodes.Status200OK, Type = typeof(AuthResponse))] [ProducesResponseType(StatusCodes.Status400BadRequest, Type = typeof(ErrorResponse))] [ProducesResponseType(StatusCodes.Status403Forbidden, Type = typeof(ErrorResponse))] - public async Task Login([FromBody] AuthRequest request) { + public async Task Login([FromBody] AuthRequest request, Session? session = null) { var user = await db.Users.FirstOrDefaultAsync(p => p.Host == null && p.UsernameLower == request.Username.ToLowerInvariant()); if (user == null) @@ -57,17 +57,17 @@ public class AuthController(DatabaseContext db, UserService userSvc) : Controlle if (!AuthHelpers.ComparePassword(request.Password, profile.Password)) return StatusCode(StatusCodes.Status403Forbidden); - var res = await db.AddAsync(new Session { - Id = IdHelpers.GenerateSlowflakeId(), - UserId = user.Id, - Active = !profile.TwoFactorEnabled, - CreatedAt = new DateTime(), - Token = CryptographyHelpers.GenerateRandomString(32) - }); - - var session = res.Entity; - await db.AddAsync(session); - await db.SaveChangesAsync(); + if (session == null) { + session = new Session { + Id = IdHelpers.GenerateSlowflakeId(), + UserId = user.Id, + Active = !profile.TwoFactorEnabled, + CreatedAt = new DateTime(), + Token = CryptographyHelpers.GenerateRandomString(32) + }; + await db.AddAsync(session); + await db.SaveChangesAsync(); + } return Ok(new AuthResponse { Status = session.Active ? AuthStatusEnum.Authenticated : AuthStatusEnum.TwoFactor,