Iceshrimp.NET/Iceshrimp.Backend/Core/Services/CustomHttpClient.cs
2024-11-20 00:48:29 +01:00

374 lines
No EOL
12 KiB
C#

using System.Diagnostics;
using System.Net;
using System.Net.Sockets;
using Iceshrimp.Backend.Core.Configuration;
using Iceshrimp.Backend.Core.Extensions;
using Microsoft.Extensions.Options;
namespace Iceshrimp.Backend.Core.Services;
public class CustomHttpClient : HttpClient, IService<HttpClient>, ISingletonService
{
private static readonly FastFallback FastFallbackHandler = new();
private static readonly HttpMessageHandler InnerHandler = new SocketsHttpHandler
{
AutomaticDecompression = DecompressionMethods.All,
ConnectCallback = FastFallbackHandler.ConnectCallbackAsync,
PooledConnectionIdleTimeout = TimeSpan.FromMinutes(5),
PooledConnectionLifetime = TimeSpan.FromMinutes(60)
};
private static readonly HttpMessageHandler Handler = new RedirectHandler(InnerHandler);
public CustomHttpClient(
IOptions<Config.InstanceSection> options,
IOptionsMonitor<Config.SecuritySection> security,
ILoggerFactory loggerFactory
) : base(Handler)
{
// Configure HTTP client options
DefaultRequestHeaders.TryAddWithoutValidation("User-Agent", options.Value.UserAgent);
Timeout = TimeSpan.FromSeconds(30);
// Default to HTTP/2, but allow for down-negotiation to HTTP/1.1 or HTTP/1.0
DefaultRequestVersion = HttpVersion.Version20;
DefaultVersionPolicy = HttpVersionPolicy.RequestVersionOrLower;
// Protect against DoS attacks
MaxResponseContentBufferSize = 1024 * 1024; // 1MiB
// Configure FastFallback
FastFallbackHandler.Logger = loggerFactory.CreateLogger<FastFallback>();
FastFallbackHandler.Security = security;
}
// Adapted from https://github.com/KazWolfe/Dalamud/blob/767cc49ecb80e29dbdda2fa8329d3c3341c964fe/Dalamud/Networking/Http/HappyEyeballsCallback.cs
private class FastFallback(int connectionBackoff = 75)
{
public ILogger<FastFallback>? Logger { private get; set; }
public IOptionsMonitor<Config.SecuritySection>? Security { private get; set; }
private bool AllowLoopback => Security?.CurrentValue.AllowLoopback ?? false;
private bool AllowLocalIPv4 => Security?.CurrentValue.AllowLocalIPv4 ?? false;
private bool AllowLocalIPv6 => Security?.CurrentValue.AllowLocalIPv6 ?? false;
public async ValueTask<Stream> ConnectCallbackAsync(
SocketsHttpConnectionContext context, CancellationToken token
)
{
var sortedRecords = await GetSortedAddressesAsync(context.DnsEndPoint.Host, token);
var linkedToken = CancellationTokenSource.CreateLinkedTokenSource(token);
var tasks = new List<Task<(NetworkStream? stream, Exception? exception)>>();
var delayCts = CancellationTokenSource.CreateLinkedTokenSource(linkedToken.Token);
for (var i = 0; i < sortedRecords.Count; i++)
{
var record = sortedRecords[i];
if (record.IsIPv4MappedToIPv6)
record = record.MapToIPv4();
if (!AllowLoopback && record.IsLoopback())
{
Logger?.LogWarning("Refusing to connect to loopback address {address} due to possible SSRF",
record.ToString());
continue;
}
if (!AllowLocalIPv6 && record.IsLocalIPv6())
{
Logger?.LogWarning("Refusing to connect to local IPv6 address {address} due to possible SSRF",
record.ToString());
continue;
}
if (!AllowLocalIPv4 && record.IsLocalIPv4())
{
Logger?.LogWarning("Refusing to connect to local IPv4 address {address} due to possible SSRF",
record.ToString());
continue;
}
delayCts.CancelAfter(connectionBackoff * i);
var task = AttemptConnectionAsync(record, context.DnsEndPoint.Port, linkedToken.Token, delayCts.Token);
tasks.Add(task);
var nextDelayCts = CancellationTokenSource.CreateLinkedTokenSource(linkedToken.Token);
_ = task.ContinueWith(_ => { nextDelayCts.Cancel(); }, TaskContinuationOptions.OnlyOnFaulted);
delayCts = nextDelayCts;
}
if (tasks.Count == 0)
throw new Exception($"Can't connect to {context.DnsEndPoint.Host}: no candidate addresses remaining");
NetworkStream? stream = null;
Exception? lastException = null;
while (tasks.Count > 0 && stream == null)
{
var task = await Task.WhenAny(tasks).ConfigureAwait(false);
var res = await task;
tasks.Remove(task);
stream = res.stream;
lastException = res.exception;
}
if (stream == null)
{
throw lastException ??
new Exception("An unknown exception occured during fast fallback connection attempt");
}
await linkedToken.CancelAsync();
tasks.ForEach(task => { task.ContinueWith(_ => Task.CompletedTask, CancellationToken.None); });
return stream;
}
private static async Task<(NetworkStream? stream, Exception? exception)> AttemptConnectionAsync(
IPAddress address, int port, CancellationToken token, CancellationToken delayToken
)
{
try
{
await Task.Delay(-1, delayToken).ConfigureAwait(false);
}
catch (TaskCanceledException) { }
token.ThrowIfCancellationRequested();
var socket = new Socket(address.AddressFamily, SocketType.Stream, ProtocolType.Tcp) { NoDelay = true };
try
{
await socket.ConnectAsync(address, port, token).ConfigureAwait(false);
return (new NetworkStream(socket, true), null);
}
catch (Exception e)
{
socket.Dispose();
return (null, e);
}
}
private static async Task<List<IPAddress>> GetSortedAddressesAsync(string hostname, CancellationToken token)
{
// This method abuses DNS ordering and LINQ a bit. We can normally assume that addresses will be provided in
// the order the system wants to use. GroupBy will return its groups *in the order they're discovered*. Meaning,
// the first group created will always be the preferred group, and all other groups are in preference order.
// This means a straight zipper merge is nice and clean and gives us most -> least preferred, repeating.
var dnsRecords = await Dns.GetHostAddressesAsync(hostname, AddressFamily.Unspecified, token);
var groups = dnsRecords
.GroupBy(a => a.AddressFamily)
.Select(g => g.Select(v => v))
.ToArray();
return ZipperMerge(groups).ToList();
}
private static IEnumerable<TSource> ZipperMerge<TSource>(params IEnumerable<TSource>[] sources)
{
// Adapted from https://github.com/KazWolfe/Dalamud/blob/767cc49ecb80e29dbdda2fa8329d3c3341c964fe/Dalamud/Utility/Util.cs
var enumerators = new IEnumerator<TSource>[sources.Length];
try
{
for (var i = 0; i < sources.Length; i++)
{
enumerators[i] = sources[i].GetEnumerator();
}
var hasNext = new bool[enumerators.Length];
bool MoveNext()
{
var anyHasNext = false;
for (var i = 0; i < enumerators.Length; i++)
{
anyHasNext |= hasNext[i] = enumerators[i].MoveNext();
}
return anyHasNext;
}
while (MoveNext())
{
for (var i = 0; i < enumerators.Length; i++)
{
if (hasNext[i])
{
yield return enumerators[i].Current;
}
}
}
}
finally
{
foreach (var enumerator in enumerators)
{
enumerator.Dispose();
}
}
}
}
private class RedirectHandler : DelegatingHandler
{
public RedirectHandler(HttpMessageHandler innerHandler) : base(innerHandler)
{
var mostInnerHandler = innerHandler.GetMostInnerHandler();
SetupCustomAutoRedirect(mostInnerHandler);
}
private int MaxAutomaticRedirections { get; set; }
private bool InitialAutoRedirect { get; set; }
private void SetupCustomAutoRedirect(HttpMessageHandler? mostInnerHandler)
{
//Store the initial auto-redirect & max-auto-redirect values.
//Disabling auto-redirect and handle redirects manually.
try
{
switch (mostInnerHandler)
{
case HttpClientHandler hch:
InitialAutoRedirect = hch.AllowAutoRedirect;
MaxAutomaticRedirections = hch.MaxAutomaticRedirections;
hch.AllowAutoRedirect = false;
break;
case SocketsHttpHandler shh:
InitialAutoRedirect = shh.AllowAutoRedirect;
MaxAutomaticRedirections = shh.MaxAutomaticRedirections;
shh.AllowAutoRedirect = false;
break;
default:
Debug.WriteLine("[SetupCustomAutoRedirect] Unknown handler type: {0}",
mostInnerHandler?.GetType().FullName);
InitialAutoRedirect = true;
MaxAutomaticRedirections = 17;
break;
}
}
catch (Exception e)
{
Debug.WriteLine(e.Message);
InitialAutoRedirect = true;
MaxAutomaticRedirections = 17;
}
}
private bool IsRedirectAllowed(HttpRequestMessage request)
{
var value = request.GetAutoRedirect();
if (value == null)
return InitialAutoRedirect;
return value == true;
}
protected override async Task<HttpResponseMessage> SendAsync(
HttpRequestMessage request, CancellationToken cancellationToken
)
{
var redirectCount = 0;
var response = await base.SendAsync(request, cancellationToken).ConfigureAwait(false);
//Manual Redirect
//https://github.com/dotnet/runtime/blob/ccfe21882e4a2206ce49cd5b32d3eb3cab3e530f/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/RedirectHandler.cs
Uri? redirectUri;
while (IsRedirect(response) &&
IsRedirectAllowed(request) &&
(redirectUri = GetUriForRedirect(request.RequestUri!, response)) != null)
{
redirectCount++;
if (redirectCount > MaxAutomaticRedirections)
break;
response.Dispose();
// Clear the authorization header.
request.Headers.Authorization = null;
// Set up for the redirect
request.RequestUri = redirectUri;
if (RequestRequiresForceGet(response.StatusCode, request.Method))
{
request.Method = HttpMethod.Get;
request.Content = null;
if (request.Headers.TransferEncodingChunked == true)
request.Headers.TransferEncodingChunked = false;
}
// Issue the redirected request.
response = await base.SendAsync(request, cancellationToken).ConfigureAwait(false);
}
return response;
}
private static bool IsRedirect(HttpResponseMessage response)
{
switch (response.StatusCode)
{
case HttpStatusCode.MultipleChoices:
case HttpStatusCode.Moved:
case HttpStatusCode.Found:
case HttpStatusCode.SeeOther:
case HttpStatusCode.TemporaryRedirect:
case HttpStatusCode.PermanentRedirect:
return true;
default:
return false;
}
}
private static Uri? GetUriForRedirect(Uri requestUri, HttpResponseMessage response)
{
var location = response.Headers.Location;
if (location == null)
{
return null;
}
// Ensure the redirect location is an absolute URI.
if (!location.IsAbsoluteUri)
{
location = new Uri(requestUri, location);
}
// Per https://tools.ietf.org/html/rfc7231#section-7.1.2, a redirect location without a
// fragment should inherit the fragment from the original URI.
var requestFragment = requestUri.Fragment;
if (!string.IsNullOrEmpty(requestFragment))
{
var redirectFragment = location.Fragment;
if (string.IsNullOrEmpty(redirectFragment))
{
location = new UriBuilder(location) { Fragment = requestFragment }.Uri;
}
}
// Reject circular redirects
return location == requestUri ? null : location;
}
private static bool RequestRequiresForceGet(HttpStatusCode statusCode, HttpMethod requestMethod)
{
switch (statusCode)
{
case HttpStatusCode.Moved:
case HttpStatusCode.Found:
case HttpStatusCode.MultipleChoices:
return requestMethod == HttpMethod.Post;
case HttpStatusCode.SeeOther:
return requestMethod != HttpMethod.Get && requestMethod != HttpMethod.Head;
default:
return false;
}
}
}
}