diff --git a/Iceshrimp.Parsing/Mfm.fs b/Iceshrimp.Parsing/Mfm.fs index 7e70dbf2..fd648247 100644 --- a/Iceshrimp.Parsing/Mfm.fs +++ b/Iceshrimp.Parsing/Mfm.fs @@ -110,9 +110,13 @@ module MfmNodeTypes = type internal UserState = { ParenthesisStack: char list - LastLine: int64 } + LastLine: int64 + Depth: int64 } - static member Default = { ParenthesisStack = []; LastLine = 0 } + static member Default = + { ParenthesisStack = [] + LastLine = 0 + Depth = 0 } open MfmNodeTypes @@ -220,10 +224,10 @@ module private MfmParser = let clearParen = updateUserState <| fun u -> { u with ParenthesisStack = [] } + let (|GreaterEqualThan|_|) k value = if value >= k then Some() else None + // References - let node, nodeRef = createParserForwardedToRef () let inlineNode, inlineNodeRef = createParserForwardedToRef () - let simple, simpleRef = createParserForwardedToRef () let seqFlatten items = seq { @@ -451,43 +455,49 @@ module private MfmParser = let prefixedNode (m: ParseMode) : Parser = fun (stream: CharStream<_>) -> - match (stream.Peek(), m) with - // Block nodes, ordered by expected frequency - | '`', Full -> codeBlockNode <|> codeNode - | '\n', Full when stream.Match("\n```") -> codeBlockNode - | '\n', Full when stream.Match("\n\n```") -> codeBlockNode - | '>', Full -> quoteNode - | '<', Full when stream.Match "
" -> centerNode - | '\\', Full when stream.Match "\\[" -> mathBlockNode - // Inline nodes, ordered by expected frequency - | '*', (Full | Inline) -> italicAsteriskNode <|> boldAsteriskNode - | '_', (Full | Inline) -> italicUnderscoreNode <|> boldUnderscoreNode - | '@', (Full | Inline) -> mentionNode - | '#', (Full | Inline) -> hashtagNode - | '`', Inline -> codeNode - | 'h', (Full | Inline) when stream.Match "http" -> urlNode - | ':', (Full | Inline | Simple) -> emojiCodeNode - | '~', (Full | Inline) when stream.Match "~~" -> strikeNode - | '[', (Full | Inline) -> linkNode - | '<', (Full | Inline) -> choice inlineTagNodes - | '<', Simple when stream.Match "" -> plainNode - | '\\', (Full | Inline) when stream.Match "\\(" -> mathNode - | '$', (Full | Inline) when stream.Match "$[" -> fnNode - | '?', (Full | Inline) when stream.Match "?[" -> linkNode - // Fallback to char node - | _ -> charNode - <| stream + match stream.UserState.Depth with + | GreaterEqualThan 100L -> stream |> charNode + | _ -> + match (stream.Peek(), m) with + // Block nodes, ordered by expected frequency + | '`', Full -> codeBlockNode <|> codeNode + | '\n', Full when stream.Match("\n```") -> codeBlockNode + | '\n', Full when stream.Match("\n\n```") -> codeBlockNode + | '>', Full -> quoteNode + | '<', Full when stream.Match "
" -> centerNode + | '\\', Full when stream.Match "\\[" -> mathBlockNode + // Inline nodes, ordered by expected frequency + | '*', (Full | Inline) -> italicAsteriskNode <|> boldAsteriskNode + | '_', (Full | Inline) -> italicUnderscoreNode <|> boldUnderscoreNode + | '@', (Full | Inline) -> mentionNode + | '#', (Full | Inline) -> hashtagNode + | '`', Inline -> codeNode + | 'h', (Full | Inline) when stream.Match "http" -> urlNode + | ':', (Full | Inline | Simple) -> emojiCodeNode + | '~', (Full | Inline) when stream.Match "~~" -> strikeNode + | '[', (Full | Inline) -> linkNode + | '<', (Full | Inline) -> choice inlineTagNodes + | '<', Simple when stream.Match "" -> plainNode + | '\\', (Full | Inline) when stream.Match "\\(" -> mathNode + | '$', (Full | Inline) when stream.Match "$[" -> fnNode + | '?', (Full | Inline) when stream.Match "?[" -> linkNode + // Fallback to char node + | _ -> charNode + <| stream attempt <| prefixedNode m <|> charNode // Populate references - do nodeRef.Value <- parseNode Full - do inlineNodeRef.Value <- parseNode Inline |>> fun v -> v :?> MfmInlineNode - do simpleRef.Value <- parseNode Simple + let pushDepth = updateUserState (fun u -> { u with Depth = (u.Depth + 1L) }) + let popDepth = updateUserState (fun u -> { u with Depth = (u.Depth - 1L) }) + do inlineNodeRef.Value <- pushDepth >>. (parseNode Inline |>> fun v -> v :?> MfmInlineNode) .>> popDepth + + // Parser abstractions + let node = parseNode Full + let simple = parseNode Simple // Final parse command let parse = spaces >>. manyTill node eof .>> spaces - let parseSimple = spaces >>. manyTill simple eof .>> spaces open MfmParser diff --git a/Iceshrimp.Tests/Parsing/MfmTests.cs b/Iceshrimp.Tests/Parsing/MfmTests.cs index 67048216..f428c598 100644 --- a/Iceshrimp.Tests/Parsing/MfmTests.cs +++ b/Iceshrimp.Tests/Parsing/MfmTests.cs @@ -1,5 +1,7 @@ +using System.Text; using Iceshrimp.Backend.Core.Helpers.LibMfm.Serialization; using Iceshrimp.Parsing; +using Microsoft.FSharp.Collections; using static Iceshrimp.Parsing.MfmNodeTypes; using FSDict = System.Collections.Generic.Dictionary?>; @@ -365,7 +367,7 @@ public class MfmTests res.ToList().Should().Equal(expected, MfmNodeEqual); MfmSerializer.Serialize(res).Should().BeEquivalentTo(input); } - + [TestMethod] public void TestLinkSilent() { @@ -516,10 +518,50 @@ public class MfmTests
"""; + AssertionOptions.FormattingOptions.MaxDepth = 100; var res = Mfm.parse(input); MfmSerializer.Serialize(res).Should().BeEquivalentTo(input); } + [TestMethod] + public void TestFnRecursionLimit() + { + const int iterations = 150; + const int limit = 100; + + var input = GetMfm(iterations); + var result = Mfm.parse(input); + + // Closing brackets will be grouped at the end, since fn node parser isn't greedy + List expected = [GetExpected(iterations), new MfmTextNode(new string(']', iterations - limit))]; + + AssertionOptions.FormattingOptions.MaxDepth = 300; + AssertionOptions.FormattingOptions.MaxLines = 1000; + result.ToList().Should().Equal(expected, MfmNodeEqual); + MfmSerializer.Serialize(result).Should().BeEquivalentTo(input); + + return; + + string GetMfm(int count) + { + var sb = new StringBuilder(); + for (var i = 0; i < count; i++) + sb.Append("$[test "); + for (var i = 0; i < count; i++) + sb.Append(']'); + + return sb.ToString(); + } + + MfmInlineNode GetExpected(int count, int remaining = limit) + { + if (remaining <= 0) + return new MfmTextNode(GetMfm(count).TrimEnd(']')); + + return new MfmFnNode("test", null, [GetExpected(--count, --remaining)]); + } + } + private static bool MfmNodeEqual(MfmNode a, MfmNode b) { if (a.GetType() != b.GetType()) return false; @@ -662,4 +704,4 @@ public class MfmTests return obj.GetHashCode(); } } -} \ No newline at end of file +}