From 550e3cc229d3d29e58037fc23cd34158f9d8bd6e Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Tue, 14 Jan 2025 12:31:10 -0500 Subject: [PATCH 01/10] Add multiaddr expression group matching Support captures export some things wip thinking about public API Think about exposing meg as a public API doc comments Finish rename Add helper for meg and add test add comment for devs --- component.go | 7 ++ meg/meg.go | 126 ++++++++++++++++++++++++++ meg/meg_test.go | 233 ++++++++++++++++++++++++++++++++++++++++++++++++ meg/sugar.go | 130 +++++++++++++++++++++++++++ meg_test.go | 45 ++++++++++ util.go | 7 ++ 6 files changed, 548 insertions(+) create mode 100644 meg/meg.go create mode 100644 meg/meg_test.go create mode 100644 meg/sugar.go create mode 100644 meg_test.go diff --git a/component.go b/component.go index 6e8a640..c9b618f 100644 --- a/component.go +++ b/component.go @@ -162,6 +162,13 @@ func (c *Component) Protocol() Protocol { return *c.protocol } +func (c *Component) Code() int { + if c == nil { + return 0 + } + return c.Protocol().Code +} + func (c *Component) RawValue() []byte { if c == nil { return nil diff --git a/meg/meg.go b/meg/meg.go new file mode 100644 index 0000000..b3b6db4 --- /dev/null +++ b/meg/meg.go @@ -0,0 +1,126 @@ +// package meg implements Regular Expressions for multiaddr Components. It's short for "Megular Expressions" +package meg + +// The developer is assumed to be familiar with the Thompson NFA approach to +// regex before making changes to this file. Refer to +// https://swtch.com/~rsc/regexp/regexp1.html for an introduction. + +import ( + "fmt" + "slices" +) + +type stateKind uint8 + +const ( + matchCode stateKind = iota + split + done +) + +// MatchState is the Thompson NFA for a regular expression. +type MatchState struct { + capture captureFunc + next *MatchState + nextSplit *MatchState + + kind stateKind + generation int + code int +} + +type captureFunc *func(string) error +type captureMap map[captureFunc][]string + +func (cm captureMap) clone() captureMap { + out := make(captureMap, len(cm)) + for k, v := range cm { + out[k] = slices.Clone(v) + } + return out +} + +type statesAndCaptures struct { + states []*MatchState + captures []captureMap +} + +func (s *MatchState) String() string { + return fmt.Sprintf("state{kind: %d, generation: %d, code: %d}", s.kind, s.generation, s.code) +} + +type Matchable interface { + Code() int + Value() string // Used when capturing the value +} + +// Match returns whether the given Components match the Pattern defined in MatchState. +// Errors are used to communicate capture errors. +// If the error is non-nil the returned bool will be false. +func Match[S ~[]T, T Matchable](s *MatchState, components S) (bool, error) { + listGeneration := s.generation + 1 // Start at the last generation + 1 + defer func() { s.generation = listGeneration }() // In case we reuse this state, store our highest generation number + + currentStates := statesAndCaptures{ + states: make([]*MatchState, 0, 16), + captures: make([]captureMap, 0, 16), + } + nextStates := statesAndCaptures{ + states: make([]*MatchState, 0, 16), + captures: make([]captureMap, 0, 16), + } + + currentStates = appendState(currentStates, s, nil, listGeneration) + + for _, c := range components { + if len(currentStates.states) == 0 { + return false, nil + } + for i, s := range currentStates.states { + if s.kind == matchCode && s.code == c.Code() { + cm := currentStates.captures[i] + if s.capture != nil { + cm[s.capture] = append(cm[s.capture], c.Value()) + } + nextStates = appendState(nextStates, s.next, currentStates.captures[i], listGeneration) + } + } + currentStates, nextStates = nextStates, currentStates + nextStates.states = nextStates.states[:0] + nextStates.captures = nextStates.captures[:0] + listGeneration++ + } + + for i, s := range currentStates.states { + if s.kind == done { + // We found a complete path. Run the captures now + for f, v := range currentStates.captures[i] { + for _, s := range v { + if err := (*f)(s); err != nil { + return false, err + } + } + } + return true, nil + } + } + return false, nil +} + +func appendState(arr statesAndCaptures, s *MatchState, c captureMap, listGeneration int) statesAndCaptures { + if s == nil || s.generation == listGeneration { + return arr + } + if c == nil { + c = make(captureMap) + } + s.generation = listGeneration + if s.kind == split { + arr = appendState(arr, s.next, c, listGeneration) + arr = appendState(arr, s.nextSplit, c.clone(), listGeneration) + } else { + arr.states = append(arr.states, s) + arr.captures = append(arr.captures, c) + } + return arr +} diff --git a/meg/meg_test.go b/meg/meg_test.go new file mode 100644 index 0000000..b47ba02 --- /dev/null +++ b/meg/meg_test.go @@ -0,0 +1,233 @@ +package meg + +import ( + "regexp" + "slices" + "testing" + "testing/quick" +) + +type codeAndValue struct { + code int + val string // Uses the string type to ensure immutability. +} + +// Code implements Matchable. +func (c codeAndValue) Code() int { + return c.code +} + +// Value implements Matchable. +func (c codeAndValue) Value() string { + return c.val +} + +var _ Matchable = codeAndValue{} + +func TestSimple(t *testing.T) { + type testCase struct { + pattern *MatchState + skipQuickCheck bool + shouldMatch [][]int + shouldNotMatch [][]int + } + testCases := + []testCase{ + { + pattern: PatternToMatchState(Val(0), Val(1)), + shouldMatch: [][]int{{0, 1}}, + shouldNotMatch: [][]int{ + {0}, + {0, 0}, + {0, 1, 0}, + }}, { + pattern: PatternToMatchState(Val(0), Val(1), Optional(Val(2))), + shouldMatch: [][]int{ + {0, 1, 2}, + {0, 1}, + }, + shouldNotMatch: [][]int{ + {0}, + {0, 0}, + {0, 1, 0}, + {0, 1, 2, 0}, + }}, { + pattern: PatternToMatchState(Val(0), Val(1), OneOrMore(2)), + skipQuickCheck: true, + shouldMatch: [][]int{ + {0, 1, 2, 2, 2, 2}, + {0, 1, 2}, + }, + shouldNotMatch: [][]int{ + {0}, + {0, 0}, + {0, 1}, + {0, 1, 0}, + {0, 1, 1, 0}, + {0, 1, 2, 0}, + }}, + } + + for i, tc := range testCases { + for _, m := range tc.shouldMatch { + if matches, _ := Match(tc.pattern, codesToCodeAndValue(m)); !matches { + t.Fatalf("failed to match %v with %s. idx=%d", m, tc.pattern, i) + } + } + for _, m := range tc.shouldNotMatch { + if matches, _ := Match(tc.pattern, codesToCodeAndValue(m)); matches { + t.Fatalf("failed to not match %v with %s. idx=%d", m, tc.pattern, i) + } + } + if tc.skipQuickCheck { + continue + } + if err := quick.Check(func(notMatch []int) bool { + for _, shouldMatch := range tc.shouldMatch { + if slices.Equal(notMatch, shouldMatch) { + // The random `notMatch` is actually something that shouldMatch. Skip it. + return true + } + } + matches, _ := Match(tc.pattern, codesToCodeAndValue(notMatch)) + return !matches + }, &quick.Config{}); err != nil { + t.Fatal(err) + } + } +} + +func TestCapture(t *testing.T) { + type setupStateAndAssert func() (*MatchState, func()) + type testCase struct { + setup setupStateAndAssert + parts []codeAndValue + } + + testCases := + []testCase{ + { + setup: func() (*MatchState, func()) { + var code0str string + return PatternToMatchState(CaptureVal(0, &code0str), Val(1)), func() { + if code0str != "hello" { + panic("unexpected value") + } + } + }, + parts: []codeAndValue{{0, "hello"}, {1, "world"}}, + }, + { + setup: func() (*MatchState, func()) { + var code0strs []string + return PatternToMatchState(CaptureOneOrMore(0, &code0strs), Val(1)), func() { + if code0strs[0] != "hello" { + panic("unexpected value") + } + if code0strs[1] != "world" { + panic("unexpected value") + } + } + }, + parts: []codeAndValue{{0, "hello"}, {0, "world"}, {1, ""}}, + }, + } + + _ = testCases + for _, tc := range testCases { + state, assert := tc.setup() + if matches, _ := Match(state, tc.parts); !matches { + t.Fatalf("failed to match %v with %s", tc.parts, state) + } + assert() + } +} + +func codesToCodeAndValue(codes []int) []codeAndValue { + out := make([]codeAndValue, len(codes)) + for i, c := range codes { + out[i] = codeAndValue{code: c} + } + return out +} + +func bytesToCodeAndValue(codes []byte) []codeAndValue { + out := make([]codeAndValue, len(codes)) + for i, c := range codes { + out[i] = codeAndValue{code: int(c)} + } + return out +} + +// FuzzMatchesRegexpBehavior fuzz tests the expression matcher by comparing it to the behavior of the regexp package. +func FuzzMatchesRegexpBehavior(f *testing.F) { + bytesToRegexpAndPattern := func(exp []byte) ([]byte, []Pattern) { + if len(exp) < 3 { + panic("regexp too short") + } + pattern := make([]Pattern, 0, len(exp)-2) + for i, b := range exp { + b = b % 32 + if i == 0 { + exp[i] = '^' + continue + } else if i == len(exp)-1 { + exp[i] = '$' + continue + } + switch { + case b < 26: + exp[i] = b + 'a' + pattern = append(pattern, Val(int(exp[i]))) + case i > 1 && b == 26: + exp[i] = '?' + pattern = pattern[:len(pattern)-1] + pattern = append(pattern, Optional(Val(int(exp[i-1])))) + case i > 1 && b == 27: + exp[i] = '*' + pattern = pattern[:len(pattern)-1] + pattern = append(pattern, ZeroOrMore(int(exp[i-1]))) + case i > 1 && b == 28: + exp[i] = '+' + pattern = pattern[:len(pattern)-1] + pattern = append(pattern, OneOrMore(int(exp[i-1]))) + default: + exp[i] = 'a' + pattern = append(pattern, Val(int(exp[i]))) + } + } + + return exp, pattern + } + + simplifyB := func(buf []byte) []byte { + for i, b := range buf { + buf[i] = (b % 26) + 'a' + } + return buf + } + + f.Fuzz(func(t *testing.T, expRules []byte, corpus []byte) { + if len(expRules) < 3 || len(expRules) > 1024 || len(corpus) > 1024 { + return + } + corpus = simplifyB(corpus) + regexpPattern, pattern := bytesToRegexpAndPattern(expRules) + matched, err := regexp.Match(string(regexpPattern), corpus) + if err != nil { + // Malformed regex. Ignore + return + } + p := PatternToMatchState(pattern...) + otherMatched, _ := Match(p, bytesToCodeAndValue(corpus)) + if otherMatched != matched { + t.Log("regexp", string(regexpPattern)) + t.Log("corpus", string(corpus)) + m2, err2 := regexp.Match(string(regexpPattern), corpus) + t.Logf("regexp matched %v. %v. %v, %v. \n%v - \n%v", matched, err, m2, err2, regexpPattern, corpus) + t.Logf("pattern %+v", pattern) + t.Fatalf("mismatched results: %v %v %v", otherMatched, matched, p) + } + }) + +} diff --git a/meg/sugar.go b/meg/sugar.go new file mode 100644 index 0000000..369a315 --- /dev/null +++ b/meg/sugar.go @@ -0,0 +1,130 @@ +package meg + +import ( + "errors" +) + +type Pattern = func(next *MatchState) *MatchState + +func PatternToMatchState(states ...Pattern) *MatchState { + nextState := &MatchState{kind: done} + for i := len(states) - 1; i >= 0; i-- { + nextState = states[i](nextState) + } + return nextState +} + +func Cat(left, right Pattern) Pattern { + return func(next *MatchState) *MatchState { + return left(right(next)) + } +} + +func Or(p ...Pattern) Pattern { + return func(next *MatchState) *MatchState { + if len(p) == 0 { + return next + } + if len(p) == 1 { + return p[0](next) + } + + return &MatchState{ + kind: split, + next: p[0](next), + nextSplit: Or(p[1:]...)(next), + } + } +} + +var errAlreadyCapture = errors.New("already captured") + +func captureOneValueOrErr(val *string) captureFunc { + if val == nil { + return nil + } + var set bool + f := func(s string) error { + if set { + *val = "" + return errAlreadyCapture + } + *val = s + return nil + } + return &f +} + +func captureMany(vals *[]string) captureFunc { + if vals == nil { + return nil + } + f := func(s string) error { + *vals = append(*vals, s) + return nil + } + return &f +} + +func captureValWithF(code int, f captureFunc) Pattern { + return func(next *MatchState) *MatchState { + return &MatchState{ + kind: matchCode, + capture: f, + code: code, + next: next, + } + } +} + +func Val(code int) Pattern { + return CaptureVal(code, nil) +} + +func CaptureVal(code int, val *string) Pattern { + return captureValWithF(code, captureOneValueOrErr(val)) +} + +func ZeroOrMore(code int) Pattern { + return CaptureZeroOrMore(code, nil) +} + +func captureZeroOrMoreWithF(code int, f captureFunc) Pattern { + return func(next *MatchState) *MatchState { + match := &MatchState{ + code: code, + capture: f, + } + s := &MatchState{ + kind: split, + next: match, + nextSplit: next, + } + match.next = s // Loop back to the split. + return s + } +} + +func CaptureZeroOrMore(code int, vals *[]string) Pattern { + return captureZeroOrMoreWithF(code, captureMany(vals)) +} + +func OneOrMore(code int) Pattern { + return CaptureOneOrMore(code, nil) +} +func CaptureOneOrMore(code int, vals *[]string) Pattern { + f := captureMany(vals) + return func(next *MatchState) *MatchState { + return captureValWithF(code, f)(captureZeroOrMoreWithF(code, f)(next)) + } +} + +func Optional(s Pattern) Pattern { + return func(next *MatchState) *MatchState { + return &MatchState{ + kind: split, + next: s(next), + nextSplit: next, + } + } +} diff --git a/meg_test.go b/meg_test.go new file mode 100644 index 0000000..817329d --- /dev/null +++ b/meg_test.go @@ -0,0 +1,45 @@ +package multiaddr + +import ( + "testing" + + "github.com/multiformats/go-multiaddr/meg" +) + +func TestMatchAndCaptureMultiaddr(t *testing.T) { + m := StringCast("/ip4/1.2.3.4/udp/8231/quic-v1/webtransport/certhash/b2uaraocy6yrdblb4sfptaddgimjmmpy/certhash/zQmbWTwYGcmdyK9CYfNBcfs9nhZs17a6FQ4Y8oea278xx41") + + var udpPort string + var certhashes []string + found, _ := m.Match( + meg.Or( + meg.Val(P_IP4), + meg.Val(P_IP6), + ), + meg.CaptureVal(P_UDP, &udpPort), + meg.Val(P_QUIC_V1), + meg.Val(P_WEBTRANSPORT), + meg.CaptureZeroOrMore(P_CERTHASH, &certhashes), + ) + if !found { + t.Fatal("failed to match") + } + if udpPort != "8231" { + t.Fatal("unexpected value") + } + + if len(certhashes) != 2 { + t.Fatal("Didn't capture all certhashes") + } + + { + m, c := SplitLast(m) + if c.Value() != certhashes[1] { + t.Fatal("unexpected value. Expected", c.RawValue(), "but got", []byte(certhashes[1])) + } + _, c = SplitLast(m) + if c.Value() != certhashes[0] { + t.Fatal("unexpected value. Expected", c.RawValue(), "but got", []byte(certhashes[0])) + } + } +} diff --git a/util.go b/util.go index d063e39..27814b0 100644 --- a/util.go +++ b/util.go @@ -2,6 +2,8 @@ package multiaddr import ( "fmt" + + "github.com/multiformats/go-multiaddr/meg" ) // Split returns the sub-address portions of a multiaddr. @@ -120,3 +122,8 @@ func ForEach(m Multiaddr, cb func(c Component) bool) { } } } + +func (m Multiaddr) Match(p ...meg.Pattern) (bool, error) { + s := meg.PatternToMatchState(p...) + return meg.Match(s, m) +} From b86244a12eee5c29c119bd4cc2bf7ae0a23695ce Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Thu, 6 Feb 2025 11:45:34 -0800 Subject: [PATCH 02/10] Move meg package to /x/ for experimental --- meg_test.go | 2 +- util.go | 2 +- {meg => x/meg}/meg.go | 0 {meg => x/meg}/meg_test.go | 0 {meg => x/meg}/sugar.go | 0 5 files changed, 2 insertions(+), 2 deletions(-) rename {meg => x/meg}/meg.go (100%) rename {meg => x/meg}/meg_test.go (100%) rename {meg => x/meg}/sugar.go (100%) diff --git a/meg_test.go b/meg_test.go index 817329d..32fcadd 100644 --- a/meg_test.go +++ b/meg_test.go @@ -3,7 +3,7 @@ package multiaddr import ( "testing" - "github.com/multiformats/go-multiaddr/meg" + "github.com/multiformats/go-multiaddr/x/meg" ) func TestMatchAndCaptureMultiaddr(t *testing.T) { diff --git a/util.go b/util.go index 27814b0..b4a7174 100644 --- a/util.go +++ b/util.go @@ -3,7 +3,7 @@ package multiaddr import ( "fmt" - "github.com/multiformats/go-multiaddr/meg" + "github.com/multiformats/go-multiaddr/x/meg" ) // Split returns the sub-address portions of a multiaddr. diff --git a/meg/meg.go b/x/meg/meg.go similarity index 100% rename from meg/meg.go rename to x/meg/meg.go diff --git a/meg/meg_test.go b/x/meg/meg_test.go similarity index 100% rename from meg/meg_test.go rename to x/meg/meg_test.go diff --git a/meg/sugar.go b/x/meg/sugar.go similarity index 100% rename from meg/sugar.go rename to x/meg/sugar.go From 0e1db6031342fb4d234d261e7e3b098c9fd70c1a Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Thu, 6 Feb 2025 11:45:40 -0800 Subject: [PATCH 03/10] Add fuzz cases --- x/meg/testdata/fuzz/FuzzMatchesRegexpBehavior/1e98195527a61d9c | 3 +++ x/meg/testdata/fuzz/FuzzMatchesRegexpBehavior/a92659f6360052e5 | 3 +++ x/meg/testdata/fuzz/FuzzMatchesRegexpBehavior/b02eb19d06b02202 | 3 +++ x/meg/testdata/fuzz/FuzzMatchesRegexpBehavior/b3cb6a0a30c58d03 | 3 +++ x/meg/testdata/fuzz/FuzzMatchesRegexpBehavior/b9eb3d27681dd876 | 3 +++ x/meg/testdata/fuzz/FuzzMatchesRegexpBehavior/f8e3578130e7a246 | 3 +++ 6 files changed, 18 insertions(+) create mode 100644 x/meg/testdata/fuzz/FuzzMatchesRegexpBehavior/1e98195527a61d9c create mode 100644 x/meg/testdata/fuzz/FuzzMatchesRegexpBehavior/a92659f6360052e5 create mode 100644 x/meg/testdata/fuzz/FuzzMatchesRegexpBehavior/b02eb19d06b02202 create mode 100644 x/meg/testdata/fuzz/FuzzMatchesRegexpBehavior/b3cb6a0a30c58d03 create mode 100644 x/meg/testdata/fuzz/FuzzMatchesRegexpBehavior/b9eb3d27681dd876 create mode 100644 x/meg/testdata/fuzz/FuzzMatchesRegexpBehavior/f8e3578130e7a246 diff --git a/x/meg/testdata/fuzz/FuzzMatchesRegexpBehavior/1e98195527a61d9c b/x/meg/testdata/fuzz/FuzzMatchesRegexpBehavior/1e98195527a61d9c new file mode 100644 index 0000000..550d18d --- /dev/null +++ b/x/meg/testdata/fuzz/FuzzMatchesRegexpBehavior/1e98195527a61d9c @@ -0,0 +1,3 @@ +go test fuzz v1 +[]byte("^c?") +[]byte("o") diff --git a/x/meg/testdata/fuzz/FuzzMatchesRegexpBehavior/a92659f6360052e5 b/x/meg/testdata/fuzz/FuzzMatchesRegexpBehavior/a92659f6360052e5 new file mode 100644 index 0000000..483a83b --- /dev/null +++ b/x/meg/testdata/fuzz/FuzzMatchesRegexpBehavior/a92659f6360052e5 @@ -0,0 +1,3 @@ +go test fuzz v1 +[]byte("") +[]byte("w") diff --git a/x/meg/testdata/fuzz/FuzzMatchesRegexpBehavior/b02eb19d06b02202 b/x/meg/testdata/fuzz/FuzzMatchesRegexpBehavior/b02eb19d06b02202 new file mode 100644 index 0000000..503a812 --- /dev/null +++ b/x/meg/testdata/fuzz/FuzzMatchesRegexpBehavior/b02eb19d06b02202 @@ -0,0 +1,3 @@ +go test fuzz v1 +[]byte("^zy") +[]byte("p ") diff --git a/x/meg/testdata/fuzz/FuzzMatchesRegexpBehavior/b3cb6a0a30c58d03 b/x/meg/testdata/fuzz/FuzzMatchesRegexpBehavior/b3cb6a0a30c58d03 new file mode 100644 index 0000000..fa67084 --- /dev/null +++ b/x/meg/testdata/fuzz/FuzzMatchesRegexpBehavior/b3cb6a0a30c58d03 @@ -0,0 +1,3 @@ +go test fuzz v1 +[]byte("^b?") +[]byte("pw") diff --git a/x/meg/testdata/fuzz/FuzzMatchesRegexpBehavior/b9eb3d27681dd876 b/x/meg/testdata/fuzz/FuzzMatchesRegexpBehavior/b9eb3d27681dd876 new file mode 100644 index 0000000..3d3ada3 --- /dev/null +++ b/x/meg/testdata/fuzz/FuzzMatchesRegexpBehavior/b9eb3d27681dd876 @@ -0,0 +1,3 @@ +go test fuzz v1 +[]byte("^b:") +[]byte("^b:MMMMMMMMMMMMMMMMMMpw") diff --git a/x/meg/testdata/fuzz/FuzzMatchesRegexpBehavior/f8e3578130e7a246 b/x/meg/testdata/fuzz/FuzzMatchesRegexpBehavior/f8e3578130e7a246 new file mode 100644 index 0000000..0b69bbd --- /dev/null +++ b/x/meg/testdata/fuzz/FuzzMatchesRegexpBehavior/f8e3578130e7a246 @@ -0,0 +1,3 @@ +go test fuzz v1 +[]byte("gw\xa6%\xc5\xc7kD\xdf\x14_\xebק\xd8H\xcf0\xcf~/\xf9\x1a\x1a\x1a\x1a\x1a\x1a\x1a\x1a\xd5{1\x98ױ\x841(") +[]byte("") From 6088fcd323b14360592ab442127167fea38a9428 Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Thu, 6 Feb 2025 12:15:40 -0800 Subject: [PATCH 04/10] faster Code() method twice as fast without the copy --- component.go | 1 - 1 file changed, 1 deletion(-) diff --git a/component.go b/component.go index c9b618f..0158bd1 100644 --- a/component.go +++ b/component.go @@ -167,7 +167,6 @@ func (c *Component) Code() int { return 0 } return c.Protocol().Code -} func (c *Component) RawValue() []byte { if c == nil { From 3342cbf2da6cc2659841d7e644a779cf97a29d62 Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Thu, 6 Feb 2025 13:45:10 -0800 Subject: [PATCH 05/10] avoid copying an empty map if no captures --- x/meg/meg.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/x/meg/meg.go b/x/meg/meg.go index b3b6db4..952111f 100644 --- a/x/meg/meg.go +++ b/x/meg/meg.go @@ -33,6 +33,9 @@ type captureFunc *func(string) error type captureMap map[captureFunc][]string func (cm captureMap) clone() captureMap { + if cm == nil { + return nil + } out := make(captureMap, len(cm)) for k, v := range cm { out[k] = slices.Clone(v) @@ -80,6 +83,10 @@ func Match[S ~[]T, T Matchable](s *MatchState, components S) (bool, error) { if s.kind == matchCode && s.code == c.Code() { cm := currentStates.captures[i] if s.capture != nil { + if cm == nil { + cm = make(captureMap) + currentStates.captures[i] = cm + } cm[s.capture] = append(cm[s.capture], c.Value()) } nextStates = appendState(nextStates, s.next, currentStates.captures[i], listGeneration) @@ -111,9 +118,6 @@ func appendState(arr statesAndCaptures, s *MatchState, c captureMap, listGenerat if s == nil || s.generation == listGeneration { return arr } - if c == nil { - c = make(captureMap) - } s.generation = listGeneration if s.kind == split { arr = appendState(arr, s.next, c, listGeneration) From 09e5347bee36e83e378c1ec9cc49958989e39f9e Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Wed, 19 Feb 2025 17:21:21 -0800 Subject: [PATCH 06/10] Faster megular expressions (#265) * much cheaper copies of captures * Add a benchmark * allocate to a slice. Use indexes as handles * cleanup * Add nocapture loop benchmark It's really fast. No surprise * cleanup * nits --- util.go | 4 +- x/meg/bench_test.go | 416 ++++++++++++++++++++++++++++++++++++++++++++ x/meg/meg.go | 187 +++++++++++++------- x/meg/meg_test.go | 44 +++-- x/meg/sugar.go | 132 +++++++++----- 5 files changed, 662 insertions(+), 121 deletions(-) create mode 100644 x/meg/bench_test.go diff --git a/util.go b/util.go index b4a7174..2cfed19 100644 --- a/util.go +++ b/util.go @@ -124,6 +124,6 @@ func ForEach(m Multiaddr, cb func(c Component) bool) { } func (m Multiaddr) Match(p ...meg.Pattern) (bool, error) { - s := meg.PatternToMatchState(p...) - return meg.Match(s, m) + matcher := meg.PatternToMatcher(p...) + return meg.Match(matcher, m) } diff --git a/x/meg/bench_test.go b/x/meg/bench_test.go new file mode 100644 index 0000000..15bf386 --- /dev/null +++ b/x/meg/bench_test.go @@ -0,0 +1,416 @@ +package meg_test + +import ( + "testing" + + "github.com/multiformats/go-multiaddr" + "github.com/multiformats/go-multiaddr/x/meg" +) + +type preallocatedCapture struct { + certHashes []string + matcher meg.Matcher +} + +func preallocateCapture() *preallocatedCapture { + p := &preallocatedCapture{} + p.matcher = meg.PatternToMatcher( + meg.Or( + meg.Val(multiaddr.P_IP4), + meg.Val(multiaddr.P_IP6), + meg.Val(multiaddr.P_DNS), + ), + meg.Val(multiaddr.P_UDP), + meg.Val(multiaddr.P_WEBRTC_DIRECT), + meg.CaptureZeroOrMore(multiaddr.P_CERTHASH, &p.certHashes), + ) + return p +} + +var webrtcMatchPrealloc *preallocatedCapture + +func (p *preallocatedCapture) IsWebRTCDirectMultiaddr(addr multiaddr.Multiaddr) (bool, int) { + found, _ := meg.Match(p.matcher, addr) + return found, len(p.certHashes) +} + +// IsWebRTCDirectMultiaddr returns whether addr is a /webrtc-direct multiaddr with the count of certhashes +// in addr +func IsWebRTCDirectMultiaddr(addr multiaddr.Multiaddr) (bool, int) { + if webrtcMatchPrealloc == nil { + webrtcMatchPrealloc = preallocateCapture() + } + return webrtcMatchPrealloc.IsWebRTCDirectMultiaddr(addr) +} + +// IsWebRTCDirectMultiaddrLoop returns whether addr is a /webrtc-direct multiaddr with the count of certhashes +// in addr +func IsWebRTCDirectMultiaddrLoop(addr multiaddr.Multiaddr) (bool, int) { + protos := [...]int{multiaddr.P_IP4, multiaddr.P_IP6, multiaddr.P_DNS, multiaddr.P_UDP, multiaddr.P_WEBRTC_DIRECT} + matchProtos := [...][]int{protos[:3], {protos[3]}, {protos[4]}} + certHashCount := 0 + for i, c := range addr { + if i >= len(matchProtos) { + if c.Code() == multiaddr.P_CERTHASH { + certHashCount++ + } else { + return false, 0 + } + } else { + found := false + for _, proto := range matchProtos[i] { + if c.Code() == proto { + found = true + break + } + } + if !found { + return false, 0 + } + } + } + return true, certHashCount +} + +var wtPrealloc *preallocatedCapture + +func isWebTransportMultiaddrPrealloc() *preallocatedCapture { + if wtPrealloc != nil { + return wtPrealloc + } + + p := &preallocatedCapture{} + var dnsName string + var ip4Addr string + var ip6Addr string + var udpPort string + var sni string + p.matcher = meg.PatternToMatcher( + meg.Or( + meg.CaptureVal(multiaddr.P_IP4, &ip4Addr), + meg.CaptureVal(multiaddr.P_IP6, &ip6Addr), + meg.CaptureVal(multiaddr.P_DNS4, &dnsName), + meg.CaptureVal(multiaddr.P_DNS6, &dnsName), + meg.CaptureVal(multiaddr.P_DNS, &dnsName), + ), + meg.CaptureVal(multiaddr.P_UDP, &udpPort), + meg.Val(multiaddr.P_QUIC_V1), + meg.Optional( + meg.CaptureVal(multiaddr.P_SNI, &sni), + ), + meg.Val(multiaddr.P_WEBTRANSPORT), + meg.CaptureZeroOrMore(multiaddr.P_CERTHASH, &p.certHashes), + ) + wtPrealloc = p + return p +} + +func IsWebTransportMultiaddrPrealloc(m multiaddr.Multiaddr) (bool, int) { + p := isWebTransportMultiaddrPrealloc() + found, _ := meg.Match(p.matcher, m) + return found, len(p.certHashes) +} + +func IsWebTransportMultiaddr(m multiaddr.Multiaddr) (bool, int) { + var dnsName string + var ip4Addr string + var ip6Addr string + var udpPort string + var sni string + var certHashesStr []string + matched, _ := m.Match( + meg.Or( + meg.CaptureVal(multiaddr.P_IP4, &ip4Addr), + meg.CaptureVal(multiaddr.P_IP6, &ip6Addr), + meg.CaptureVal(multiaddr.P_DNS4, &dnsName), + meg.CaptureVal(multiaddr.P_DNS6, &dnsName), + meg.CaptureVal(multiaddr.P_DNS, &dnsName), + ), + meg.CaptureVal(multiaddr.P_UDP, &udpPort), + meg.Val(multiaddr.P_QUIC_V1), + meg.Optional( + meg.CaptureVal(multiaddr.P_SNI, &sni), + ), + meg.Val(multiaddr.P_WEBTRANSPORT), + meg.CaptureZeroOrMore(multiaddr.P_CERTHASH, &certHashesStr), + ) + if !matched { + return false, 0 + } + return true, len(certHashesStr) +} + +func IsWebTransportMultiaddrNoCapture(m multiaddr.Multiaddr) (bool, int) { + matched, _ := m.Match( + meg.Or( + meg.Val(multiaddr.P_IP4), + meg.Val(multiaddr.P_IP6), + meg.Val(multiaddr.P_DNS4), + meg.Val(multiaddr.P_DNS6), + meg.Val(multiaddr.P_DNS), + ), + meg.Val(multiaddr.P_UDP), + meg.Val(multiaddr.P_QUIC_V1), + meg.Optional( + meg.Val(multiaddr.P_SNI), + ), + meg.Val(multiaddr.P_WEBTRANSPORT), + meg.ZeroOrMore(multiaddr.P_CERTHASH), + ) + if !matched { + return false, 0 + } + return true, 0 +} + +func IsWebTransportMultiaddrLoop(m multiaddr.Multiaddr) (bool, int) { + var ip4Addr string + var ip6Addr string + var dnsName string + var udpPort string + var sni string + + // Expected pattern: + // 0: one of: P_IP4, P_IP6, P_DNS4, P_DNS6, P_DNS + // 1: P_UDP + // 2: P_QUIC_V1 + // 3: optional P_SNI (if present) + // Next: P_WEBTRANSPORT + // Trailing: zero or more P_CERTHASH + + // Check minimum length (at least without SNI: 4 components) + if len(m) < 4 { + return false, 0 + } + + idx := 0 + + // Component 0: Must be one of IP or DNS protocols. + switch m[idx].Code() { + case multiaddr.P_IP4: + ip4Addr = m[idx].String() + case multiaddr.P_IP6: + ip6Addr = m[idx].String() + case multiaddr.P_DNS4, multiaddr.P_DNS6, multiaddr.P_DNS: + dnsName = m[idx].String() + default: + return false, 0 + } + idx++ + + // Component 1: Must be UDP. + if idx >= len(m) || m[idx].Code() != multiaddr.P_UDP { + return false, 0 + } + udpPort = m[idx].String() + idx++ + + // Component 2: Must be QUIC_V1. + if idx >= len(m) || m[idx].Code() != multiaddr.P_QUIC_V1 { + return false, 0 + } + idx++ + + // Optional component: SNI. + if idx < len(m) && m[idx].Code() == multiaddr.P_SNI { + sni = m[idx].String() + idx++ + } + + // Next component: Must be WEBTRANSPORT. + if idx >= len(m) || m[idx].Code() != multiaddr.P_WEBTRANSPORT { + return false, 0 + } + idx++ + + // All remaining components must be CERTHASH. + certHashCount := 0 + for ; idx < len(m); idx++ { + if m[idx].Code() != multiaddr.P_CERTHASH { + return false, 0 + } + _ = m[idx].String() + certHashCount++ + } + + _ = ip4Addr + _ = ip6Addr + _ = dnsName + _ = udpPort + _ = sni + + return true, certHashCount +} + +func IsWebTransportMultiaddrLoopNoCapture(m multiaddr.Multiaddr) (bool, int) { + // Expected pattern: + // 0: one of: P_IP4, P_IP6, P_DNS4, P_DNS6, P_DNS + // 1: P_UDP + // 2: P_QUIC_V1 + // 3: optional P_SNI (if present) + // Next: P_WEBTRANSPORT + // Trailing: zero or more P_CERTHASH + + // Check minimum length (at least without SNI: 4 components) + if len(m) < 4 { + return false, 0 + } + + idx := 0 + + // Component 0: Must be one of IP or DNS protocols. + switch m[idx].Code() { + case multiaddr.P_IP4: + case multiaddr.P_IP6: + case multiaddr.P_DNS4, multiaddr.P_DNS6, multiaddr.P_DNS: + default: + return false, 0 + } + idx++ + + // Component 1: Must be UDP. + if idx >= len(m) || m[idx].Code() != multiaddr.P_UDP { + return false, 0 + } + idx++ + + // Component 2: Must be QUIC_V1. + if idx >= len(m) || m[idx].Code() != multiaddr.P_QUIC_V1 { + return false, 0 + } + idx++ + + // Optional component: SNI. + if idx < len(m) && m[idx].Code() == multiaddr.P_SNI { + idx++ + } + + // Next component: Must be WEBTRANSPORT. + if idx >= len(m) || m[idx].Code() != multiaddr.P_WEBTRANSPORT { + return false, 0 + } + idx++ + + // All remaining components must be CERTHASH. + for ; idx < len(m); idx++ { + if m[idx].Code() != multiaddr.P_CERTHASH { + return false, 0 + } + _ = m[idx].String() + } + + return true, 0 +} + +func BenchmarkIsWebTransportMultiaddrPrealloc(b *testing.B) { + addr := multiaddr.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1/sni/example.com/webtransport") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + isWT, count := IsWebTransportMultiaddrPrealloc(addr) + if !isWT || count != 0 { + b.Fatal("unexpected result") + } + } +} + +func BenchmarkIsWebTransportMultiaddrNoCapturePrealloc(b *testing.B) { + addr := multiaddr.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1/sni/example.com/webtransport") + + wtPreallocNoCapture := meg.PatternToMatcher( + meg.Or( + meg.Val(multiaddr.P_IP4), + meg.Val(multiaddr.P_IP6), + meg.Val(multiaddr.P_DNS4), + meg.Val(multiaddr.P_DNS6), + meg.Val(multiaddr.P_DNS), + ), + meg.Val(multiaddr.P_UDP), + meg.Val(multiaddr.P_QUIC_V1), + meg.Optional( + meg.Val(multiaddr.P_SNI), + ), + meg.Val(multiaddr.P_WEBTRANSPORT), + meg.ZeroOrMore(multiaddr.P_CERTHASH), + ) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + isWT, _ := meg.Match(wtPreallocNoCapture, addr) + if !isWT { + b.Fatal("unexpected result") + } + } +} + +func BenchmarkIsWebTransportMultiaddrNoCapture(b *testing.B) { + addr := multiaddr.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1/sni/example.com/webtransport") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + isWT, count := IsWebTransportMultiaddrNoCapture(addr) + if !isWT || count != 0 { + b.Fatal("unexpected result") + } + } +} + +func BenchmarkIsWebTransportMultiaddr(b *testing.B) { + addr := multiaddr.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1/sni/example.com/webtransport") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + isWT, count := IsWebTransportMultiaddr(addr) + if !isWT || count != 0 { + b.Fatal("unexpected result") + } + } +} + +func BenchmarkIsWebTransportMultiaddrLoop(b *testing.B) { + addr := multiaddr.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1/sni/example.com/webtransport") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + isWT, count := IsWebTransportMultiaddrLoop(addr) + if !isWT || count != 0 { + b.Fatal("unexpected result") + } + } +} + +func BenchmarkIsWebTransportMultiaddrLoopNoCapture(b *testing.B) { + addr := multiaddr.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1/sni/example.com/webtransport") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + isWT, count := IsWebTransportMultiaddrLoopNoCapture(addr) + if !isWT || count != 0 { + b.Fatal("unexpected result") + } + } +} + +func BenchmarkIsWebRTCDirectMultiaddr(b *testing.B) { + addr := multiaddr.StringCast("/ip4/1.2.3.4/udp/1234/webrtc-direct/") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + isWebRTC, count := IsWebRTCDirectMultiaddr(addr) + if !isWebRTC || count != 0 { + b.Fatal("unexpected result") + } + } +} + +func BenchmarkIsWebRTCDirectMultiaddrLoop(b *testing.B) { + addr := multiaddr.StringCast("/ip4/1.2.3.4/udp/1234/webrtc-direct/") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + isWebRTC, count := IsWebRTCDirectMultiaddrLoop(addr) + if !isWebRTC || count != 0 { + b.Fatal("unexpected result") + } + } +} diff --git a/x/meg/meg.go b/x/meg/meg.go index 952111f..ccc2275 100644 --- a/x/meg/meg.go +++ b/x/meg/meg.go @@ -7,49 +7,48 @@ package meg import ( "fmt" - "slices" ) -type stateKind uint8 +type stateKind = int const ( - matchCode stateKind = iota - split - done + done stateKind = (iota * -1) - 1 + // split anything else that is negative ) // MatchState is the Thompson NFA for a regular expression. type MatchState struct { - capture captureFunc - next *MatchState - nextSplit *MatchState - - kind stateKind - generation int - code int + capture captureFunc + // next is is the index of the next state. in the MatchState array. + next int + // If codeOrKind is negative, it is a kind. + // If it is negative, but not a `done`, then it is the index to the next split. + // This is done to keep the `MatchState` struct small and cache friendly. + codeOrKind int } -type captureFunc *func(string) error -type captureMap map[captureFunc][]string +type captureFunc func(string) error -func (cm captureMap) clone() captureMap { - if cm == nil { - return nil - } - out := make(captureMap, len(cm)) - for k, v := range cm { - out[k] = slices.Clone(v) - } - return out +// capture is a linked list of capture funcs with values. +type capture struct { + f captureFunc + v string + prev *capture } type statesAndCaptures struct { - states []*MatchState - captures []captureMap + states []int + captures []*capture } -func (s *MatchState) String() string { - return fmt.Sprintf("state{kind: %d, generation: %d, code: %d}", s.kind, s.generation, s.code) +func (s MatchState) String() string { + if s.codeOrKind == done { + return "done" + } + if s.codeOrKind < done { + return fmt.Sprintf("split{left: %d, right: %d}", s.next, restoreSplitIdx(s.codeOrKind)) + } + return fmt.Sprintf("match{code: %d, next: %d}", s.codeOrKind, s.next) } type Matchable interface { @@ -60,52 +59,80 @@ type Matchable interface { // Match returns whether the given Components match the Pattern defined in MatchState. // Errors are used to communicate capture errors. // If the error is non-nil the returned bool will be false. -func Match[S ~[]T, T Matchable](s *MatchState, components S) (bool, error) { - listGeneration := s.generation + 1 // Start at the last generation + 1 - defer func() { s.generation = listGeneration }() // In case we reuse this state, store our highest generation number +func Match[S ~[]T, T Matchable](matcher Matcher, components S) (bool, error) { + states := matcher.states + startStateIdx := matcher.startIdx + + // Fast case for a small number of states (<128) + // Avoids allocation of a slice for the visitedBitSet. + stackBitSet := [2]uint64{} + visitedBitSet := stackBitSet[:] + if len(states) > 128 { + visitedBitSet = make([]uint64, (len(states)+63)/64) + } currentStates := statesAndCaptures{ - states: make([]*MatchState, 0, 16), - captures: make([]captureMap, 0, 16), + states: make([]int, 0, 16), + captures: make([]*capture, 0, 16), } nextStates := statesAndCaptures{ - states: make([]*MatchState, 0, 16), - captures: make([]captureMap, 0, 16), + states: make([]int, 0, 16), + captures: make([]*capture, 0, 16), } - currentStates = appendState(currentStates, s, nil, listGeneration) + currentStates = appendState(currentStates, states, startStateIdx, nil, visitedBitSet) for _, c := range components { + clear(visitedBitSet) if len(currentStates.states) == 0 { return false, nil } - for i, s := range currentStates.states { - if s.kind == matchCode && s.code == c.Code() { + for i, stateIndex := range currentStates.states { + s := states[stateIndex] + if s.codeOrKind >= 0 && s.codeOrKind == c.Code() { cm := currentStates.captures[i] if s.capture != nil { + next := &capture{ + f: s.capture, + v: c.Value(), + } if cm == nil { - cm = make(captureMap) - currentStates.captures[i] = cm + cm = next + } else { + next.prev = cm + cm = next } - cm[s.capture] = append(cm[s.capture], c.Value()) + currentStates.captures[i] = cm } - nextStates = appendState(nextStates, s.next, currentStates.captures[i], listGeneration) + nextStates = appendState(nextStates, states, s.next, cm, visitedBitSet) } } currentStates, nextStates = nextStates, currentStates nextStates.states = nextStates.states[:0] nextStates.captures = nextStates.captures[:0] - listGeneration++ } - for i, s := range currentStates.states { - if s.kind == done { + for i, stateIndex := range currentStates.states { + s := states[stateIndex] + if s.codeOrKind == done { + // We found a complete path. Run the captures now - for f, v := range currentStates.captures[i] { - for _, s := range v { - if err := (*f)(s); err != nil { - return false, err - } + c := currentStates.captures[i] + + // Flip the order of the captures because we see captures from right + // to left, but users expect them left to right. + type captureWithVal struct { + f captureFunc + v string + } + reversedCaptures := make([]captureWithVal, 0, 16) + for c != nil { + reversedCaptures = append(reversedCaptures, captureWithVal{c.f, c.v}) + c = c.prev + } + for i := len(reversedCaptures) - 1; i >= 0; i-- { + if err := reversedCaptures[i].f(reversedCaptures[i].v); err != nil { + return false, err } } return true, nil @@ -114,17 +141,59 @@ func Match[S ~[]T, T Matchable](s *MatchState, components S) (bool, error) { return false, nil } -func appendState(arr statesAndCaptures, s *MatchState, c captureMap, listGeneration int) statesAndCaptures { - if s == nil || s.generation == listGeneration { - return arr +// appendState is a non-recursive way of appending states to statesAndCaptures. +// If a state is a split, both branches are appended to statesAndCaptures. +func appendState(arr statesAndCaptures, states []MatchState, stateIndex int, c *capture, visitedBitSet []uint64) statesAndCaptures { + // Local struct to hold state index and the associated capture pointer. + type task struct { + idx int + cap *capture } - s.generation = listGeneration - if s.kind == split { - arr = appendState(arr, s.next, c, listGeneration) - arr = appendState(arr, s.nextSplit, c.clone(), listGeneration) - } else { - arr.states = append(arr.states, s) - arr.captures = append(arr.captures, c) + + // Initialize the stack with the starting task. + stack := make([]task, 0, 16) + stack = append(stack, task{stateIndex, c}) + + // Process the stack until empty. + for len(stack) > 0 { + // Pop the last element (LIFO order). + n := len(stack) - 1 + t := stack[n] + stack = stack[:n] + + // If the state index is out of bounds, skip. + if t.idx >= len(states) { + continue + } + s := states[t.idx] + + // Check if this state has already been visited. + if visitedBitSet[t.idx/64]&(1<<(t.idx%64)) != 0 { + continue + } + // Mark the state as visited. + visitedBitSet[t.idx/64] |= 1 << (t.idx % 64) + + // If it's a split state (the value is less than done) then push both branches. + if s.codeOrKind < done { + // Get the second branch from the split. + splitIdx := restoreSplitIdx(s.codeOrKind) + // To preserve order (s.next processed first), push the split branch first. + stack = append(stack, task{splitIdx, t.cap}) + stack = append(stack, task{s.next, t.cap}) + } else { + // Otherwise, it's a valid final state -- append it. + arr.states = append(arr.states, t.idx) + arr.captures = append(arr.captures, t.cap) + } } return arr } + +func storeSplitIdx(codeOrKind int) int { + return (codeOrKind + 2) * -1 +} + +func restoreSplitIdx(splitIdx int) int { + return (splitIdx * -1) - 2 +} diff --git a/x/meg/meg_test.go b/x/meg/meg_test.go index b47ba02..e0265c4 100644 --- a/x/meg/meg_test.go +++ b/x/meg/meg_test.go @@ -26,7 +26,7 @@ var _ Matchable = codeAndValue{} func TestSimple(t *testing.T) { type testCase struct { - pattern *MatchState + pattern Matcher skipQuickCheck bool shouldMatch [][]int shouldNotMatch [][]int @@ -34,14 +34,24 @@ func TestSimple(t *testing.T) { testCases := []testCase{ { - pattern: PatternToMatchState(Val(0), Val(1)), + pattern: PatternToMatcher(Val(0), Val(1)), shouldMatch: [][]int{{0, 1}}, shouldNotMatch: [][]int{ {0}, {0, 0}, {0, 1, 0}, - }}, { - pattern: PatternToMatchState(Val(0), Val(1), Optional(Val(2))), + }, + }, + { + pattern: PatternToMatcher(Optional(Val(1))), + shouldMatch: [][]int{ + {1}, + {}, + }, + shouldNotMatch: [][]int{{0}}, + }, + { + pattern: PatternToMatcher(Val(0), Val(1), Optional(Val(2))), shouldMatch: [][]int{ {0, 1, 2}, {0, 1}, @@ -52,7 +62,7 @@ func TestSimple(t *testing.T) { {0, 1, 0}, {0, 1, 2, 0}, }}, { - pattern: PatternToMatchState(Val(0), Val(1), OneOrMore(2)), + pattern: PatternToMatcher(Val(0), Val(1), OneOrMore(2)), skipQuickCheck: true, shouldMatch: [][]int{ {0, 1, 2, 2, 2, 2}, @@ -70,13 +80,13 @@ func TestSimple(t *testing.T) { for i, tc := range testCases { for _, m := range tc.shouldMatch { - if matches, _ := Match(tc.pattern, codesToCodeAndValue(m)); !matches { - t.Fatalf("failed to match %v with %s. idx=%d", m, tc.pattern, i) + if matches, err := Match(tc.pattern, codesToCodeAndValue(m)); !matches { + t.Fatalf("failed to match %v with %v. idx=%d. err=%v", m, tc.pattern, i, err) } } for _, m := range tc.shouldNotMatch { if matches, _ := Match(tc.pattern, codesToCodeAndValue(m)); matches { - t.Fatalf("failed to not match %v with %s. idx=%d", m, tc.pattern, i) + t.Fatalf("failed to not match %v with %v. idx=%d", m, tc.pattern, i) } } if tc.skipQuickCheck { @@ -98,7 +108,7 @@ func TestSimple(t *testing.T) { } func TestCapture(t *testing.T) { - type setupStateAndAssert func() (*MatchState, func()) + type setupStateAndAssert func() (Matcher, func()) type testCase struct { setup setupStateAndAssert parts []codeAndValue @@ -107,9 +117,9 @@ func TestCapture(t *testing.T) { testCases := []testCase{ { - setup: func() (*MatchState, func()) { + setup: func() (Matcher, func()) { var code0str string - return PatternToMatchState(CaptureVal(0, &code0str), Val(1)), func() { + return PatternToMatcher(CaptureVal(0, &code0str), Val(1)), func() { if code0str != "hello" { panic("unexpected value") } @@ -118,9 +128,9 @@ func TestCapture(t *testing.T) { parts: []codeAndValue{{0, "hello"}, {1, "world"}}, }, { - setup: func() (*MatchState, func()) { + setup: func() (Matcher, func()) { var code0strs []string - return PatternToMatchState(CaptureOneOrMore(0, &code0strs), Val(1)), func() { + return PatternToMatcher(CaptureOneOrMore(0, &code0strs), Val(1)), func() { if code0strs[0] != "hello" { panic("unexpected value") } @@ -137,7 +147,7 @@ func TestCapture(t *testing.T) { for _, tc := range testCases { state, assert := tc.setup() if matches, _ := Match(state, tc.parts); !matches { - t.Fatalf("failed to match %v with %s", tc.parts, state) + t.Fatalf("failed to match %v with %v", tc.parts, state) } assert() } @@ -161,7 +171,7 @@ func bytesToCodeAndValue(codes []byte) []codeAndValue { // FuzzMatchesRegexpBehavior fuzz tests the expression matcher by comparing it to the behavior of the regexp package. func FuzzMatchesRegexpBehavior(f *testing.F) { - bytesToRegexpAndPattern := func(exp []byte) ([]byte, []Pattern) { + bytesToRegexpAndPattern := func(exp []byte) (string, []Pattern) { if len(exp) < 3 { panic("regexp too short") } @@ -197,7 +207,7 @@ func FuzzMatchesRegexpBehavior(f *testing.F) { } } - return exp, pattern + return string(exp), pattern } simplifyB := func(buf []byte) []byte { @@ -218,7 +228,7 @@ func FuzzMatchesRegexpBehavior(f *testing.F) { // Malformed regex. Ignore return } - p := PatternToMatchState(pattern...) + p := PatternToMatcher(pattern...) otherMatched, _ := Match(p, bytesToCodeAndValue(corpus)) if otherMatched != matched { t.Log("regexp", string(regexpPattern)) diff --git a/x/meg/sugar.go b/x/meg/sugar.go index 369a315..ee961cc 100644 --- a/x/meg/sugar.go +++ b/x/meg/sugar.go @@ -2,38 +2,69 @@ package meg import ( "errors" + "fmt" + "strconv" + "strings" ) -type Pattern = func(next *MatchState) *MatchState +// Pattern is essentially a curried MatchState. +// Given the slice of current MatchStates and a handle (int index) to the next +// MatchState, it returns a handle to the inserted MatchState. +type Pattern = func(states *[]MatchState, nextIdx int) int -func PatternToMatchState(states ...Pattern) *MatchState { - nextState := &MatchState{kind: done} - for i := len(states) - 1; i >= 0; i-- { - nextState = states[i](nextState) +type Matcher struct { + states []MatchState + startIdx int +} + +func (s Matcher) String() string { + states := make([]string, len(s.states)) + for i, state := range s.states { + states[i] = state.String() + "@" + strconv.Itoa(i) + } + return fmt.Sprintf("RootMatchState{states: [%s], startIdx: %d}", strings.Join(states, ", "), s.startIdx) +} + +func PatternToMatcher(patterns ...Pattern) Matcher { + // Preallocate a slice to hold the MatchStates. + // Avoids small allocations for each pattern. + // The number is chosen experimentally. It is subject to change. + states := make([]MatchState, 0, len(patterns)*3) + // Append the done state. + states = append(states, MatchState{codeOrKind: done}) + nextIdx := len(states) - 1 + // Build the chain by composing patterns from right to left. + for i := len(patterns) - 1; i >= 0; i-- { + nextIdx = patterns[i](&states, nextIdx) } - return nextState + return Matcher{states: states, startIdx: nextIdx} } func Cat(left, right Pattern) Pattern { - return func(next *MatchState) *MatchState { - return left(right(next)) + return func(states *[]MatchState, nextIdx int) int { + // First run the right pattern, then feed the result into left. + return left(states, right(states, nextIdx)) } } func Or(p ...Pattern) Pattern { - return func(next *MatchState) *MatchState { + return func(states *[]MatchState, nextIdx int) int { if len(p) == 0 { - return next + return nextIdx } - if len(p) == 1 { - return p[0](next) - } - - return &MatchState{ - kind: split, - next: p[0](next), - nextSplit: Or(p[1:]...)(next), + // Evaluate the last pattern and use its result as the initial accumulator. + accum := p[len(p)-1](states, nextIdx) + // Iterate backwards from the second-to-last pattern to the first. + for i := len(p) - 2; i >= 0; i-- { + leftIdx := p[i](states, nextIdx) + newState := MatchState{ + next: leftIdx, + codeOrKind: storeSplitIdx(accum), + } + *states = append(*states, newState) + accum = len(*states) - 1 } + return accum } } @@ -52,7 +83,7 @@ func captureOneValueOrErr(val *string) captureFunc { *val = s return nil } - return &f + return f } func captureMany(vals *[]string) captureFunc { @@ -63,17 +94,18 @@ func captureMany(vals *[]string) captureFunc { *vals = append(*vals, s) return nil } - return &f + return f } func captureValWithF(code int, f captureFunc) Pattern { - return func(next *MatchState) *MatchState { - return &MatchState{ - kind: matchCode, - capture: f, - code: code, - next: next, + return func(states *[]MatchState, nextIdx int) int { + newState := MatchState{ + capture: f, + codeOrKind: code, + next: nextIdx, } + *states = append(*states, newState) + return len(*states) - 1 } } @@ -90,18 +122,27 @@ func ZeroOrMore(code int) Pattern { } func captureZeroOrMoreWithF(code int, f captureFunc) Pattern { - return func(next *MatchState) *MatchState { - match := &MatchState{ - code: code, - capture: f, + return func(states *[]MatchState, nextIdx int) int { + // Create the match state. + matchState := MatchState{ + codeOrKind: code, + capture: f, } - s := &MatchState{ - kind: split, - next: match, - nextSplit: next, + *states = append(*states, matchState) + matchIdx := len(*states) - 1 + + // Create the split state that branches to the match state and to the next state. + s := MatchState{ + next: matchIdx, + codeOrKind: storeSplitIdx(nextIdx), } - match.next = s // Loop back to the split. - return s + *states = append(*states, s) + splitIdx := len(*states) - 1 + + // Close the loop: update the match state's next field. + (*states)[matchIdx].next = splitIdx + + return splitIdx } } @@ -112,19 +153,24 @@ func CaptureZeroOrMore(code int, vals *[]string) Pattern { func OneOrMore(code int) Pattern { return CaptureOneOrMore(code, nil) } + func CaptureOneOrMore(code int, vals *[]string) Pattern { f := captureMany(vals) - return func(next *MatchState) *MatchState { - return captureValWithF(code, f)(captureZeroOrMoreWithF(code, f)(next)) + return func(states *[]MatchState, nextIdx int) int { + // First attach the zero-or-more loop. + zeroOrMoreIdx := captureZeroOrMoreWithF(code, f)(states, nextIdx) + // Then put the capture state before the loop. + return captureValWithF(code, f)(states, zeroOrMoreIdx) } } func Optional(s Pattern) Pattern { - return func(next *MatchState) *MatchState { - return &MatchState{ - kind: split, - next: s(next), - nextSplit: next, + return func(states *[]MatchState, nextIdx int) int { + newState := MatchState{ + next: s(states, nextIdx), + codeOrKind: storeSplitIdx(nextIdx), } + *states = append(*states, newState) + return len(*states) - 1 } } From ae47e22f6256d1dd44f38410357dc494ffacb785 Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Tue, 25 Feb 2025 16:16:10 -0800 Subject: [PATCH 07/10] feat(x/meg): Support capturing components (#269) * Use Matchable interface * Add Bytes to Matchable interface * feat(x/meg): Support capturing bytes * Export CaptureWithF Can be used by more specific capturers (e.g capture net.AddrIP) * Support Any match, RawValue, and multiple Concatenations * Add CaptureAddrPort --- meg_capturers.go | 54 ++++++++++++++++++++++ meg_test.go | 31 ++++++++++++- x/meg/bench_test.go | 75 +++++++++++++++++++++++------- x/meg/meg.go | 40 ++++++++++------ x/meg/meg_test.go | 30 +++++++++++- x/meg/sugar.go | 110 +++++++++++++++++++++++++++++++++++--------- 6 files changed, 282 insertions(+), 58 deletions(-) create mode 100644 meg_capturers.go diff --git a/meg_capturers.go b/meg_capturers.go new file mode 100644 index 0000000..21581e7 --- /dev/null +++ b/meg_capturers.go @@ -0,0 +1,54 @@ +package multiaddr + +import ( + "encoding/binary" + "fmt" + "net/netip" + + "github.com/multiformats/go-multiaddr/x/meg" +) + +func CaptureAddrPort(network *string, ipPort *netip.AddrPort) (capturePattern meg.Pattern) { + var ipOnly netip.Addr + capturePort := func(s meg.Matchable) error { + switch s.Code() { + case P_UDP: + *network = "udp" + case P_TCP: + *network = "tcp" + default: + return fmt.Errorf("invalid network: %s", s.Value()) + } + + port := binary.BigEndian.Uint16(s.RawValue()) + *ipPort = netip.AddrPortFrom(ipOnly, port) + return nil + } + + pattern := meg.Cat( + meg.Or( + meg.CaptureWithF(P_IP4, func(s meg.Matchable) error { + var ok bool + ipOnly, ok = netip.AddrFromSlice(s.RawValue()) + if !ok { + return fmt.Errorf("invalid ip4 address: %s", s.Value()) + } + return nil + }), + meg.CaptureWithF(P_IP6, func(s meg.Matchable) error { + var ok bool + ipOnly, ok = netip.AddrFromSlice(s.RawValue()) + if !ok { + return fmt.Errorf("invalid ip6 address: %s", s.Value()) + } + return nil + }), + ), + meg.Or( + meg.CaptureWithF(P_UDP, capturePort), + meg.CaptureWithF(P_TCP, capturePort), + ), + ) + + return pattern +} diff --git a/meg_test.go b/meg_test.go index 32fcadd..d5e0ee0 100644 --- a/meg_test.go +++ b/meg_test.go @@ -1,6 +1,7 @@ package multiaddr import ( + "net/netip" "testing" "github.com/multiformats/go-multiaddr/x/meg" @@ -16,10 +17,10 @@ func TestMatchAndCaptureMultiaddr(t *testing.T) { meg.Val(P_IP4), meg.Val(P_IP6), ), - meg.CaptureVal(P_UDP, &udpPort), + meg.CaptureStringVal(P_UDP, &udpPort), meg.Val(P_QUIC_V1), meg.Val(P_WEBTRANSPORT), - meg.CaptureZeroOrMore(P_CERTHASH, &certhashes), + meg.CaptureZeroOrMoreStringVals(P_CERTHASH, &certhashes), ) if !found { t.Fatal("failed to match") @@ -43,3 +44,29 @@ func TestMatchAndCaptureMultiaddr(t *testing.T) { } } } + +func TestCaptureAddrPort(t *testing.T) { + m := StringCast("/ip4/1.2.3.4/udp/8231/quic-v1/webtransport") + var addrPort netip.AddrPort + var network string + + found, err := m.Match( + CaptureAddrPort(&network, &addrPort), + meg.ZeroOrMore(meg.Any), + ) + if err != nil { + t.Fatal("error", err) + } + if !found { + t.Fatal("failed to match") + } + if !addrPort.IsValid() { + t.Fatal("failed to capture addrPort") + } + if network != "udp" { + t.Fatal("unexpected network", network) + } + if addrPort.String() != "1.2.3.4:8231" { + t.Fatal("unexpected ipPort", addrPort) + } +} diff --git a/x/meg/bench_test.go b/x/meg/bench_test.go index 15bf386..5eff21f 100644 --- a/x/meg/bench_test.go +++ b/x/meg/bench_test.go @@ -22,7 +22,7 @@ func preallocateCapture() *preallocatedCapture { ), meg.Val(multiaddr.P_UDP), meg.Val(multiaddr.P_WEBRTC_DIRECT), - meg.CaptureZeroOrMore(multiaddr.P_CERTHASH, &p.certHashes), + meg.CaptureZeroOrMoreStringVals(multiaddr.P_CERTHASH, &p.certHashes), ) return p } @@ -87,19 +87,19 @@ func isWebTransportMultiaddrPrealloc() *preallocatedCapture { var sni string p.matcher = meg.PatternToMatcher( meg.Or( - meg.CaptureVal(multiaddr.P_IP4, &ip4Addr), - meg.CaptureVal(multiaddr.P_IP6, &ip6Addr), - meg.CaptureVal(multiaddr.P_DNS4, &dnsName), - meg.CaptureVal(multiaddr.P_DNS6, &dnsName), - meg.CaptureVal(multiaddr.P_DNS, &dnsName), + meg.CaptureStringVal(multiaddr.P_IP4, &ip4Addr), + meg.CaptureStringVal(multiaddr.P_IP6, &ip6Addr), + meg.CaptureStringVal(multiaddr.P_DNS4, &dnsName), + meg.CaptureStringVal(multiaddr.P_DNS6, &dnsName), + meg.CaptureStringVal(multiaddr.P_DNS, &dnsName), ), - meg.CaptureVal(multiaddr.P_UDP, &udpPort), + meg.CaptureStringVal(multiaddr.P_UDP, &udpPort), meg.Val(multiaddr.P_QUIC_V1), meg.Optional( - meg.CaptureVal(multiaddr.P_SNI, &sni), + meg.CaptureStringVal(multiaddr.P_SNI, &sni), ), meg.Val(multiaddr.P_WEBTRANSPORT), - meg.CaptureZeroOrMore(multiaddr.P_CERTHASH, &p.certHashes), + meg.CaptureZeroOrMoreStringVals(multiaddr.P_CERTHASH, &p.certHashes), ) wtPrealloc = p return p @@ -120,19 +120,19 @@ func IsWebTransportMultiaddr(m multiaddr.Multiaddr) (bool, int) { var certHashesStr []string matched, _ := m.Match( meg.Or( - meg.CaptureVal(multiaddr.P_IP4, &ip4Addr), - meg.CaptureVal(multiaddr.P_IP6, &ip6Addr), - meg.CaptureVal(multiaddr.P_DNS4, &dnsName), - meg.CaptureVal(multiaddr.P_DNS6, &dnsName), - meg.CaptureVal(multiaddr.P_DNS, &dnsName), + meg.CaptureStringVal(multiaddr.P_IP4, &ip4Addr), + meg.CaptureStringVal(multiaddr.P_IP6, &ip6Addr), + meg.CaptureStringVal(multiaddr.P_DNS4, &dnsName), + meg.CaptureStringVal(multiaddr.P_DNS6, &dnsName), + meg.CaptureStringVal(multiaddr.P_DNS, &dnsName), ), - meg.CaptureVal(multiaddr.P_UDP, &udpPort), + meg.CaptureStringVal(multiaddr.P_UDP, &udpPort), meg.Val(multiaddr.P_QUIC_V1), meg.Optional( - meg.CaptureVal(multiaddr.P_SNI, &sni), + meg.CaptureStringVal(multiaddr.P_SNI, &sni), ), meg.Val(multiaddr.P_WEBTRANSPORT), - meg.CaptureZeroOrMore(multiaddr.P_CERTHASH, &certHashesStr), + meg.CaptureZeroOrMoreStringVals(multiaddr.P_CERTHASH, &certHashesStr), ) if !matched { return false, 0 @@ -140,6 +140,35 @@ func IsWebTransportMultiaddr(m multiaddr.Multiaddr) (bool, int) { return true, len(certHashesStr) } +func IsWebTransportMultiaddrCaptureBytes(m multiaddr.Multiaddr) (bool, int) { + var dnsName []byte + var ip4Addr []byte + var ip6Addr []byte + var udpPort []byte + var sni []byte + var certHashes [][]byte + matched, _ := m.Match( + meg.Or( + meg.CaptureBytes(multiaddr.P_IP4, &ip4Addr), + meg.CaptureBytes(multiaddr.P_IP6, &ip6Addr), + meg.CaptureBytes(multiaddr.P_DNS4, &dnsName), + meg.CaptureBytes(multiaddr.P_DNS6, &dnsName), + meg.CaptureBytes(multiaddr.P_DNS, &dnsName), + ), + meg.CaptureBytes(multiaddr.P_UDP, &udpPort), + meg.Val(multiaddr.P_QUIC_V1), + meg.Optional( + meg.CaptureBytes(multiaddr.P_SNI, &sni), + ), + meg.Val(multiaddr.P_WEBTRANSPORT), + meg.CaptureZeroOrMoreBytes(multiaddr.P_CERTHASH, &certHashes), + ) + if !matched { + return false, 0 + } + return true, len(certHashes) +} + func IsWebTransportMultiaddrNoCapture(m multiaddr.Multiaddr) (bool, int) { matched, _ := m.Match( meg.Or( @@ -355,6 +384,18 @@ func BenchmarkIsWebTransportMultiaddrNoCapture(b *testing.B) { } } +func BenchmarkIsWebTransportMultiaddrCaptureBytes(b *testing.B) { + addr := multiaddr.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1/sni/example.com/webtransport") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + isWT, count := IsWebTransportMultiaddrCaptureBytes(addr) + if !isWT || count != 0 { + b.Fatal("unexpected result") + } + } +} + func BenchmarkIsWebTransportMultiaddr(b *testing.B) { addr := multiaddr.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1/sni/example.com/webtransport") diff --git a/x/meg/meg.go b/x/meg/meg.go index ccc2275..c70a139 100644 --- a/x/meg/meg.go +++ b/x/meg/meg.go @@ -12,27 +12,30 @@ import ( type stateKind = int const ( - done stateKind = (iota * -1) - 1 - // split anything else that is negative + matchAny stateKind = (iota * -1) - 1 + // done MUST be the last stateKind in this list. We use it to determine if a + // state is a split index. + done + // Anything that is less than done is a split index ) // MatchState is the Thompson NFA for a regular expression. type MatchState struct { - capture captureFunc + capture CaptureFunc // next is is the index of the next state. in the MatchState array. next int // If codeOrKind is negative, it is a kind. - // If it is negative, but not a `done`, then it is the index to the next split. + // If it is negative, and less than `done`, then it is the index to the next split. // This is done to keep the `MatchState` struct small and cache friendly. codeOrKind int } -type captureFunc func(string) error +type CaptureFunc func(Matchable) error // capture is a linked list of capture funcs with values. type capture struct { - f captureFunc - v string + f CaptureFunc + v Matchable prev *capture } @@ -53,7 +56,14 @@ func (s MatchState) String() string { type Matchable interface { Code() int - Value() string // Used when capturing the value + // Value() returns the string representation of the matchable. + Value() string + // RawValue() returns the byte representation of the Value + RawValue() []byte + // Bytes() returns the underlying bytes of the matchable. For multiaddr + // Components, this includes the protocol code and possibly the varint + // encoded size. + Bytes() []byte } // Match returns whether the given Components match the Pattern defined in MatchState. @@ -89,12 +99,12 @@ func Match[S ~[]T, T Matchable](matcher Matcher, components S) (bool, error) { } for i, stateIndex := range currentStates.states { s := states[stateIndex] - if s.codeOrKind >= 0 && s.codeOrKind == c.Code() { + if s.codeOrKind == matchAny || (s.codeOrKind >= 0 && s.codeOrKind == c.Code()) { cm := currentStates.captures[i] if s.capture != nil { next := &capture{ f: s.capture, - v: c.Value(), + v: c, } if cm == nil { cm = next @@ -122,8 +132,8 @@ func Match[S ~[]T, T Matchable](matcher Matcher, components S) (bool, error) { // Flip the order of the captures because we see captures from right // to left, but users expect them left to right. type captureWithVal struct { - f captureFunc - v string + f CaptureFunc + v Matchable } reversedCaptures := make([]captureWithVal, 0, 16) for c != nil { @@ -190,10 +200,12 @@ func appendState(arr statesAndCaptures, states []MatchState, stateIndex int, c * return arr } +const splitIdxOffset = (-1 * (done - 1)) + func storeSplitIdx(codeOrKind int) int { - return (codeOrKind + 2) * -1 + return (codeOrKind + splitIdxOffset) * -1 } func restoreSplitIdx(splitIdx int) int { - return (splitIdx * -1) - 2 + return (splitIdx * -1) - splitIdxOffset } diff --git a/x/meg/meg_test.go b/x/meg/meg_test.go index e0265c4..8f5cff3 100644 --- a/x/meg/meg_test.go +++ b/x/meg/meg_test.go @@ -22,6 +22,16 @@ func (c codeAndValue) Value() string { return c.val } +// Bytes implements Matchable. +func (c codeAndValue) Bytes() []byte { + return []byte(c.val) +} + +// RawValue implements Matchable. +func (c codeAndValue) RawValue() []byte { + return []byte(c.val) +} + var _ Matchable = codeAndValue{} func TestSimple(t *testing.T) { @@ -33,6 +43,22 @@ func TestSimple(t *testing.T) { } testCases := []testCase{ + { + pattern: PatternToMatcher(Val(Any), Val(1)), + shouldMatch: [][]int{ + {0, 1}, + {1, 1}, + {2, 1}, + {3, 1}, + {4, 1}, + }, + shouldNotMatch: [][]int{ + {0}, + {0, 0}, + {0, 1, 0}, + }, + skipQuickCheck: true, + }, { pattern: PatternToMatcher(Val(0), Val(1)), shouldMatch: [][]int{{0, 1}}, @@ -119,7 +145,7 @@ func TestCapture(t *testing.T) { { setup: func() (Matcher, func()) { var code0str string - return PatternToMatcher(CaptureVal(0, &code0str), Val(1)), func() { + return PatternToMatcher(CaptureStringVal(0, &code0str), Val(1)), func() { if code0str != "hello" { panic("unexpected value") } @@ -130,7 +156,7 @@ func TestCapture(t *testing.T) { { setup: func() (Matcher, func()) { var code0strs []string - return PatternToMatcher(CaptureOneOrMore(0, &code0strs), Val(1)), func() { + return PatternToMatcher(CaptureOneOrMoreStringVals(0, &code0strs), Val(1)), func() { if code0strs[0] != "hello" { panic("unexpected value") } diff --git a/x/meg/sugar.go b/x/meg/sugar.go index ee961cc..41e18e3 100644 --- a/x/meg/sugar.go +++ b/x/meg/sugar.go @@ -40,10 +40,26 @@ func PatternToMatcher(patterns ...Pattern) Matcher { return Matcher{states: states, startIdx: nextIdx} } -func Cat(left, right Pattern) Pattern { - return func(states *[]MatchState, nextIdx int) int { - // First run the right pattern, then feed the result into left. - return left(states, right(states, nextIdx)) +func Cat(patterns ...Pattern) Pattern { + switch len(patterns) { + case 0: + return func(states *[]MatchState, nextIdx int) int { + return nextIdx + } + case 1: + return patterns[0] + case 2: + return func(states *[]MatchState, nextIdx int) int { + left := patterns[0] + right := patterns[1] + // First run the right pattern, then feed the result into left. + return left(states, right(states, nextIdx)) + } + default: + return Cat( + Cat(patterns[:len(patterns)-1]...), + patterns[len(patterns)-1], + ) } } @@ -70,34 +86,61 @@ func Or(p ...Pattern) Pattern { var errAlreadyCapture = errors.New("already captured") -func captureOneValueOrErr(val *string) captureFunc { +func captureOneBytesOrErr(val *[]byte) CaptureFunc { + if val == nil { + return nil + } + var set bool + f := func(s Matchable) error { + if set { + *val = nil + return errAlreadyCapture + } + *val = s.RawValue() + return nil + } + return f +} + +func captureOneStringValueOrErr(val *string) CaptureFunc { if val == nil { return nil } var set bool - f := func(s string) error { + f := func(s Matchable) error { if set { *val = "" return errAlreadyCapture } - *val = s + *val = s.Value() return nil } return f } -func captureMany(vals *[]string) captureFunc { +func captureManyBytes(vals *[][]byte) CaptureFunc { if vals == nil { return nil } - f := func(s string) error { - *vals = append(*vals, s) + f := func(s Matchable) error { + *vals = append(*vals, s.RawValue()) return nil } return f } -func captureValWithF(code int, f captureFunc) Pattern { +func captureManyStrings(vals *[]string) CaptureFunc { + if vals == nil { + return nil + } + f := func(s Matchable) error { + *vals = append(*vals, s.Value()) + return nil + } + return f +} + +func CaptureWithF(code int, f CaptureFunc) Pattern { return func(states *[]MatchState, nextIdx int) int { newState := MatchState{ capture: f, @@ -110,18 +153,25 @@ func captureValWithF(code int, f captureFunc) Pattern { } func Val(code int) Pattern { - return CaptureVal(code, nil) + return CaptureStringVal(code, nil) +} + +// Any is a special code that matches any value. +var Any int = matchAny + +func CaptureStringVal(code int, val *string) Pattern { + return CaptureWithF(code, captureOneStringValueOrErr(val)) } -func CaptureVal(code int, val *string) Pattern { - return captureValWithF(code, captureOneValueOrErr(val)) +func CaptureBytes(code int, val *[]byte) Pattern { + return CaptureWithF(code, captureOneBytesOrErr(val)) } func ZeroOrMore(code int) Pattern { - return CaptureZeroOrMore(code, nil) + return CaptureZeroOrMoreStringVals(code, nil) } -func captureZeroOrMoreWithF(code int, f captureFunc) Pattern { +func CaptureZeroOrMoreWithF(code int, f CaptureFunc) Pattern { return func(states *[]MatchState, nextIdx int) int { // Create the match state. matchState := MatchState{ @@ -146,21 +196,35 @@ func captureZeroOrMoreWithF(code int, f captureFunc) Pattern { } } -func CaptureZeroOrMore(code int, vals *[]string) Pattern { - return captureZeroOrMoreWithF(code, captureMany(vals)) +func CaptureZeroOrMoreBytes(code int, vals *[][]byte) Pattern { + return CaptureZeroOrMoreWithF(code, captureManyBytes(vals)) +} + +func CaptureZeroOrMoreStringVals(code int, vals *[]string) Pattern { + return CaptureZeroOrMoreWithF(code, captureManyStrings(vals)) } func OneOrMore(code int) Pattern { - return CaptureOneOrMore(code, nil) + return CaptureOneOrMoreStringVals(code, nil) +} + +func CaptureOneOrMoreStringVals(code int, vals *[]string) Pattern { + f := captureManyStrings(vals) + return func(states *[]MatchState, nextIdx int) int { + // First attach the zero-or-more loop. + zeroOrMoreIdx := CaptureZeroOrMoreWithF(code, f)(states, nextIdx) + // Then put the capture state before the loop. + return CaptureWithF(code, f)(states, zeroOrMoreIdx) + } } -func CaptureOneOrMore(code int, vals *[]string) Pattern { - f := captureMany(vals) +func CaptureOneOrMoreBytes(code int, vals *[][]byte) Pattern { + f := captureManyBytes(vals) return func(states *[]MatchState, nextIdx int) int { // First attach the zero-or-more loop. - zeroOrMoreIdx := captureZeroOrMoreWithF(code, f)(states, nextIdx) + zeroOrMoreIdx := CaptureZeroOrMoreWithF(code, f)(states, nextIdx) // Then put the capture state before the loop. - return captureValWithF(code, f)(states, zeroOrMoreIdx) + return CaptureWithF(code, f)(states, zeroOrMoreIdx) } } From 6a85b40a53e55027daf218075601fc1ed9760bc3 Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Tue, 25 Feb 2025 16:22:44 -0800 Subject: [PATCH 08/10] Fix typo in rebase --- component.go | 1 + 1 file changed, 1 insertion(+) diff --git a/component.go b/component.go index 0158bd1..c9b618f 100644 --- a/component.go +++ b/component.go @@ -167,6 +167,7 @@ func (c *Component) Code() int { return 0 } return c.Protocol().Code +} func (c *Component) RawValue() []byte { if c == nil { From 0c5383d595a8035fa1fd9b944d0924ae8627231f Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Tue, 25 Feb 2025 18:39:49 -0800 Subject: [PATCH 09/10] workaround for go generics --- util.go | 2 +- x/meg/bench_test.go | 10 +++++++--- x/meg/meg.go | 12 ++++++++---- x/meg/meg_test.go | 24 ++++++++++++++---------- 4 files changed, 30 insertions(+), 18 deletions(-) diff --git a/util.go b/util.go index 2cfed19..176c2d5 100644 --- a/util.go +++ b/util.go @@ -125,5 +125,5 @@ func ForEach(m Multiaddr, cb func(c Component) bool) { func (m Multiaddr) Match(p ...meg.Pattern) (bool, error) { matcher := meg.PatternToMatcher(p...) - return meg.Match(matcher, m) + return meg.Match(matcher, m, func(c *Component) meg.Matchable { return c }) } diff --git a/x/meg/bench_test.go b/x/meg/bench_test.go index 5eff21f..c898a26 100644 --- a/x/meg/bench_test.go +++ b/x/meg/bench_test.go @@ -29,8 +29,12 @@ func preallocateCapture() *preallocatedCapture { var webrtcMatchPrealloc *preallocatedCapture +func componentPtrToMatchable(c *multiaddr.Component) *multiaddr.Component { + return c +} + func (p *preallocatedCapture) IsWebRTCDirectMultiaddr(addr multiaddr.Multiaddr) (bool, int) { - found, _ := meg.Match(p.matcher, addr) + found, _ := meg.Match(p.matcher, addr, componentPtrToMatchable) return found, len(p.certHashes) } @@ -107,7 +111,7 @@ func isWebTransportMultiaddrPrealloc() *preallocatedCapture { func IsWebTransportMultiaddrPrealloc(m multiaddr.Multiaddr) (bool, int) { p := isWebTransportMultiaddrPrealloc() - found, _ := meg.Match(p.matcher, m) + found, _ := meg.Match(p.matcher, m, componentPtrToMatchable) return found, len(p.certHashes) } @@ -365,7 +369,7 @@ func BenchmarkIsWebTransportMultiaddrNoCapturePrealloc(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - isWT, _ := meg.Match(wtPreallocNoCapture, addr) + isWT, _ := meg.Match(wtPreallocNoCapture, addr, componentPtrToMatchable) if !isWT { b.Fatal("unexpected result") } diff --git a/x/meg/meg.go b/x/meg/meg.go index c70a139..ef8b58c 100644 --- a/x/meg/meg.go +++ b/x/meg/meg.go @@ -69,7 +69,10 @@ type Matchable interface { // Match returns whether the given Components match the Pattern defined in MatchState. // Errors are used to communicate capture errors. // If the error is non-nil the returned bool will be false. -func Match[S ~[]T, T Matchable](matcher Matcher, components S) (bool, error) { +// The ptrToMatchable function is used to convert type *T to a Matchable.. +// This is due to a limitation of Go generics, where we cannot say *T implements Matchable. +// When meg moves out of the x/ directory, we can reference the `*Component` type directly and avoid this limitation. +func Match[S ~[]T, T any, G Matchable](matcher Matcher, components S, ptrToMatchable func(*T) G) (bool, error) { states := matcher.states startStateIdx := matcher.startIdx @@ -92,19 +95,20 @@ func Match[S ~[]T, T Matchable](matcher Matcher, components S) (bool, error) { currentStates = appendState(currentStates, states, startStateIdx, nil, visitedBitSet) - for _, c := range components { + for ic := range components { clear(visitedBitSet) if len(currentStates.states) == 0 { return false, nil } for i, stateIndex := range currentStates.states { s := states[stateIndex] - if s.codeOrKind == matchAny || (s.codeOrKind >= 0 && s.codeOrKind == c.Code()) { + cPtr := ptrToMatchable(&components[ic]) + if s.codeOrKind == matchAny || (s.codeOrKind >= 0 && s.codeOrKind == cPtr.Code()) { cm := currentStates.captures[i] if s.capture != nil { next := &capture{ f: s.capture, - v: c, + v: cPtr, } if cm == nil { cm = next diff --git a/x/meg/meg_test.go b/x/meg/meg_test.go index 8f5cff3..b107866 100644 --- a/x/meg/meg_test.go +++ b/x/meg/meg_test.go @@ -13,26 +13,30 @@ type codeAndValue struct { } // Code implements Matchable. -func (c codeAndValue) Code() int { +func (c *codeAndValue) Code() int { return c.code } // Value implements Matchable. -func (c codeAndValue) Value() string { +func (c *codeAndValue) Value() string { return c.val } // Bytes implements Matchable. -func (c codeAndValue) Bytes() []byte { +func (c *codeAndValue) Bytes() []byte { return []byte(c.val) } // RawValue implements Matchable. -func (c codeAndValue) RawValue() []byte { +func (c *codeAndValue) RawValue() []byte { return []byte(c.val) } -var _ Matchable = codeAndValue{} +var _ Matchable = &codeAndValue{} + +func codeAndValuePtrToMatchable(c *codeAndValue) *codeAndValue { + return c +} func TestSimple(t *testing.T) { type testCase struct { @@ -106,12 +110,12 @@ func TestSimple(t *testing.T) { for i, tc := range testCases { for _, m := range tc.shouldMatch { - if matches, err := Match(tc.pattern, codesToCodeAndValue(m)); !matches { + if matches, err := Match(tc.pattern, codesToCodeAndValue(m), codeAndValuePtrToMatchable); !matches { t.Fatalf("failed to match %v with %v. idx=%d. err=%v", m, tc.pattern, i, err) } } for _, m := range tc.shouldNotMatch { - if matches, _ := Match(tc.pattern, codesToCodeAndValue(m)); matches { + if matches, _ := Match(tc.pattern, codesToCodeAndValue(m), codeAndValuePtrToMatchable); matches { t.Fatalf("failed to not match %v with %v. idx=%d", m, tc.pattern, i) } } @@ -125,7 +129,7 @@ func TestSimple(t *testing.T) { return true } } - matches, _ := Match(tc.pattern, codesToCodeAndValue(notMatch)) + matches, _ := Match(tc.pattern, codesToCodeAndValue(notMatch), codeAndValuePtrToMatchable) return !matches }, &quick.Config{}); err != nil { t.Fatal(err) @@ -172,7 +176,7 @@ func TestCapture(t *testing.T) { _ = testCases for _, tc := range testCases { state, assert := tc.setup() - if matches, _ := Match(state, tc.parts); !matches { + if matches, _ := Match(state, tc.parts, codeAndValuePtrToMatchable); !matches { t.Fatalf("failed to match %v with %v", tc.parts, state) } assert() @@ -255,7 +259,7 @@ func FuzzMatchesRegexpBehavior(f *testing.F) { return } p := PatternToMatcher(pattern...) - otherMatched, _ := Match(p, bytesToCodeAndValue(corpus)) + otherMatched, _ := Match(p, bytesToCodeAndValue(corpus), codeAndValuePtrToMatchable) if otherMatched != matched { t.Log("regexp", string(regexpPattern)) t.Log("corpus", string(corpus)) From c6d7d99ee23917b4563b3e14c5c9a0386810e57e Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Mon, 17 Mar 2025 15:05:35 -0700 Subject: [PATCH 10/10] less hacky workaround for go generics --- util.go | 11 ++++++++++- x/meg/bench_test.go | 16 +++++++++++----- x/meg/meg.go | 19 +++++++++++++------ x/meg/meg_test.go | 28 +++++++++++++++++----------- 4 files changed, 51 insertions(+), 23 deletions(-) diff --git a/util.go b/util.go index 176c2d5..c2a3b80 100644 --- a/util.go +++ b/util.go @@ -123,7 +123,16 @@ func ForEach(m Multiaddr, cb func(c Component) bool) { } } +type componentList []Component + +func (m componentList) Get(i int) meg.Matchable { + return &m[i] +} + +func (m componentList) Len() int { + return len(m) +} func (m Multiaddr) Match(p ...meg.Pattern) (bool, error) { matcher := meg.PatternToMatcher(p...) - return meg.Match(matcher, m, func(c *Component) meg.Matchable { return c }) + return meg.Match(matcher, componentList(m)) } diff --git a/x/meg/bench_test.go b/x/meg/bench_test.go index c898a26..4996f58 100644 --- a/x/meg/bench_test.go +++ b/x/meg/bench_test.go @@ -29,12 +29,18 @@ func preallocateCapture() *preallocatedCapture { var webrtcMatchPrealloc *preallocatedCapture -func componentPtrToMatchable(c *multiaddr.Component) *multiaddr.Component { - return c +type componentList []multiaddr.Component + +func (m componentList) Get(i int) meg.Matchable { + return &m[i] +} + +func (m componentList) Len() int { + return len(m) } func (p *preallocatedCapture) IsWebRTCDirectMultiaddr(addr multiaddr.Multiaddr) (bool, int) { - found, _ := meg.Match(p.matcher, addr, componentPtrToMatchable) + found, _ := meg.Match(p.matcher, componentList(addr)) return found, len(p.certHashes) } @@ -111,7 +117,7 @@ func isWebTransportMultiaddrPrealloc() *preallocatedCapture { func IsWebTransportMultiaddrPrealloc(m multiaddr.Multiaddr) (bool, int) { p := isWebTransportMultiaddrPrealloc() - found, _ := meg.Match(p.matcher, m, componentPtrToMatchable) + found, _ := meg.Match(p.matcher, componentList(m)) return found, len(p.certHashes) } @@ -369,7 +375,7 @@ func BenchmarkIsWebTransportMultiaddrNoCapturePrealloc(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - isWT, _ := meg.Match(wtPreallocNoCapture, addr, componentPtrToMatchable) + isWT, _ := meg.Match(wtPreallocNoCapture, componentList(addr)) if !isWT { b.Fatal("unexpected result") } diff --git a/x/meg/meg.go b/x/meg/meg.go index ef8b58c..7bd825d 100644 --- a/x/meg/meg.go +++ b/x/meg/meg.go @@ -66,13 +66,20 @@ type Matchable interface { Bytes() []byte } +// ListOfMatchable is anything list-like that contains Matchable items. +// This allows us to convert a slice of []T as a []Matchable when *T implements Matchable. +// In the future, this may not be required if Go generics allows us to say S ~[]T, and *T implements Matchable. +type ListOfMatchable interface { + Get(i int) Matchable + Len() int +} + // Match returns whether the given Components match the Pattern defined in MatchState. // Errors are used to communicate capture errors. // If the error is non-nil the returned bool will be false. -// The ptrToMatchable function is used to convert type *T to a Matchable.. -// This is due to a limitation of Go generics, where we cannot say *T implements Matchable. -// When meg moves out of the x/ directory, we can reference the `*Component` type directly and avoid this limitation. -func Match[S ~[]T, T any, G Matchable](matcher Matcher, components S, ptrToMatchable func(*T) G) (bool, error) { + +// Components must be a ListOfMatchable to allow us to use a slice of []T as a []Matchable when *T implements Matchable. +func Match[L ListOfMatchable](matcher Matcher, components L) (bool, error) { states := matcher.states startStateIdx := matcher.startIdx @@ -95,14 +102,14 @@ func Match[S ~[]T, T any, G Matchable](matcher Matcher, components S, ptrToMatch currentStates = appendState(currentStates, states, startStateIdx, nil, visitedBitSet) - for ic := range components { + for ic := range components.Len() { clear(visitedBitSet) if len(currentStates.states) == 0 { return false, nil } for i, stateIndex := range currentStates.states { s := states[stateIndex] - cPtr := ptrToMatchable(&components[ic]) + cPtr := components.Get(ic) if s.codeOrKind == matchAny || (s.codeOrKind >= 0 && s.codeOrKind == cPtr.Code()) { cm := currentStates.captures[i] if s.capture != nil { diff --git a/x/meg/meg_test.go b/x/meg/meg_test.go index b107866..c644007 100644 --- a/x/meg/meg_test.go +++ b/x/meg/meg_test.go @@ -34,10 +34,6 @@ func (c *codeAndValue) RawValue() []byte { var _ Matchable = &codeAndValue{} -func codeAndValuePtrToMatchable(c *codeAndValue) *codeAndValue { - return c -} - func TestSimple(t *testing.T) { type testCase struct { pattern Matcher @@ -110,12 +106,12 @@ func TestSimple(t *testing.T) { for i, tc := range testCases { for _, m := range tc.shouldMatch { - if matches, err := Match(tc.pattern, codesToCodeAndValue(m), codeAndValuePtrToMatchable); !matches { + if matches, err := Match(tc.pattern, codesToCodeAndValue(m)); !matches { t.Fatalf("failed to match %v with %v. idx=%d. err=%v", m, tc.pattern, i, err) } } for _, m := range tc.shouldNotMatch { - if matches, _ := Match(tc.pattern, codesToCodeAndValue(m), codeAndValuePtrToMatchable); matches { + if matches, _ := Match(tc.pattern, codesToCodeAndValue(m)); matches { t.Fatalf("failed to not match %v with %v. idx=%d", m, tc.pattern, i) } } @@ -129,7 +125,7 @@ func TestSimple(t *testing.T) { return true } } - matches, _ := Match(tc.pattern, codesToCodeAndValue(notMatch), codeAndValuePtrToMatchable) + matches, _ := Match(tc.pattern, codesToCodeAndValue(notMatch)) return !matches }, &quick.Config{}); err != nil { t.Fatal(err) @@ -176,14 +172,24 @@ func TestCapture(t *testing.T) { _ = testCases for _, tc := range testCases { state, assert := tc.setup() - if matches, _ := Match(state, tc.parts, codeAndValuePtrToMatchable); !matches { + if matches, _ := Match(state, codeAndValueList(tc.parts)); !matches { t.Fatalf("failed to match %v with %v", tc.parts, state) } assert() } } -func codesToCodeAndValue(codes []int) []codeAndValue { +type codeAndValueList []codeAndValue + +func (c codeAndValueList) Get(i int) Matchable { + return &c[i] +} + +func (c codeAndValueList) Len() int { + return len(c) +} + +func codesToCodeAndValue(codes []int) codeAndValueList { out := make([]codeAndValue, len(codes)) for i, c := range codes { out[i] = codeAndValue{code: c} @@ -191,7 +197,7 @@ func codesToCodeAndValue(codes []int) []codeAndValue { return out } -func bytesToCodeAndValue(codes []byte) []codeAndValue { +func bytesToCodeAndValue(codes []byte) codeAndValueList { out := make([]codeAndValue, len(codes)) for i, c := range codes { out[i] = codeAndValue{code: int(c)} @@ -259,7 +265,7 @@ func FuzzMatchesRegexpBehavior(f *testing.F) { return } p := PatternToMatcher(pattern...) - otherMatched, _ := Match(p, bytesToCodeAndValue(corpus), codeAndValuePtrToMatchable) + otherMatched, _ := Match(p, bytesToCodeAndValue(corpus)) if otherMatched != matched { t.Log("regexp", string(regexpPattern)) t.Log("corpus", string(corpus))