[parsing/mfm] Limit inline node recursion to 100

This commit is contained in:
Laura Hausmann 2024-11-26 00:18:52 +01:00
parent 9036eacd98
commit aa593f78b8
No known key found for this signature in database
GPG key ID: D044E84C5BE01605
2 changed files with 88 additions and 36 deletions

View file

@ -110,9 +110,13 @@ module MfmNodeTypes =
type internal UserState = type internal UserState =
{ ParenthesisStack: char list { ParenthesisStack: char list
LastLine: int64 } LastLine: int64
Depth: int64 }
static member Default = { ParenthesisStack = []; LastLine = 0 } static member Default =
{ ParenthesisStack = []
LastLine = 0
Depth = 0 }
open MfmNodeTypes open MfmNodeTypes
@ -220,10 +224,10 @@ module private MfmParser =
let clearParen = updateUserState <| fun u -> { u with ParenthesisStack = [] } let clearParen = updateUserState <| fun u -> { u with ParenthesisStack = [] }
let (|GreaterEqualThan|_|) k value = if value >= k then Some() else None
// References // References
let node, nodeRef = createParserForwardedToRef ()
let inlineNode, inlineNodeRef = createParserForwardedToRef () let inlineNode, inlineNodeRef = createParserForwardedToRef ()
let simple, simpleRef = createParserForwardedToRef ()
let seqFlatten items = let seqFlatten items =
seq { seq {
@ -451,6 +455,9 @@ module private MfmParser =
let prefixedNode (m: ParseMode) : Parser<MfmNode, UserState> = let prefixedNode (m: ParseMode) : Parser<MfmNode, UserState> =
fun (stream: CharStream<_>) -> fun (stream: CharStream<_>) ->
match stream.UserState.Depth with
| GreaterEqualThan 100L -> stream |> charNode
| _ ->
match (stream.Peek(), m) with match (stream.Peek(), m) with
// Block nodes, ordered by expected frequency // Block nodes, ordered by expected frequency
| '`', Full -> codeBlockNode <|> codeNode | '`', Full -> codeBlockNode <|> codeNode
@ -481,13 +488,16 @@ module private MfmParser =
attempt <| prefixedNode m <|> charNode attempt <| prefixedNode m <|> charNode
// Populate references // Populate references
do nodeRef.Value <- parseNode Full let pushDepth = updateUserState (fun u -> { u with Depth = (u.Depth + 1L) })
do inlineNodeRef.Value <- parseNode Inline |>> fun v -> v :?> MfmInlineNode let popDepth = updateUserState (fun u -> { u with Depth = (u.Depth - 1L) })
do simpleRef.Value <- parseNode Simple do inlineNodeRef.Value <- pushDepth >>. (parseNode Inline |>> fun v -> v :?> MfmInlineNode) .>> popDepth
// Parser abstractions
let node = parseNode Full
let simple = parseNode Simple
// Final parse command // Final parse command
let parse = spaces >>. manyTill node eof .>> spaces let parse = spaces >>. manyTill node eof .>> spaces
let parseSimple = spaces >>. manyTill simple eof .>> spaces let parseSimple = spaces >>. manyTill simple eof .>> spaces
open MfmParser open MfmParser

View file

@ -1,5 +1,7 @@
using System.Text;
using Iceshrimp.Backend.Core.Helpers.LibMfm.Serialization; using Iceshrimp.Backend.Core.Helpers.LibMfm.Serialization;
using Iceshrimp.Parsing; using Iceshrimp.Parsing;
using Microsoft.FSharp.Collections;
using static Iceshrimp.Parsing.MfmNodeTypes; using static Iceshrimp.Parsing.MfmNodeTypes;
using FSDict = System.Collections.Generic.Dictionary<string, Microsoft.FSharp.Core.FSharpOption<string>?>; using FSDict = System.Collections.Generic.Dictionary<string, Microsoft.FSharp.Core.FSharpOption<string>?>;
@ -516,10 +518,50 @@ public class MfmTests
</center> </center>
"""; """;
AssertionOptions.FormattingOptions.MaxDepth = 100;
var res = Mfm.parse(input); var res = Mfm.parse(input);
MfmSerializer.Serialize(res).Should().BeEquivalentTo(input); MfmSerializer.Serialize(res).Should().BeEquivalentTo(input);
} }
[TestMethod]
public void TestFnRecursionLimit()
{
const int iterations = 150;
const int limit = 100;
var input = GetMfm(iterations);
var result = Mfm.parse(input);
// Closing brackets will be grouped at the end, since fn node parser isn't greedy
List<MfmNode> expected = [GetExpected(iterations), new MfmTextNode(new string(']', iterations - limit))];
AssertionOptions.FormattingOptions.MaxDepth = 300;
AssertionOptions.FormattingOptions.MaxLines = 1000;
result.ToList().Should().Equal(expected, MfmNodeEqual);
MfmSerializer.Serialize(result).Should().BeEquivalentTo(input);
return;
string GetMfm(int count)
{
var sb = new StringBuilder();
for (var i = 0; i < count; i++)
sb.Append("$[test ");
for (var i = 0; i < count; i++)
sb.Append(']');
return sb.ToString();
}
MfmInlineNode GetExpected(int count, int remaining = limit)
{
if (remaining <= 0)
return new MfmTextNode(GetMfm(count).TrimEnd(']'));
return new MfmFnNode("test", null, [GetExpected(--count, --remaining)]);
}
}
private static bool MfmNodeEqual(MfmNode a, MfmNode b) private static bool MfmNodeEqual(MfmNode a, MfmNode b)
{ {
if (a.GetType() != b.GetType()) return false; if (a.GetType() != b.GetType()) return false;