Skip to content

Commit 8b608fc

Browse files
authored
add profiling runtime flag, improve merkle performance (#105)
1 parent 5f28063 commit 8b608fc

File tree

10 files changed

+212
-82
lines changed

10 files changed

+212
-82
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
tmp/
22
.env
3+
*.pprof

api/migrations/scripts/000-rebuild-proofs-hashes/main.go

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package main
22

33
import (
44
"context"
5-
"errors"
65
"fmt"
76
"log"
87
"os"
@@ -43,13 +42,10 @@ func migrateTree(
4342
)
4443
eg.SetLimit(runtime.NumCPU())
4544

46-
for _, l := range leaves {
47-
l := l //avoid capture
45+
for i := range leaves {
46+
i := i
4847
eg.Go(func() error {
49-
pf := tree.Proof(l)
50-
if !merkle.Valid(tree.Root(), pf, l) {
51-
return errors.New("invalid proof for tree")
52-
}
48+
pf := tree.Proof(i)
5349
proofHash := hashProof(pf)
5450
pm.Lock()
5551
proofHashes = append(proofHashes, []any{tree.Root(), proofHash})

api/proof.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@ func (s *Server) GetProof(w http.ResponseWriter, r *http.Request) {
2121
ctx = r.Context()
2222
root = common.FromHex(r.URL.Query().Get("root"))
2323
leaf = common.FromHex(r.URL.Query().Get("unhashedLeaf"))
24-
addr = common.HexToAddress(r.URL.Query().Get("address"))
24+
addr = common.FromHex(r.URL.Query().Get("address"))
2525
)
26+
2627
if len(root) == 0 {
2728
s.sendJSONError(r, w, nil, http.StatusBadRequest, "missing root")
2829
return
2930
}
30-
if len(leaf) == 0 && addr == (common.Address{}) {
31+
if len(leaf) == 0 && len(addr) == 0 {
3132
s.sendJSONError(r, w, nil, http.StatusBadRequest, "missing leaf")
3233
return
3334
}
@@ -53,7 +54,7 @@ func (s *Server) GetProof(w http.ResponseWriter, r *http.Request) {
5354
if bytes.Equal(l, leaf) {
5455
target = l
5556
}
56-
} else if leaf2Addr(l, td.Ltd, td.Packed).Hex() == addr.Hex() {
57+
} else if bytes.Equal(leaf2Addr(l, td.Ltd, td.Packed), addr) {
5758
target = l
5859
}
5960
}
@@ -67,7 +68,8 @@ func (s *Server) GetProof(w http.ResponseWriter, r *http.Request) {
6768
}
6869

6970
var (
70-
p = merkle.New(leaves).Proof(target)
71+
mt = merkle.New(leaves)
72+
p = mt.Proof(mt.Index(target))
7173
phex = []hexutil.Bytes{}
7274
)
7375

api/tree.go

Lines changed: 25 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"encoding/json"
66
"errors"
77
"net/http"
8-
"sync"
98

109
"github.com/contextwtf/lanyard/merkle"
1110
"github.com/ethereum/go-ethereum/accounts/abi"
@@ -14,7 +13,6 @@ import (
1413
"github.com/ethereum/go-ethereum/crypto"
1514
"github.com/jackc/pgx/v4"
1615
"github.com/jackc/pgx/v4/pgxpool"
17-
"golang.org/x/sync/errgroup"
1816
)
1917

2018
func (s *Server) TreeHandler(w http.ResponseWriter, r *http.Request) {
@@ -31,12 +29,12 @@ func (s *Server) TreeHandler(w http.ResponseWriter, r *http.Request) {
3129
}
3230
}
3331

34-
func leaf2Addr(leaf []byte, ltd []string, packed bool) common.Address {
35-
if len(ltd) == 0 || (len(ltd) == 1 && ltd[0] == "address") {
36-
return common.BytesToAddress(leaf)
32+
func leaf2Addr(leaf []byte, ltd []string, packed bool) []byte {
33+
if len(ltd) == 0 || (len(ltd) == 1 && ltd[0] == "address" && len(leaf) == 20) {
34+
return leaf
3735
}
3836
if ltd[len(ltd)-1] == "address" && len(leaf) > 20 {
39-
return common.BytesToAddress(leaf[len(leaf)-20:])
37+
return leaf[len(leaf)-20:]
4038
}
4139

4240
if packed {
@@ -45,7 +43,7 @@ func leaf2Addr(leaf []byte, ltd []string, packed bool) common.Address {
4543
return addrUnpacked(leaf, ltd)
4644
}
4745

48-
func addrUnpacked(leaf []byte, ltd []string) common.Address {
46+
func addrUnpacked(leaf []byte, ltd []string) []byte {
4947
var addrStart, pos int
5048
for _, desc := range ltd {
5149
if desc == "address" {
@@ -54,32 +52,33 @@ func addrUnpacked(leaf []byte, ltd []string) common.Address {
5452
}
5553
pos += 32
5654
}
55+
5756
if len(leaf) >= addrStart+32 {
58-
return common.BytesToAddress(leaf[addrStart:(addrStart + 32)])
57+
l := leaf[addrStart:(addrStart + 32)]
58+
return l[len(l)-20:] // take last 20 bytes
5959
}
60-
return common.Address{}
60+
return []byte{}
6161
}
6262

63-
func addrPacked(leaf []byte, ltd []string) common.Address {
63+
func addrPacked(leaf []byte, ltd []string) []byte {
6464
var addrStart, pos int
6565
for _, desc := range ltd {
6666
t, err := abi.NewType(desc, "", nil)
6767
if err != nil {
68-
return common.Address{}
69-
}
70-
if desc == "address" {
68+
return []byte{}
69+
} else if desc == "address" {
7170
addrStart = pos
7271
break
7372
}
7473
pos += int(t.GetType().Size())
7574
}
7675
if addrStart == 0 && pos != 0 {
77-
return common.Address{}
76+
return []byte{}
7877
}
7978
if len(leaf) >= addrStart+20 {
80-
return common.BytesToAddress(leaf[addrStart:(addrStart + 20)])
79+
return leaf[addrStart:(addrStart + 20)]
8180
}
82-
return common.Address{}
81+
return []byte{}
8382
}
8483

8584
func hashProof(p [][]byte) []byte {
@@ -117,14 +116,14 @@ func (s *Server) CreateTree(w http.ResponseWriter, r *http.Request) {
117116

118117
var leaves [][]byte
119118
for _, l := range req.Leaves {
120-
// use the go-ethereum HexDecode method because it is more
119+
// use the go-ethereum FromHex method because it is more
121120
// lenient and will allow for odd-length hex strings (by padding them)
122121
leaves = append(leaves, common.FromHex(l))
123122
}
124123

125-
tree := merkle.New(leaves)
126-
root := tree.Root()
127124
var (
125+
tree = merkle.New(leaves)
126+
root = tree.Root()
128127
exists bool
129128
)
130129

@@ -147,25 +146,15 @@ func (s *Server) CreateTree(w http.ResponseWriter, r *http.Request) {
147146
}
148147

149148
var (
150-
proofHashes = [][]any{}
151-
eg errgroup.Group
152-
pm sync.Mutex
149+
proofHashes = make([][]any, 0, len(leaves))
150+
allProofs = tree.LeafProofs()
153151
)
154-
for _, l := range leaves {
155-
l := l //avoid capture
156-
eg.Go(func() error {
157-
pf := tree.Proof(l)
158-
if !merkle.Valid(tree.Root(), pf, l) {
159-
return errors.New("invalid proof for tree")
160-
}
161-
proofHash := hashProof(pf)
162-
pm.Lock()
163-
proofHashes = append(proofHashes, []any{tree.Root(), proofHash})
164-
pm.Unlock()
165-
return nil
166-
})
152+
153+
for _, p := range allProofs {
154+
proofHash := hashProof(p)
155+
proofHashes = append(proofHashes, []any{root, proofHash})
167156
}
168-
err = eg.Wait()
157+
169158
if err != nil {
170159
s.sendJSONError(r, w, err, http.StatusBadRequest, "generating proofs for tree")
171160
return

api/tree_test.go

Lines changed: 79 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package api
22

33
import (
4+
"bytes"
45
"testing"
56

67
"github.com/ethereum/go-ethereum/common"
@@ -10,23 +11,23 @@ func TestAddrUnpacked(t *testing.T) {
1011
cases := []struct {
1112
leaf []byte
1213
ltd []string
13-
want common.Address
14+
want []byte
1415
}{
1516
{
16-
common.Hex2Bytes("00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"),
17+
common.FromHex("00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"),
1718
[]string{"uint32", "address"},
18-
common.HexToAddress("0x0000000000000000000000000000000000000001"),
19+
common.FromHex("0x0000000000000000000000000000000000000001"),
1920
},
2021
{
21-
common.Hex2Bytes("00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000000"),
22+
common.FromHex("00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000000"),
2223
[]string{"address", "uint32"},
23-
common.HexToAddress("0x0000000000000000000000000000000000000001"),
24+
common.FromHex("0x0000000000000000000000000000000000000001"),
2425
},
2526
}
2627

2728
for _, c := range cases {
2829
addr := addrUnpacked(c.leaf, c.ltd)
29-
if addr != c.want {
30+
if !bytes.Equal(addr, c.want) {
3031
t.Errorf("expected: %v got: %v", c.want, addr)
3132
}
3233
}
@@ -36,24 +37,90 @@ func TestAddrPacked(t *testing.T) {
3637
cases := []struct {
3738
leaf []byte
3839
ltd []string
39-
want common.Address
40+
want []byte
4041
}{
4142
{
42-
common.Hex2Bytes("000000000000000000000000000000000000000000000001"),
43+
common.FromHex("000000000000000000000000000000000000000000000001"),
4344
[]string{"uint32", "address"},
44-
common.HexToAddress("0x0000000000000000000000000000000000000001"),
45+
common.FromHex("0x0000000000000000000000000000000000000001"),
4546
},
4647
{
47-
common.Hex2Bytes("000000000000000000000000000000000000000100000000"),
48+
common.FromHex("000000000000000000000000000000000000000100000000"),
4849
[]string{"address", "uint32"},
49-
common.HexToAddress("0x0000000000000000000000000000000000000001"),
50+
common.FromHex("0x0000000000000000000000000000000000000001"),
5051
},
5152
}
5253

5354
for _, c := range cases {
5455
addr := addrPacked(c.leaf, c.ltd)
55-
if addr != c.want {
56+
if !bytes.Equal(addr, c.want) {
5657
t.Errorf("expected: %v got: %v", c.want, addr)
5758
}
5859
}
5960
}
61+
62+
func TestLeaf2Addr(t *testing.T) {
63+
cases := []struct {
64+
leaf []byte
65+
ltd []string
66+
packed bool
67+
want []byte
68+
}{
69+
{
70+
common.FromHex("000000000000000000000000000000000000000000000001"),
71+
[]string{"uint32", "address"},
72+
true,
73+
common.FromHex("0x0000000000000000000000000000000000000001"),
74+
},
75+
{
76+
common.FromHex("000000000000000000000000000000000000000100000000"),
77+
[]string{"address", "uint32"},
78+
true,
79+
common.FromHex("0x0000000000000000000000000000000000000001"),
80+
},
81+
{
82+
common.FromHex("0x0000000000000000000000000000000000000001"),
83+
[]string{"address"},
84+
false,
85+
common.FromHex("0x0000000000000000000000000000000000000001"),
86+
},
87+
{
88+
common.FromHex("00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"),
89+
[]string{"uint32", "address"},
90+
false,
91+
common.FromHex("0x0000000000000000000000000000000000000001"),
92+
},
93+
{
94+
common.FromHex("00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000000"),
95+
[]string{"address", "uint32"},
96+
false,
97+
common.FromHex("0x0000000000000000000000000000000000000001"),
98+
},
99+
{
100+
common.FromHex("00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000001"),
101+
[]string{"uint256", "address"},
102+
false,
103+
common.FromHex("0x0000000000000000000000000000000000000001"),
104+
},
105+
{
106+
common.FromHex("0x0000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000002d"),
107+
[]string{"address", "uint256"},
108+
false,
109+
common.FromHex("0x0000000000000000000000000000000000000001"),
110+
},
111+
{
112+
common.FromHex("0x0000000000000000000000000000000000000000000000000000000000000001"),
113+
[]string{"address"},
114+
true,
115+
common.FromHex("0x0000000000000000000000000000000000000001"),
116+
},
117+
}
118+
119+
for _, c := range cases {
120+
addr := leaf2Addr(c.leaf, c.ltd, c.packed)
121+
if !bytes.Equal(addr, c.want) {
122+
t.Errorf("expected: %v got: %v", c.want, addr)
123+
}
124+
}
125+
126+
}

cmd/api/main.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package main
33
import (
44
"context"
55
"database/sql"
6+
"flag"
67
"fmt"
78
"net"
89
"net/http"
@@ -17,6 +18,7 @@ import (
1718
"github.com/jackc/pgx/v4/pgxpool"
1819
_ "github.com/jackc/pgx/v4/stdlib"
1920
"github.com/opentracing/opentracing-go"
21+
"github.com/pkg/profile"
2022
"github.com/rs/zerolog"
2123
"github.com/rs/zerolog/log"
2224
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/opentracer"
@@ -39,6 +41,14 @@ func main() {
3941
env = "dev"
4042
}
4143

44+
shouldProfile := flag.Bool("profile", false, "enable profiling")
45+
flag.Parse()
46+
47+
if *shouldProfile {
48+
prof := profile.Start(profile.CPUProfile, profile.ProfilePath("."))
49+
defer prof.Stop()
50+
}
51+
4252
ddAgent := os.Getenv("DD_AGENT_HOST")
4353
if ddAgent != "" {
4454
t := opentracer.New(

go.mod

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ require (
88
github.com/jackc/pgx/v4 v4.16.1
99
github.com/lib/pq v1.10.9
1010
github.com/opentracing/opentracing-go v1.2.0
11+
github.com/pkg/profile v1.2.1
1112
github.com/rs/cors v1.8.2
1213
github.com/rs/zerolog v1.29.1
1314
golang.org/x/sync v0.3.0
@@ -43,6 +44,7 @@ require (
4344
github.com/philhofer/fwd v1.1.1 // indirect
4445
github.com/pkg/errors v0.9.1 // indirect
4546
github.com/rs/xid v1.4.0 // indirect
47+
github.com/stretchr/testify v1.8.0 // indirect
4648
github.com/tinylib/msgp v1.1.2 // indirect
4749
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa // indirect
4850
golang.org/x/sys v0.0.0-20220829200755-d48e67d00261 // indirect

0 commit comments

Comments
 (0)