diff --git a/Iceshrimp.Backend/Core/Extensions/ModelBinderProviderExtensions.cs b/Iceshrimp.Backend/Core/Extensions/ModelBinderProviderExtensions.cs index 69024768..e6633c83 100644 --- a/Iceshrimp.Backend/Core/Extensions/ModelBinderProviderExtensions.cs +++ b/Iceshrimp.Backend/Core/Extensions/ModelBinderProviderExtensions.cs @@ -24,9 +24,11 @@ public static class ModelBinderProviderExtensions var hybridProvider = new HybridModelBinderProvider(bodyProvider, complexProvider, dictionaryProvider); var customCollectionProvider = new CustomCollectionModelBinderProvider(collectionProvider); + var intToBoolProvider = new IntToBoolModelBinderProvider(); providers.Insert(0, hybridProvider); providers.Insert(1, customCollectionProvider); + providers.Insert(2, intToBoolProvider); } } @@ -52,6 +54,44 @@ public class HybridModelBinderProvider( } } +public class IntToBoolModelBinderProvider : IModelBinderProvider +{ + public IModelBinder? GetBinder(ModelBinderProviderContext context) + { + var isBool = context.Metadata.ModelType.IsAssignableFrom(typeof(bool)); + var isNullableBool = context.Metadata.ModelType.IsAssignableFrom(typeof(bool?)); + if (!isBool && !isNullableBool) return null; + + return new IntToBoolModelBinder(); + } +} + +public class IntToBoolModelBinder : IModelBinder +{ + public Task BindModelAsync(ModelBindingContext bindingContext) + { + var valueProviderResult = bindingContext.ValueProvider.GetValue(bindingContext.ModelName); + var value = valueProviderResult.FirstValue; + var error = $"{bindingContext.ModelName} must be one of: [true, false, 1, 0]."; + + if (bool.TryParse(value, out var boolValue)) + { + bindingContext.Result = ModelBindingResult.Success(boolValue); + } + else if (int.TryParse(value, out var intValue)) + { + if (intValue is 0 or 1) bindingContext.Result = ModelBindingResult.Success(intValue != 0); + else bindingContext.ModelState.TryAddModelError(bindingContext.ModelName, error); + } + else if (bindingContext.ModelMetadata.IsNullableValueType && value != null && string.IsNullOrWhiteSpace(value)) + { + bindingContext.Result = ModelBindingResult.Success(null); + } + + return Task.CompletedTask; + } +} + public class CustomCollectionModelBinderProvider(IModelBinderProvider provider) : IModelBinderProvider { public IModelBinder? GetBinder(ModelBinderProviderContext context)