Skip to content

Commit da76b8e

Browse files
authored
fix: simplify data model, proofs_hashes table (#101)
1 parent 82f3158 commit da76b8e

File tree

6 files changed

+269
-124
lines changed

6 files changed

+269
-124
lines changed

api/migrations/migrations.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,14 @@ var Migrations = []migrate.Migration{
116116
CREATE INDEX proofs_arr_idx ON trees_proofs USING GIN ((proofs_array(proofs)));
117117
`,
118118
},
119+
{
120+
Name: "2023-06-26.0.proofs_hashes.sql",
121+
SQL: `
122+
CREATE TABLE IF NOT EXISTS proofs_hashes (
123+
hash bytea,
124+
root bytea
125+
);
126+
CREATE INDEX IF NOT EXISTS proofs_hashes_hash_idx ON proofs_hashes (hash);
127+
`,
128+
},
119129
}
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"log"
8+
"os"
9+
"runtime"
10+
"runtime/debug"
11+
"sync"
12+
13+
"github.com/contextwtf/lanyard/merkle"
14+
"github.com/ethereum/go-ethereum/crypto"
15+
"github.com/jackc/pgx/v4"
16+
"github.com/jackc/pgx/v4/pgxpool"
17+
"golang.org/x/sync/errgroup"
18+
)
19+
20+
func check(err error) {
21+
if err != nil {
22+
fmt.Fprintf(os.Stderr, "processor error: %s", err)
23+
debug.PrintStack()
24+
os.Exit(1)
25+
}
26+
}
27+
28+
func hashProof(p [][]byte) []byte {
29+
return crypto.Keccak256(p...)
30+
}
31+
32+
func migrateTree(
33+
ctx context.Context,
34+
tx pgx.Tx,
35+
leaves [][]byte,
36+
) error {
37+
tree := merkle.New(leaves)
38+
39+
var (
40+
proofHashes = [][]any{}
41+
eg errgroup.Group
42+
pm sync.Mutex
43+
)
44+
eg.SetLimit(runtime.NumCPU())
45+
46+
for _, l := range leaves {
47+
l := l //avoid capture
48+
eg.Go(func() error {
49+
pf := tree.Proof(l)
50+
if !merkle.Valid(tree.Root(), pf, l) {
51+
return errors.New("invalid proof for tree")
52+
}
53+
proofHash := hashProof(pf)
54+
pm.Lock()
55+
proofHashes = append(proofHashes, []any{tree.Root(), proofHash})
56+
pm.Unlock()
57+
return nil
58+
})
59+
}
60+
err := eg.Wait()
61+
if err != nil {
62+
return err
63+
}
64+
65+
_, err = tx.CopyFrom(ctx, pgx.Identifier{"proofs_hashes"},
66+
[]string{"root", "hash"},
67+
pgx.CopyFromRows(proofHashes),
68+
)
69+
70+
return err
71+
}
72+
73+
func main() {
74+
ctx := context.Background()
75+
const defaultPGURL = "postgres:///al"
76+
dburl := os.Getenv("DATABASE_URL")
77+
if dburl == "" {
78+
dburl = defaultPGURL
79+
}
80+
dbc, err := pgxpool.ParseConfig(dburl)
81+
check(err)
82+
83+
db, err := pgxpool.ConnectConfig(ctx, dbc)
84+
check(err)
85+
86+
log.Println("fetching roots from db")
87+
const q = `
88+
SELECT unhashed_leaves
89+
FROM trees
90+
WHERE root not in (select root from proofs_hashes group by 1)
91+
`
92+
rows, err := db.Query(ctx, q)
93+
check(err)
94+
defer rows.Close()
95+
96+
trees := [][][]byte{}
97+
98+
for rows.Next() {
99+
var t [][]byte
100+
err := rows.Scan(&t)
101+
trees = append(trees, t)
102+
check(err)
103+
}
104+
105+
if len(trees) == 0 {
106+
log.Println("no trees to process")
107+
return
108+
}
109+
110+
log.Printf("migrating %d trees", len(trees))
111+
112+
tx, err := db.Begin(ctx)
113+
check(err)
114+
defer tx.Rollback(ctx)
115+
116+
var count int
117+
118+
for _, tree := range trees {
119+
err = migrateTree(ctx, tx, tree)
120+
check(err)
121+
count++
122+
if count%1000 == 0 {
123+
log.Printf("migrated %d/%d trees", count, len(trees))
124+
}
125+
}
126+
127+
log.Printf("committing %d trees", len(trees))
128+
err = tx.Commit(ctx)
129+
check(err)
130+
log.Printf("done")
131+
}

api/proof.go

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package api
22

33
import (
4+
"bytes"
45
"errors"
56
"net/http"
67

8+
"github.com/contextwtf/lanyard/merkle"
79
"github.com/ethereum/go-ethereum/common"
810
"github.com/ethereum/go-ethereum/common/hexutil"
911
"github.com/jackc/pgx/v4"
@@ -18,41 +20,19 @@ func (s *Server) GetProof(w http.ResponseWriter, r *http.Request) {
1820
var (
1921
ctx = r.Context()
2022
root = common.FromHex(r.URL.Query().Get("root"))
21-
leaf = r.URL.Query().Get("unhashedLeaf")
22-
addr = r.URL.Query().Get("address")
23+
leaf = common.FromHex(r.URL.Query().Get("unhashedLeaf"))
24+
addr = common.HexToAddress(r.URL.Query().Get("address"))
2325
)
2426
if len(root) == 0 {
2527
s.sendJSONError(r, w, nil, http.StatusBadRequest, "missing root")
2628
return
2729
}
28-
if leaf == "" && addr == "" {
30+
if len(leaf) == 0 && addr == (common.Address{}) {
2931
s.sendJSONError(r, w, nil, http.StatusBadRequest, "missing leaf")
3032
return
3133
}
3234

33-
const q = `
34-
WITH tree AS (
35-
SELECT jsonb_array_elements(proofs) proofs
36-
FROM trees
37-
WHERE root = $1
38-
)
39-
SELECT
40-
proofs->'leaf',
41-
proofs->'proof'
42-
FROM tree
43-
WHERE (
44-
--eth addresses contain mixed casing to
45-
--accommodate checksums. we sidestep
46-
--the casing issues for user queries
47-
lower(proofs->>'addr') = lower($2)
48-
OR lower(proofs->>'leaf') = lower($3)
49-
)
50-
`
51-
var (
52-
resp = &getProofResp{}
53-
row = s.db.QueryRow(ctx, q, root, addr, leaf)
54-
err = row.Scan(&resp.UnhashedLeaf, &resp.Proof)
55-
)
35+
td, err := getTree(ctx, s.db, root)
5636
if errors.Is(err, pgx.ErrNoRows) {
5737
s.sendJSONError(r, w, nil, http.StatusNotFound, "tree not found")
5838
w.Header().Set("Cache-Control", "public, max-age=60")
@@ -62,12 +42,49 @@ func (s *Server) GetProof(w http.ResponseWriter, r *http.Request) {
6242
return
6343
}
6444

45+
var (
46+
leaves [][]byte
47+
target []byte
48+
)
49+
// check if leaf is in tree and error if not
50+
for _, l := range td.UnhashedLeaves {
51+
if len(target) == 0 {
52+
if len(leaf) > 0 {
53+
if bytes.Equal(l, leaf) {
54+
target = l
55+
}
56+
} else if leaf2Addr(l, td.Ltd, td.Packed).Hex() == addr.Hex() {
57+
target = l
58+
}
59+
}
60+
61+
leaves = append(leaves, l)
62+
}
63+
64+
if len(target) == 0 {
65+
s.sendJSONError(r, w, nil, http.StatusNotFound, "leaf not found in tree")
66+
return
67+
}
68+
69+
var (
70+
p = merkle.New(leaves).Proof(target)
71+
phex = []hexutil.Bytes{}
72+
)
73+
74+
// convert [][]byte to []hexutil.Bytes
75+
for _, p := range p {
76+
phex = append(phex, p)
77+
}
78+
6579
// cache for 1 year if we're returning an unhashed leaf proof
6680
// or 60 seconds for an address proof
67-
if leaf != "" {
81+
if len(leaf) > 0 {
6882
w.Header().Set("Cache-Control", "public, max-age=31536000")
6983
} else {
7084
w.Header().Set("Cache-Control", "public, max-age=60")
7185
}
72-
s.sendJSON(r, w, resp)
86+
s.sendJSON(r, w, getProofResp{
87+
UnhashedLeaf: target,
88+
Proof: phex,
89+
})
7390
}

api/root.go

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,13 @@
11
package api
22

33
import (
4-
"encoding/json"
54
"net/http"
65
"strings"
76

87
"github.com/ethereum/go-ethereum/common/hexutil"
98
"github.com/jackc/pgx/v4"
109
)
1110

12-
func proofURLToDBQuery(param string) string {
13-
type proofLookup struct {
14-
Proof []string `json:"proof"`
15-
}
16-
17-
lookup := proofLookup{
18-
Proof: strings.Split(param, ","),
19-
}
20-
21-
q, err := json.Marshal([]proofLookup{lookup})
22-
if err != nil {
23-
return ""
24-
}
25-
26-
return string(q)
27-
}
28-
2911
func (s *Server) GetRoot(w http.ResponseWriter, r *http.Request) {
3012
type rootResp struct {
3113
Root hexutil.Bytes `json:"root"`
@@ -37,24 +19,40 @@ func (s *Server) GetRoot(w http.ResponseWriter, r *http.Request) {
3719
}
3820

3921
var (
40-
ctx = r.Context()
41-
proof = r.URL.Query().Get("proof")
42-
dbQuery = proofURLToDBQuery(proof)
22+
ctx = r.Context()
23+
err error
24+
proof = r.URL.Query().Get("proof")
25+
ps = strings.Split(proof, ",")
26+
pb = [][]byte{}
4327
)
44-
if proof == "" || dbQuery == "" {
45-
s.sendJSONError(r, w, nil, http.StatusBadRequest, "missing list of proofs")
28+
29+
for _, s := range ps {
30+
var b []byte
31+
b, err = hexutil.Decode(s)
32+
if err != nil {
33+
break
34+
}
35+
pb = append(pb, b)
36+
}
37+
38+
if len(pb) == 0 || err != nil {
39+
s.sendJSONError(r, w, nil, http.StatusBadRequest, "missing or malformed list of proofs")
4640
return
4741
}
4842

4943
const q = `
5044
SELECT root
51-
FROM trees_proofs
52-
WHERE proofs_array(proofs) @> proofs_array($1);
45+
FROM proofs_hashes
46+
WHERE hash = $1
47+
group by 1;
5348
`
54-
roots := make([]hexutil.Bytes, 0)
55-
rb := make(hexutil.Bytes, 0)
49+
var (
50+
roots []hexutil.Bytes
51+
rb hexutil.Bytes
52+
ph = hashProof(pb)
53+
)
5654

57-
_, err := s.db.QueryFunc(ctx, q, []interface{}{dbQuery}, []interface{}{&rb}, func(qfr pgx.QueryFuncRow) error {
55+
_, err = s.db.QueryFunc(ctx, q, []interface{}{&ph}, []interface{}{&rb}, func(qfr pgx.QueryFuncRow) error {
5856
roots = append(roots, rb)
5957
return nil
6058
})

0 commit comments

Comments
 (0)