Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 238 additions & 12 deletions core/tnclient/transport_cre.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
package tnclient

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/url"
"regexp"
"strconv"
"strings"
Expand All @@ -15,6 +17,7 @@ import (

clientType "github.com/trufnetwork/kwil-db/core/client/types"
"github.com/trufnetwork/kwil-db/core/crypto/auth"
"github.com/trufnetwork/kwil-db/core/rpc/client/gateway"
jsonrpc "github.com/trufnetwork/kwil-db/core/rpc/json"
"github.com/trufnetwork/kwil-db/core/types"

Expand Down Expand Up @@ -58,6 +61,8 @@ type CRETransport struct {
chainIDMu sync.RWMutex
chainIDInitialized bool
reqID atomic.Uint64
authCookie string // Cookie value for gateway authentication
authCookieMu sync.RWMutex
}

// Verify CRETransport implements Transport interface at compile time
Expand All @@ -81,6 +86,13 @@ var _ Transport = (*CRETransport)(nil)
// return err
// }
func NewCRETransport(runtime cre.NodeRuntime, endpoint string, signer auth.Signer) (*CRETransport, error) {
// Append /rpc/v1 if not already present (kwil-db client adds this automatically)
// First trim trailing slashes to prevent duplication (e.g., "/rpc/v1/" → "/rpc/v1/rpc/v1")
endpoint = strings.TrimRight(endpoint, "/")
if !strings.HasSuffix(endpoint, "/rpc/v1") {
endpoint = endpoint + "/rpc/v1"
}

return &CRETransport{
runtime: runtime,
client: &http.Client{},
Expand All @@ -97,7 +109,34 @@ func (t *CRETransport) nextReqID() string {
}

// callJSONRPC makes a JSON-RPC call via CRE HTTP client
// It automatically handles authentication if the endpoint returns 401
func (t *CRETransport) callJSONRPC(ctx context.Context, method string, params any, result any) error {
// Try the call
err := t.doJSONRPC(ctx, method, params, result)

// If we get a 401, try authenticating and retry once
if err != nil && strings.Contains(err.Error(), "401") {
if t.signer == nil {
return fmt.Errorf("%w [DEBUG: signer is nil, cannot authenticate]", err)
}
// Authenticate with gateway
authErr := t.authenticate(ctx)
if authErr != nil {
return fmt.Errorf("authentication failed: %w (original 401 for method %s)", authErr, method)
}
// Retry the call
retryErr := t.doJSONRPC(ctx, method, params, result)
if retryErr != nil {
return fmt.Errorf("retry after auth failed: %w (method: %s)", retryErr, method)
}
return nil
}

return err
}

// doJSONRPC performs the actual JSON-RPC call without authentication retry
func (t *CRETransport) doJSONRPC(ctx context.Context, method string, params any, result any) error {
// Marshal the params
paramsJSON, err := json.Marshal(params)
if err != nil {
Expand All @@ -114,14 +153,24 @@ func (t *CRETransport) callJSONRPC(ctx context.Context, method string, params an
return fmt.Errorf("failed to marshal JSON-RPC request: %w", err)
}

// Create headers
headers := map[string]string{
"Content-Type": "application/json",
}

// Add auth cookie if we have one
t.authCookieMu.RLock()
if t.authCookie != "" {
headers["Cookie"] = t.authCookie
}
t.authCookieMu.RUnlock()

// Create CRE HTTP request
httpReq := &http.Request{
Url: t.endpoint,
Method: "POST",
Body: requestBody,
Headers: map[string]string{
"Content-Type": "application/json",
},
Url: t.endpoint,
Method: "POST",
Body: requestBody,
Headers: headers,
}

// Execute via CRE client (returns Promise)
Expand Down Expand Up @@ -167,15 +216,39 @@ func (t *CRETransport) callJSONRPC(ctx context.Context, method string, params an
// The call is executed within CRE's consensus mechanism, ensuring all nodes in the DON
// reach agreement on the result.
func (t *CRETransport) Call(ctx context.Context, namespace string, action string, inputs []any) (*types.CallResult, error) {
// Build call params matching kwil-db's user/call endpoint
params := map[string]any{
"dbid": namespace,
"action": action,
"inputs": inputs,
// Use "main" as default namespace if empty (TRUF.NETWORK convention)
if namespace == "" {
namespace = "main"
}

// Encode inputs to EncodedValue array
var encodedInputs []*types.EncodedValue
for _, val := range inputs {
encoded, err := types.EncodeValue(val)
if err != nil {
return nil, fmt.Errorf("failed to encode input value: %w", err)
}
encodedInputs = append(encodedInputs, encoded)
}

// Build ActionCall payload
payload := &types.ActionCall{
Namespace: namespace,
Action: action,
Arguments: encodedInputs,
}

// Create CallMessage
// Call operations are read-only and typically don't require authentication,
// but we pass the signer (if configured) to support authenticated gateway calls.
// The challenge is nil for standard calls (vs. Execute which requires it).
callMsg, err := types.CreateCallMessage(payload, nil, t.signer)
if err != nil {
return nil, fmt.Errorf("failed to create call message: %w", err)
}

var result types.CallResult
if err := t.callJSONRPC(ctx, "user.call", params, &result); err != nil {
if err := t.callJSONRPC(ctx, "user.call", callMsg, &result); err != nil {
return nil, err
}

Expand Down Expand Up @@ -450,3 +523,156 @@ func (t *CRETransport) ChainID() string {
func (t *CRETransport) Signer() auth.Signer {
return t.signer
}

// authenticate performs gateway authentication and stores the cookie.
// This is called automatically when a 401 error is received.
func (t *CRETransport) authenticate(ctx context.Context) error {
if t.signer == nil {
return fmt.Errorf("cannot authenticate without a signer")
}

// Get authentication parameters from gateway
var authParam gateway.AuthnParameterResponse
if err := t.doJSONRPC(ctx, string(gateway.MethodAuthnParam), &struct{}{}, &authParam); err != nil {
return fmt.Errorf("failed to get auth parameters (kgw.authn_param): %w", err)
}

// Parse endpoint to get domain (remove /rpc/v1 path if present)
parsedURL, err := url.Parse(t.endpoint)
if err != nil {
return fmt.Errorf("failed to parse endpoint URL %s: %w", t.endpoint, err)
}
// Use just scheme + host, without the /rpc/v1 path
targetDomain := parsedURL.Scheme + "://" + parsedURL.Host

// Get chain ID
chainID := t.ChainID()
if chainID == "" {
return fmt.Errorf("failed to get chain ID for authentication")
}

// Compose authentication message (SIWE-like format)
msg := composeGatewayAuthMessage(&authParam, targetDomain, authParam.URI, "1", chainID)

// Sign the message
sig, err := t.signer.Sign([]byte(msg))
if err != nil {
return fmt.Errorf("failed to sign auth message: %w", err)
}

// Send authentication request
authReq := &gateway.AuthnRequest{
Nonce: authParam.Nonce,
Sender: t.signer.CompactID(),
Signature: sig,
}

// Make the auth request and capture the response headers
authResp, err := t.doJSONRPCWithResponse(ctx, string(gateway.MethodAuthn), authReq)
if err != nil {
return fmt.Errorf("kgw.authn request failed: %w [DEBUG: sender=%x, nonce=%s]",
err, authReq.Sender, authReq.Nonce)
}

// Extract Set-Cookie header from response
setCookie, ok := authResp["set-cookie"]
if !ok || setCookie == "" {
// Try other common header names
if sc, exists := authResp["Set-Cookie"]; exists {
setCookie = sc
ok = true
}
}

if ok && setCookie != "" {
// Parse the cookie (just extract the name=value part)
cookieParts := strings.Split(setCookie, ";")
if len(cookieParts) > 0 {
t.authCookieMu.Lock()
t.authCookie = cookieParts[0] // Store just the name=value part
t.authCookieMu.Unlock()
}
} else {
return fmt.Errorf("no Set-Cookie header in kgw.authn response [DEBUG: headers=%+v]", authResp)
}

return nil
}

// doJSONRPCWithResponse performs a JSON-RPC call and returns the response headers.
// This is used for authentication to extract the Set-Cookie header.
func (t *CRETransport) doJSONRPCWithResponse(ctx context.Context, method string, params any) (map[string]string, error) {
// Marshal the params
paramsJSON, err := json.Marshal(params)
if err != nil {
return nil, fmt.Errorf("failed to marshal params: %w", err)
}

// Create JSON-RPC request
reqID := t.nextReqID()
rpcReq := jsonrpc.NewRequest(reqID, method, paramsJSON)

// Marshal the full request
requestBody, err := json.Marshal(rpcReq)
if err != nil {
return nil, fmt.Errorf("failed to marshal JSON-RPC request: %w", err)
}

// Create CRE HTTP request
httpReq := &http.Request{
Url: t.endpoint,
Method: "POST",
Body: requestBody,
Headers: map[string]string{
"Content-Type": "application/json",
},
}

// Execute via CRE client (returns Promise)
httpResp, err := t.client.SendRequest(t.runtime, httpReq).Await()
if err != nil {
return nil, fmt.Errorf("CRE HTTP request failed: %w", err)
}

// Check HTTP status
if httpResp.StatusCode != 200 {
return nil, fmt.Errorf("unexpected HTTP status code: %d", httpResp.StatusCode)
}

// Parse JSON-RPC response
var rpcResp jsonrpc.Response
if err := json.Unmarshal(httpResp.Body, &rpcResp); err != nil {
return nil, fmt.Errorf("failed to unmarshal JSON-RPC response: %w", err)
}

// Check for JSON-RPC errors
if rpcResp.Error != nil {
return nil, fmt.Errorf("JSON-RPC error: %s (code: %d)", rpcResp.Error.Message, rpcResp.Error.Code)
}

// Return the response headers
return httpResp.GetHeaders(), nil
}

// composeGatewayAuthMessage composes a SIWE-like authentication message.
// This matches the format used by kwil-db gateway client.
// Note: This is a custom format, not standard SIWE - it omits the account address line
// and uses "Issue At" instead of "Issued At" to match kgw's expectations.
func composeGatewayAuthMessage(param *gateway.AuthnParameterResponse, domain string, uri string, version string, chainID string) string {
var msg bytes.Buffer
msg.WriteString(domain + " wants you to sign in with your account:\n")
msg.WriteString("\n")
if param.Statement != "" {
msg.WriteString(param.Statement + "\n")
}
msg.WriteString("\n")
msg.WriteString(fmt.Sprintf("URI: %s\n", uri))
msg.WriteString(fmt.Sprintf("Version: %s\n", version))
msg.WriteString(fmt.Sprintf("Chain ID: %s\n", chainID))
msg.WriteString(fmt.Sprintf("Nonce: %s\n", param.Nonce))
msg.WriteString(fmt.Sprintf("Issue At: %s\n", param.IssueAt)) // Note: "Issue At" not "Issued At" (kgw custom format)
if param.ExpirationTime != "" {
msg.WriteString(fmt.Sprintf("Expiration Time: %s\n", param.ExpirationTime))
}
return msg.String()
}
Loading