diff --git a/internal/api/syncapi/identity.go b/internal/api/syncapi/identity.go index b67ecedc..94270378 100644 --- a/internal/api/syncapi/identity.go +++ b/internal/api/syncapi/identity.go @@ -63,7 +63,7 @@ func (i *Identity) loadOrGenerateKey() error { if pubKeyBlock == nil { return errors.New("no public key found in pem") } - pubKey, err := x509.ParsePKIXPublicKey(pubKeyBytes) + pubKey, err := x509.ParsePKIXPublicKey(pubKeyBlock.Bytes) if err != nil { return fmt.Errorf("parse public key: %w", err) } diff --git a/internal/api/syncapi/identity_test.go b/internal/api/syncapi/identity_test.go index 4a56c58a..b66c3d92 100644 --- a/internal/api/syncapi/identity_test.go +++ b/internal/api/syncapi/identity_test.go @@ -2,12 +2,11 @@ package syncapi import ( "fmt" - "os" "path/filepath" "testing" ) -func TestIdentity(t *testing.T) { +func TestCreateAndLoad(t *testing.T) { dir := t.TempDir() // Create a new identity @@ -16,14 +15,36 @@ func TestIdentity(t *testing.T) { t.Fatalf("failed to create identity: %v", err) } + // Load the identity + loaded, err := NewIdentity("test-instance", filepath.Join(dir, "myidentity.pem")) + if err != nil { + t.Fatalf("failed to load identity: %v", err) + } + + // Verify the identity + if !ident.privateKey.Equal(loaded.privateKey) { + t.Fatalf("identities do not match") + } +} + +func TestSignatures(t *testing.T) { + dir := t.TempDir() + + // Create a new identity + ident, err := NewIdentity("test-instance", filepath.Join(dir, "myidentity.pem")) + if err != nil { + t.Fatalf("failed to create identity: %v", err) + } + + // Sign a message signature, err := ident.SignMessage([]byte("hello world!")) + if err != nil { + t.Fatalf("failed to sign message: %v", err) + } fmt.Printf("signed message: %x\n", signature) - // Load and print identity file - bytes, _ := os.ReadFile(filepath.Join(dir, "myidentity.pem")) - t.Log(string(bytes)) - - // Load and print public key file - bytes, _ = os.ReadFile(filepath.Join(dir, "myidentity.pem.pub")) - t.Log(string(bytes)) + // verify the signature + if err := ident.VerifySignature([]byte("hello world!"), signature); err != nil { + t.Fatalf("failed to verify signature: %v", err) + } } diff --git a/internal/api/syncapi/syncclient.go b/internal/api/syncapi/syncclient.go index cdd3feed..4264536f 100644 --- a/internal/api/syncapi/syncclient.go +++ b/internal/api/syncapi/syncclient.go @@ -34,7 +34,6 @@ type SyncClient struct { // mutable properties mu sync.Mutex - remoteConfigStore RemoteConfigStore connectionStatus v1.SyncConnectionState connectionStatusMessage string } diff --git a/internal/api/syncapi/synchandler.go b/internal/api/syncapi/synchandler.go index 59175c85..317fb763 100644 --- a/internal/api/syncapi/synchandler.go +++ b/internal/api/syncapi/synchandler.go @@ -100,8 +100,8 @@ func (h *BackrestSyncHandler) Sync(ctx context.Context, stream *connect.BidiStre } zap.S().Infof("syncserver accepted a connection from client instance ID %q", authorizedClientPeer.InstanceId) - opIDLru, _ := lru.New[int64, int64](128) // original ID -> local ID - flowIDLru, _ := lru.New[int64, int64](128) // original flow ID -> local flow ID + opIDLru, _ := lru.New[int64, int64](2048) // original ID -> local ID + flowIDLru, _ := lru.New[int64, int64](2048) // original flow ID -> local flow ID insertOrUpdate := func(op *v1.Operation) error { op.OriginalId = op.Id @@ -123,18 +123,38 @@ func (h *BackrestSyncHandler) Sync(ctx context.Context, stream *connect.BidiStre } } if op.FlowId, ok = flowIDLru.Get(op.OriginalFlowId); !ok { - var flowOp *v1.Operation - if err := h.mgr.oplog.Query(oplog.Query{}. - SetOriginalFlowID(op.OriginalFlowId). - SetInstanceID(op.InstanceId), func(o *v1.Operation) error { - flowOp = o - return nil - }); err != nil { - return fmt.Errorf("mapping remote flow ID to local ID: %w", err) + tryFindFlowID := func(q oplog.Query) (int64, error) { + var flowOp *v1.Operation + if err := h.mgr.oplog.Query(q, func(o *v1.Operation) error { + flowOp = o + return nil + }); err != nil { + return 0, fmt.Errorf("mapping remote flow ID to local ID: %w", err) + } + if flowOp != nil { + return flowOp.FlowId, nil + } + return 0, nil } - if flowOp != nil { - op.FlowId = flowOp.FlowId - flowIDLru.Add(op.OriginalFlowId, flowOp.FlowId) + + var err error + var flowId int64 + flowId, err = tryFindFlowID(oplog.Query{}.SetSnapshotID(op.SnapshotId)) + if err != nil { + return err + } + if flowId == 0 { + flowId, err = tryFindFlowID(oplog.Query{}. + SetOriginalFlowID(op.OriginalFlowId). + SetInstanceID(op.InstanceId)) + if err != nil { + return err + } + } + + if flowId != 0 { + op.FlowId = flowId + flowIDLru.Add(op.OriginalFlowId, flowId) } }