diff --git a/Iceshrimp.Backend/Controllers/Mastodon/StatusController.cs b/Iceshrimp.Backend/Controllers/Mastodon/StatusController.cs index 6bdcb811..fc235278 100644 --- a/Iceshrimp.Backend/Controllers/Mastodon/StatusController.cs +++ b/Iceshrimp.Backend/Controllers/Mastodon/StatusController.cs @@ -10,6 +10,7 @@ using Microsoft.AspNetCore.Cors; using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.RateLimiting; using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.Caching.Distributed; namespace Iceshrimp.Backend.Controllers.Mastodon; @@ -19,7 +20,12 @@ namespace Iceshrimp.Backend.Controllers.Mastodon; [EnableCors("mastodon")] [EnableRateLimiting("sliding")] [Produces("application/json")] -public class StatusController(DatabaseContext db, NoteRenderer noteRenderer, NoteService noteSvc) : Controller { +public class StatusController( + DatabaseContext db, + NoteRenderer noteRenderer, + NoteService noteSvc, + IDistributedCache cache +) : Controller { [HttpGet("{id}")] [Authenticate("read:statuses")] [Produces("application/json")] @@ -119,7 +125,25 @@ public class StatusController(DatabaseContext db, NoteRenderer noteRenderer, Not var user = HttpContext.GetUserOrFail(); //TODO: handle scheduled statuses - //TODO: handle Idempotency-Key + Request.Headers.TryGetValue("Idempotency-Key", out var idempotencyKeyHeader); + var idempotencyKey = idempotencyKeyHeader.FirstOrDefault(); + if (idempotencyKey != null) { + var hit = await cache.FetchAsync($"idempotency:{idempotencyKey}", TimeSpan.FromHours(24), + () => $"_:{HttpContext.TraceIdentifier}"); + + if (hit != $"_:{HttpContext.TraceIdentifier}") { + for (var i = 0; i <= 10; i++) { + if (!hit.StartsWith('_')) break; + await Task.Delay(100); + hit = await cache.GetAsync($"idempotency:{idempotencyKey}") + ?? throw new Exception("Idempotency key status disappeared in for loop"); + if (i >= 10) + throw GracefulException.RequestTimeout("Failed to resolve idempotency key note within 1000 ms"); + } + + return await GetNote(hit); + } + } if (request.Text == null && request.MediaIds is not { Count: > 0 } && request.Poll == null) throw GracefulException.BadRequest("Posts must have text, media or poll"); @@ -139,6 +163,10 @@ public class StatusController(DatabaseContext db, NoteRenderer noteRenderer, Not var note = await noteSvc.CreateNoteAsync(user, visibility, request.Text, request.Cw, reply, attachments: attachments); + + if (idempotencyKey != null) + await cache.SetAsync($"idempotency:{idempotencyKey}", note.Id, TimeSpan.FromHours(24)); + var res = await noteRenderer.RenderAsync(note, user); return Ok(res); diff --git a/Iceshrimp.Backend/Core/Extensions/DistributedCacheExtensions.cs b/Iceshrimp.Backend/Core/Extensions/DistributedCacheExtensions.cs index ec41d38d..0a5ff9cd 100644 --- a/Iceshrimp.Backend/Core/Extensions/DistributedCacheExtensions.cs +++ b/Iceshrimp.Backend/Core/Extensions/DistributedCacheExtensions.cs @@ -48,7 +48,13 @@ public static class DistributedCacheExtensions { await cache.SetAsync(key, fetched, ttl); return fetched; } - + + public static async Task FetchAsync( + this IDistributedCache cache, string key, TimeSpan ttl, Func fetcher + ) where T : class { + return await FetchAsync(cache, key, ttl, () => Task.FromResult(fetcher())); + } + public static async Task FetchAsyncValue( this IDistributedCache cache, string key, TimeSpan ttl, Func> fetcher ) where T : struct { @@ -59,6 +65,12 @@ public static class DistributedCacheExtensions { await cache.SetAsync(key, fetched, ttl); return fetched; } + + public static async Task FetchAsyncValue( + this IDistributedCache cache, string key, TimeSpan ttl, Func fetcher + ) where T : struct { + return await FetchAsyncValue(cache, key, ttl, () => Task.FromResult(fetcher())); + } public static async Task SetAsync(this IDistributedCache cache, string key, T data, TimeSpan ttl) { using var stream = new MemoryStream(); diff --git a/Iceshrimp.Backend/Core/Middleware/ErrorHandlerMiddleware.cs b/Iceshrimp.Backend/Core/Middleware/ErrorHandlerMiddleware.cs index b6be2679..735f1edb 100644 --- a/Iceshrimp.Backend/Core/Middleware/ErrorHandlerMiddleware.cs +++ b/Iceshrimp.Backend/Core/Middleware/ErrorHandlerMiddleware.cs @@ -125,6 +125,10 @@ public class GracefulException( public static GracefulException BadRequest(string message, string? details = null) { return new GracefulException(HttpStatusCode.BadRequest, message, details); } + + public static GracefulException RequestTimeout(string message, string? details = null) { + return new GracefulException(HttpStatusCode.RequestTimeout, message, details); + } public static GracefulException RecordNotFound() { return new GracefulException(HttpStatusCode.NotFound, "Record not found");