using System.Diagnostics.CodeAnalysis; using EntityFramework.Exceptions.PostgreSQL; using EntityFrameworkCore.Projectables.Infrastructure; using Iceshrimp.Backend.Core.Configuration; using Iceshrimp.Backend.Core.Database.Tables; using Iceshrimp.Backend.Core.Extensions; using Microsoft.AspNetCore.DataProtection.EntityFrameworkCore; using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.Design; using Npgsql; namespace Iceshrimp.Backend.Core.Database; [SuppressMessage("ReSharper", "StringLiteralTypo")] [SuppressMessage("ReSharper", "IdentifierTypo")] public class DatabaseContext(DbContextOptions options) : DbContext(options), IDataProtectionKeyContext { public virtual DbSet AbuseUserReports { get; init; } = null!; public virtual DbSet Announcements { get; init; } = null!; public virtual DbSet AnnouncementReads { get; init; } = null!; public virtual DbSet Antennas { get; init; } = null!; public virtual DbSet AttestationChallenges { get; init; } = null!; public virtual DbSet Bites { get; init; } = null!; public virtual DbSet Blockings { get; init; } = null!; public virtual DbSet Channels { get; init; } = null!; public virtual DbSet ChannelFollowings { get; init; } = null!; public virtual DbSet ChannelNotePins { get; init; } = null!; public virtual DbSet Clips { get; init; } = null!; public virtual DbSet ClipNotes { get; init; } = null!; public virtual DbSet DriveFiles { get; init; } = null!; public virtual DbSet DriveFolders { get; init; } = null!; public virtual DbSet Emojis { get; init; } = null!; public virtual DbSet FollowRequests { get; init; } = null!; public virtual DbSet Followings { get; init; } = null!; public virtual DbSet GalleryLikes { get; init; } = null!; public virtual DbSet GalleryPosts { get; init; } = null!; public virtual DbSet Hashtags { get; init; } = null!; public virtual DbSet Instances { get; init; } = null!; public virtual DbSet Markers { get; init; } = null!; public virtual DbSet MessagingMessages { get; init; } = null!; public virtual DbSet Meta { get; init; } = null!; public virtual DbSet ModerationLogs { get; init; } = null!; public virtual DbSet Mutings { get; init; } = null!; public virtual DbSet Notes { get; init; } = null!; public virtual DbSet NoteThreads { get; init; } = null!; public virtual DbSet NoteBookmarks { get; init; } = null!; public virtual DbSet NoteEdits { get; init; } = null!; public virtual DbSet NoteLikes { get; init; } = null!; public virtual DbSet NoteReactions { get; init; } = null!; public virtual DbSet NoteThreadMutings { get; init; } = null!; public virtual DbSet NoteUnreads { get; init; } = null!; public virtual DbSet NoteWatchings { get; init; } = null!; public virtual DbSet Notifications { get; init; } = null!; public virtual DbSet OauthApps { get; init; } = null!; public virtual DbSet OauthTokens { get; init; } = null!; public virtual DbSet Pages { get; init; } = null!; public virtual DbSet PageLikes { get; init; } = null!; public virtual DbSet PasswordResetRequests { get; init; } = null!; public virtual DbSet Polls { get; init; } = null!; public virtual DbSet PollVotes { get; init; } = null!; public virtual DbSet PromoNotes { get; init; } = null!; public virtual DbSet PromoReads { get; init; } = null!; public virtual DbSet RegistrationInvites { get; init; } = null!; public virtual DbSet RegistryItems { get; init; } = null!; public virtual DbSet Relays { get; init; } = null!; public virtual DbSet RenoteMutings { get; init; } = null!; public virtual DbSet Sessions { get; init; } = null!; public virtual DbSet SwSubscriptions { get; init; } = null!; public virtual DbSet PushSubscriptions { get; init; } = null!; public virtual DbSet UsedUsernames { get; init; } = null!; public virtual DbSet Users { get; init; } = null!; public virtual DbSet UserGroups { get; init; } = null!; public virtual DbSet UserGroupInvitations { get; init; } = null!; public virtual DbSet UserGroupMembers { get; init; } = null!; public virtual DbSet UserKeypairs { get; init; } = null!; public virtual DbSet UserLists { get; init; } = null!; public virtual DbSet UserListMembers { get; init; } = null!; public virtual DbSet UserNotePins { get; init; } = null!; public virtual DbSet UserPendings { get; init; } = null!; public virtual DbSet UserProfiles { get; init; } = null!; public virtual DbSet UserPublickeys { get; init; } = null!; public virtual DbSet UserSecurityKeys { get; init; } = null!; public virtual DbSet UserSettings { get; init; } = null!; public virtual DbSet Webhooks { get; init; } = null!; public virtual DbSet AllowedInstances { get; init; } = null!; public virtual DbSet BlockedInstances { get; init; } = null!; public virtual DbSet MetaStore { get; init; } = null!; public virtual DbSet CacheStore { get; init; } = null!; public virtual DbSet Jobs { get; init; } = null!; public virtual DbSet Filters { get; init; } = null!; public virtual DbSet PluginStore { get; init; } = null!; public virtual DbSet PolicyConfiguration { get; init; } = null!; public virtual DbSet DataProtectionKeys { get; init; } = null!; public static NpgsqlDataSource GetDataSource(Config.DatabaseSection config) { var dataSourceBuilder = new NpgsqlDataSourceBuilder { ConnectionStringBuilder = { Host = config.Host, Port = config.Port, Username = config.Username, Password = config.Password, Database = config.Database, MaxPoolSize = config.MaxConnections, Multiplexing = config.Multiplexing, Options = "-c jit=off" } }; return ConfigureDataSource(dataSourceBuilder, config); } private static NpgsqlDataSource ConfigureDataSource( NpgsqlDataSourceBuilder dataSourceBuilder, Config.DatabaseSection config ) { dataSourceBuilder.MapEnum(); dataSourceBuilder.MapEnum(); dataSourceBuilder.MapEnum(); dataSourceBuilder.MapEnum(); dataSourceBuilder.MapEnum(); dataSourceBuilder.MapEnum(); dataSourceBuilder.MapEnum(); dataSourceBuilder.MapEnum(); dataSourceBuilder.MapEnum(); dataSourceBuilder.MapEnum(); dataSourceBuilder.MapEnum(); dataSourceBuilder.EnableDynamicJson(); if (config.ParameterLogging) dataSourceBuilder.EnableParameterLogging(); return dataSourceBuilder.Build(); } public static void Configure( DbContextOptionsBuilder optionsBuilder, NpgsqlDataSource dataSource, Config.DatabaseSection config ) { optionsBuilder.UseNpgsql(dataSource, options => { options.MapEnum("antenna_src_enum"); options.MapEnum("note_visibility_enum"); options.MapEnum("notification_type_enum"); options.MapEnum("page_visibility_enum"); options.MapEnum("relay_status_enum"); options.MapEnum("user_profile_ffvisibility_enum"); options.MapEnum("marker_type_enum"); options.MapEnum("push_subscription_policy_enum"); options.MapEnum("job_status"); options.MapEnum("filter_context_enum"); options.MapEnum("filter_action_enum"); }); optionsBuilder.UseProjectables(options => { options.CompatibilityMode(CompatibilityMode.Full); }); optionsBuilder.UseExceptionProcessor(); if (config.ParameterLogging) optionsBuilder.EnableSensitiveDataLogging(); } protected override void OnModelCreating(ModelBuilder modelBuilder) { modelBuilder .HasPostgresEnum() .HasPostgresEnum() .HasPostgresEnum() .HasPostgresEnum() .HasPostgresEnum() .HasPostgresEnum() .HasPostgresEnum() .HasPostgresEnum() .HasPostgresEnum() .HasPostgresEnum() .HasPostgresEnum() .HasPostgresExtension("pg_trgm"); modelBuilder .HasDbFunction(typeof(DatabaseContext).GetMethod(nameof(NoteAncestors), [typeof(string), typeof(int)])!) .HasName("note_ancestors"); modelBuilder .HasDbFunction(typeof(DatabaseContext).GetMethod(nameof(Conversations), [typeof(string)])!) .HasName("conversations"); modelBuilder .HasDbFunction(typeof(Note).GetMethod(nameof(Note.InternalRawAttachments), [typeof(string)])!) .HasName("note_attachments_raw"); modelBuilder.Entity().ToTable("data_protection_keys"); modelBuilder.ApplyConfigurationsFromAssembly(typeof(DatabaseContext).Assembly); } public async Task ReloadEntityAsync(object entity) { await Entry(entity).ReloadAsync(); } public async Task ReloadEntityRecursivelyAsync(object entity) { await ReloadEntityAsync(entity); await Entry(entity) .References.Where(p => p is { IsLoaded: true, TargetEntry: not null }) .Select(p => p.TargetEntry!.ReloadAsync()) .AwaitAllNoConcurrencyAsync(); } public IQueryable NoteAncestors(string noteId, int depth) => FromExpression(() => NoteAncestors(noteId, depth)); public IQueryable NoteAncestors(Note note, int depth) => FromExpression(() => NoteAncestors(note.Id, depth)); public IQueryable NoteDescendants(string noteId, int depth, int limit) => Notes.FromSql($""" SELECT * FROM note WHERE id IN ( WITH RECURSIVE search_tree(id, path) AS ( SELECT id, ARRAY[id]::VARCHAR[] FROM note WHERE id = {noteId} UNION ALL ( SELECT note.id, path || note.id FROM search_tree JOIN note ON note."replyId" = search_tree.id WHERE COALESCE(array_length(path, 1) < {depth + 1}, TRUE) AND NOT note.id = ANY(path) ) ) SELECT id FROM search_tree WHERE id <> {noteId} LIMIT {limit} ) """); public IQueryable NoteDescendants(Note note, int depth, int breadth) => NoteDescendants(note.Id, depth, breadth); public IQueryable Conversations(string userId) => FromExpression(() => Conversations(userId)); public IQueryable Conversations(User user) => FromExpression(() => Conversations(user.Id)); public IQueryable GetJob(string queue) => Database.SqlQuery($""" UPDATE "jobs" SET "status" = 'running', "started_at" = now() WHERE "id" = ( SELECT "id" FROM "jobs" WHERE queue = {queue} AND status = 'queued' ORDER BY COALESCE("delayed_until", "queued_at") LIMIT 1 FOR UPDATE SKIP LOCKED) RETURNING "jobs".*; """); public Task GetJobRunningCountAsync(string queue, CancellationToken token) => Jobs.CountAsync(p => p.Queue == queue && p.Status == Job.JobStatus.Running, token); public Task GetJobQueuedCountAsync(string queue, CancellationToken token) => Jobs.CountAsync(p => p.Queue == queue && p.Status == Job.JobStatus.Queued, token); public async Task IsDatabaseEmptyAsync() => !await Database.SqlQuery($""" select s.nspname from pg_class c join pg_namespace s on s.oid = c.relnamespace where s.nspname in ('public') """) .AnyAsync(); } [SuppressMessage("ReSharper", "UnusedType.Global", Justification = "Constructed using reflection by the dotnet-ef CLI tool")] public class DesignTimeDatabaseContextFactory : IDesignTimeDbContextFactory { DatabaseContext IDesignTimeDbContextFactory.CreateDbContext(string[] args) { var configuration = new ConfigurationBuilder() .SetBasePath(Directory.GetCurrentDirectory()) .AddCustomConfiguration() .Build(); var config = configuration.GetSection("Database").Get() ?? throw new Exception("Failed to initialize database: Failed to load configuration"); // Required to make `dotnet ef database update` work correctly config.Multiplexing = false; var dataSource = DatabaseContext.GetDataSource(config); var builder = new DbContextOptionsBuilder(); DatabaseContext.Configure(builder, dataSource, config); return new DatabaseContext(builder.Options); } }