From 3e73f118c5e0432b4c0ef781af433694396c7424 Mon Sep 17 00:00:00 2001 From: Silas Lenihan Date: Fri, 25 Apr 2025 12:45:44 -0400 Subject: [PATCH 01/16] Updated ocr3 Metadata type to include Encoding and Decoding --- .../consensus/ocr3/types/aggregator.go | 149 ++++++++++++++++++ 1 file changed, 149 insertions(+) diff --git a/pkg/capabilities/consensus/ocr3/types/aggregator.go b/pkg/capabilities/consensus/ocr3/types/aggregator.go index 96ba82be0..db53c0239 100644 --- a/pkg/capabilities/consensus/ocr3/types/aggregator.go +++ b/pkg/capabilities/consensus/ocr3/types/aggregator.go @@ -1,6 +1,11 @@ package types import ( + "bytes" + "encoding/binary" + "encoding/hex" + "fmt" + "io" "strings" ocrcommon "github.com/smartcontractkit/libocr/commontypes" @@ -34,6 +39,150 @@ func (m *Metadata) padWorkflowName() { } } +// Encode serializes Metadata in contract order: +// 1B Version, 32B ExecutionID, 4B Timestamp, 4B DONID, 4B DONConfigVersion, +// 32B WorkflowID, 10B WorkflowName, 20B WorkflowOwner, 2B ReportID +func (m Metadata) Encode() ([]byte, error) { + m.padWorkflowName() + buf := new(bytes.Buffer) + + // 1) Version as a single byte + if err := buf.WriteByte(byte(m.Version)); err != nil { + return nil, err + } + + // 2) Helper to decode a hex string and ensure length + writeHex := func(field string, expectedBytes int) error { + s := strings.TrimPrefix(field, "0x") + b, err := hex.DecodeString(s) + if err != nil { + return fmt.Errorf("invalid hex in field: %w", err) + } + if len(b) != expectedBytes { + return fmt.Errorf("wrong length: expected %d bytes, got %d", expectedBytes, len(b)) + } + _, err = buf.Write(b) + return err + } + + // ExecutionID: 32 bytes + if err := writeHex(m.ExecutionID, 32); err != nil { + return nil, fmt.Errorf("ExecutionID: %w", err) + } + + // Timestamp, DONID, DONConfigVersion—all 4‐byte big endian + for _, v := range []uint32{m.Timestamp, m.DONID, m.DONConfigVersion} { + if err := binary.Write(buf, binary.BigEndian, v); err != nil { + return nil, err + } + } + + // WorkflowID: 32 bytes + if err := writeHex(m.WorkflowID, 32); err != nil { + return nil, fmt.Errorf("WorkflowID: %w", err) + } + + // Workflow Name: 10 bytes + if err := writeHex(m.WorkflowName, 10); err != nil { + return nil, fmt.Errorf("WorkflowName: %w", err) + } + + // WorkflowOwner: 20 bytes + if err := writeHex(m.WorkflowOwner, 20); err != nil { + return nil, fmt.Errorf("WorkflowOwner: %w", err) + } + + // ReportID: 2 bytes + if err := writeHex(m.ReportID, 2); err != nil { + return nil, fmt.Errorf("ReportID: %w", err) + } + + return buf.Bytes(), nil +} + +const MetadataLen = 1 + 32 + 4 + 4 + 4 + 32 + 10 + 20 + 2 // =109 + +// Decode parses exactly MetadataLen bytes from raw, returns a Metadata struct +// and any trailing data. +func Decode(raw []byte) (Metadata, []byte, error) { + m := Metadata{} + + if len(raw) < MetadataLen { + return m, nil, fmt.Errorf("metadata: raw too short, want ≥%d, got %d", MetadataLen, len(raw)) + } + + buf := bytes.NewReader(raw[:MetadataLen]) + + // 1) Version (1 byte) + var vb byte + if err := binary.Read(buf, binary.BigEndian, &vb); err != nil { + return m, nil, err + } + m.Version = uint32(vb) + + // helper to read N bytes and hex-decode + readHex := func(n int) (string, error) { + tmp := make([]byte, n) + if _, err := io.ReadFull(buf, tmp); err != nil { + return "", err + } + return hex.EncodeToString(tmp), nil + } + + // 2) ExecutionID (32 bytes hex) + var err error + if m.ExecutionID, err = readHex(32); err != nil { + return m, nil, fmt.Errorf("ExecutionID: %w", err) + } + + // 3) Timestamp, DONID, DONConfigVersion (each 4 bytes BE) + for _, ptr := range []*uint32{&m.Timestamp, &m.DONID, &m.DONConfigVersion} { + if err := binary.Read(buf, binary.BigEndian, ptr); err != nil { + return m, nil, err + } + } + + // 4) WorkflowID (32 bytes hex) + if m.WorkflowID, err = readHex(32); err != nil { + return m, nil, fmt.Errorf("WorkflowID: %w", err) + } + + nameBytes := make([]byte, 10) + if _, err := io.ReadFull(buf, nameBytes); err != nil { + return m, nil, err + } + // hex-encode those 10 bytes into a 20-char string + m.WorkflowName = hex.EncodeToString(nameBytes) + + // 6) WorkflowOwner (20 bytes hex) + if m.WorkflowOwner, err = readHex(20); err != nil { + return m, nil, fmt.Errorf("WorkflowOwner: %w", err) + } + + // 7) ReportID (2 bytes hex) + if m.ReportID, err = readHex(2); err != nil { + return m, nil, fmt.Errorf("ReportID: %w", err) + } + + // strip any stray "0x" prefixes just in case + m.ExecutionID = strings.TrimPrefix(m.ExecutionID, "0x") + m.WorkflowID = strings.TrimPrefix(m.WorkflowID, "0x") + m.WorkflowOwner = strings.TrimPrefix(m.WorkflowOwner, "0x") + m.ReportID = strings.TrimPrefix(m.ReportID, "0x") + + // the rest is payload + tail := raw[MetadataLen:] + return m, tail, nil +} + +func (m Metadata) Length() int { + b, err := m.Encode() + if err != nil { + return 0 + } + return len(b) +} + // Aggregator is the interface that enables a hook to the Outcome() phase of OCR reporting. type Aggregator interface { // Called by the Outcome() phase of OCR reporting. From f3ed96e428724590d600ab73b9c6be5fce146d10 Mon Sep 17 00:00:00 2001 From: Silas Lenihan Date: Mon, 19 May 2025 13:16:02 -0400 Subject: [PATCH 02/16] Revert "add GetEstimateFee (#1196)" This reverts commit 80bc8b13c0e7839849182ef9e55b8ac0063145c9. --- pkg/loop/internal/pb/contract_writer.pb.go | 226 +++--------------- pkg/loop/internal/pb/contract_writer.proto | 17 -- .../internal/pb/contract_writer_grpc.pb.go | 38 --- .../contractwriter/contract_writer.go | 43 ---- pkg/types/contract_writer.go | 10 - 5 files changed, 34 insertions(+), 300 deletions(-) diff --git a/pkg/loop/internal/pb/contract_writer.pb.go b/pkg/loop/internal/pb/contract_writer.pb.go index 372380200..5b61982db 100644 --- a/pkg/loop/internal/pb/contract_writer.pb.go +++ b/pkg/loop/internal/pb/contract_writer.pb.go @@ -271,91 +271,6 @@ func (x *GetTransactionStatusRequest) GetTransactionId() string { return "" } -// GetEstimateFeeReply has arguments for [github.com/smartcontractkit/chainlink-common/pkg/types.ContractWriter.GetEstimateFee]. -type GetEstimateFeeRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - ContractName string `protobuf:"bytes,1,opt,name=contract_name,json=contractName,proto3" json:"contract_name,omitempty"` - Method string `protobuf:"bytes,2,opt,name=method,proto3" json:"method,omitempty"` - Params *VersionedBytes `protobuf:"bytes,3,opt,name=params,proto3" json:"params,omitempty"` - ToAddress string `protobuf:"bytes,4,opt,name=to_address,json=toAddress,proto3" json:"to_address,omitempty"` - Meta *TransactionMeta `protobuf:"bytes,5,opt,name=meta,proto3" json:"meta,omitempty"` - Value *BigInt `protobuf:"bytes,6,opt,name=value,proto3" json:"value,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *GetEstimateFeeRequest) Reset() { - *x = GetEstimateFeeRequest{} - mi := &file_contract_writer_proto_msgTypes[3] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *GetEstimateFeeRequest) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*GetEstimateFeeRequest) ProtoMessage() {} - -func (x *GetEstimateFeeRequest) ProtoReflect() protoreflect.Message { - mi := &file_contract_writer_proto_msgTypes[3] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use GetEstimateFeeRequest.ProtoReflect.Descriptor instead. -func (*GetEstimateFeeRequest) Descriptor() ([]byte, []int) { - return file_contract_writer_proto_rawDescGZIP(), []int{3} -} - -func (x *GetEstimateFeeRequest) GetContractName() string { - if x != nil { - return x.ContractName - } - return "" -} - -func (x *GetEstimateFeeRequest) GetMethod() string { - if x != nil { - return x.Method - } - return "" -} - -func (x *GetEstimateFeeRequest) GetParams() *VersionedBytes { - if x != nil { - return x.Params - } - return nil -} - -func (x *GetEstimateFeeRequest) GetToAddress() string { - if x != nil { - return x.ToAddress - } - return "" -} - -func (x *GetEstimateFeeRequest) GetMeta() *TransactionMeta { - if x != nil { - return x.Meta - } - return nil -} - -func (x *GetEstimateFeeRequest) GetValue() *BigInt { - if x != nil { - return x.Value - } - return nil -} - // GetTransactionStatusReply has return arguments for [github.com/smartcontractkit/chainlink-common/pkg/types.ContractWriter.GetTransactionStatus]. type GetTransactionStatusReply struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -366,7 +281,7 @@ type GetTransactionStatusReply struct { func (x *GetTransactionStatusReply) Reset() { *x = GetTransactionStatusReply{} - mi := &file_contract_writer_proto_msgTypes[4] + mi := &file_contract_writer_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -378,7 +293,7 @@ func (x *GetTransactionStatusReply) String() string { func (*GetTransactionStatusReply) ProtoMessage() {} func (x *GetTransactionStatusReply) ProtoReflect() protoreflect.Message { - mi := &file_contract_writer_proto_msgTypes[4] + mi := &file_contract_writer_proto_msgTypes[3] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -391,7 +306,7 @@ func (x *GetTransactionStatusReply) ProtoReflect() protoreflect.Message { // Deprecated: Use GetTransactionStatusReply.ProtoReflect.Descriptor instead. func (*GetTransactionStatusReply) Descriptor() ([]byte, []int) { - return file_contract_writer_proto_rawDescGZIP(), []int{4} + return file_contract_writer_proto_rawDescGZIP(), []int{3} } func (x *GetTransactionStatusReply) GetTransactionStatus() TransactionStatus { @@ -412,7 +327,7 @@ type GetFeeComponentsReply struct { func (x *GetFeeComponentsReply) Reset() { *x = GetFeeComponentsReply{} - mi := &file_contract_writer_proto_msgTypes[5] + mi := &file_contract_writer_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -424,7 +339,7 @@ func (x *GetFeeComponentsReply) String() string { func (*GetFeeComponentsReply) ProtoMessage() {} func (x *GetFeeComponentsReply) ProtoReflect() protoreflect.Message { - mi := &file_contract_writer_proto_msgTypes[5] + mi := &file_contract_writer_proto_msgTypes[4] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -437,7 +352,7 @@ func (x *GetFeeComponentsReply) ProtoReflect() protoreflect.Message { // Deprecated: Use GetFeeComponentsReply.ProtoReflect.Descriptor instead. func (*GetFeeComponentsReply) Descriptor() ([]byte, []int) { - return file_contract_writer_proto_rawDescGZIP(), []int{5} + return file_contract_writer_proto_rawDescGZIP(), []int{4} } func (x *GetFeeComponentsReply) GetExecutionFee() *BigInt { @@ -454,59 +369,6 @@ func (x *GetFeeComponentsReply) GetDataAvailabilityFee() *BigInt { return nil } -// GetEstimateFeeReply has return arguments for [github.com/smartcontractkit/chainlink-common/pkg/types.ContractWriter.GetEstimateFee]. -type GetEstimateFeeReply struct { - state protoimpl.MessageState `protogen:"open.v1"` - Fee *BigInt `protobuf:"bytes,1,opt,name=fee,proto3" json:"fee,omitempty"` - Decimals uint32 `protobuf:"varint,2,opt,name=decimals,proto3" json:"decimals,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *GetEstimateFeeReply) Reset() { - *x = GetEstimateFeeReply{} - mi := &file_contract_writer_proto_msgTypes[6] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *GetEstimateFeeReply) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*GetEstimateFeeReply) ProtoMessage() {} - -func (x *GetEstimateFeeReply) ProtoReflect() protoreflect.Message { - mi := &file_contract_writer_proto_msgTypes[6] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use GetEstimateFeeReply.ProtoReflect.Descriptor instead. -func (*GetEstimateFeeReply) Descriptor() ([]byte, []int) { - return file_contract_writer_proto_rawDescGZIP(), []int{6} -} - -func (x *GetEstimateFeeReply) GetFee() *BigInt { - if x != nil { - return x.Fee - } - return nil -} - -func (x *GetEstimateFeeReply) GetDecimals() uint32 { - if x != nil { - return x.Decimals - } - return 0 -} - var File_contract_writer_proto protoreflect.FileDescriptor const file_contract_writer_proto_rawDesc = "" + @@ -525,35 +387,23 @@ const file_contract_writer_proto_rawDesc = "" + "\x15workflow_execution_id\x18\x01 \x01(\tR\x13workflowExecutionId\x12)\n" + "\tgas_limit\x18\x02 \x01(\v2\f.loop.BigIntR\bgasLimit\"D\n" + "\x1bGetTransactionStatusRequest\x12%\n" + - "\x0etransaction_id\x18\x01 \x01(\tR\rtransactionId\"\xf0\x01\n" + - "\x15GetEstimateFeeRequest\x12#\n" + - "\rcontract_name\x18\x01 \x01(\tR\fcontractName\x12\x16\n" + - "\x06method\x18\x02 \x01(\tR\x06method\x12,\n" + - "\x06params\x18\x03 \x01(\v2\x14.loop.VersionedBytesR\x06params\x12\x1d\n" + - "\n" + - "to_address\x18\x04 \x01(\tR\ttoAddress\x12)\n" + - "\x04meta\x18\x05 \x01(\v2\x15.loop.TransactionMetaR\x04meta\x12\"\n" + - "\x05value\x18\x06 \x01(\v2\f.loop.BigIntR\x05value\"c\n" + + "\x0etransaction_id\x18\x01 \x01(\tR\rtransactionId\"c\n" + "\x19GetTransactionStatusReply\x12F\n" + "\x12transaction_status\x18\x01 \x01(\x0e2\x17.loop.TransactionStatusR\x11transactionStatus\"\x8c\x01\n" + "\x15GetFeeComponentsReply\x121\n" + "\rexecution_fee\x18\x01 \x01(\v2\f.loop.BigIntR\fexecutionFee\x12@\n" + - "\x15data_availability_fee\x18\x02 \x01(\v2\f.loop.BigIntR\x13dataAvailabilityFee\"Q\n" + - "\x13GetEstimateFeeReply\x12\x1e\n" + - "\x03fee\x18\x01 \x01(\v2\f.loop.BigIntR\x03fee\x12\x1a\n" + - "\bdecimals\x18\x02 \x01(\rR\bdecimals*\xd6\x01\n" + + "\x15data_availability_fee\x18\x02 \x01(\v2\f.loop.BigIntR\x13dataAvailabilityFee*\xd6\x01\n" + "\x11TransactionStatus\x12\x1e\n" + "\x1aTRANSACTION_STATUS_UNKNOWN\x10\x00\x12\x1e\n" + "\x1aTRANSACTION_STATUS_PENDING\x10\x01\x12\"\n" + "\x1eTRANSACTION_STATUS_UNCONFIRMED\x10\x02\x12 \n" + "\x1cTRANSACTION_STATUS_FINALIZED\x10\x03\x12\x1d\n" + "\x19TRANSACTION_STATUS_FAILED\x10\x04\x12\x1c\n" + - "\x18TRANSACTION_STATUS_FATAL\x10\x052\xd4\x02\n" + + "\x18TRANSACTION_STATUS_FATAL\x10\x052\x88\x02\n" + "\x0eContractWriter\x12M\n" + "\x11SubmitTransaction\x12\x1e.loop.SubmitTransactionRequest\x1a\x16.google.protobuf.Empty\"\x00\x12\\\n" + "\x14GetTransactionStatus\x12!.loop.GetTransactionStatusRequest\x1a\x1f.loop.GetTransactionStatusReply\"\x00\x12I\n" + - "\x10GetFeeComponents\x12\x16.google.protobuf.Empty\x1a\x1b.loop.GetFeeComponentsReply\"\x00\x12J\n" + - "\x0eGetEstimateFee\x12\x1b.loop.GetEstimateFeeRequest\x1a\x19.loop.GetEstimateFeeReply\"\x00BCZAgithub.com/smartcontractkit/chainlink-common/pkg/loop/internal/pbb\x06proto3" + "\x10GetFeeComponents\x12\x16.google.protobuf.Empty\x1a\x1b.loop.GetFeeComponentsReply\"\x00BCZAgithub.com/smartcontractkit/chainlink-common/pkg/loop/internal/pbb\x06proto3" var ( file_contract_writer_proto_rawDescOnce sync.Once @@ -568,45 +418,37 @@ func file_contract_writer_proto_rawDescGZIP() []byte { } var file_contract_writer_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_contract_writer_proto_msgTypes = make([]protoimpl.MessageInfo, 7) +var file_contract_writer_proto_msgTypes = make([]protoimpl.MessageInfo, 5) var file_contract_writer_proto_goTypes = []any{ (TransactionStatus)(0), // 0: loop.TransactionStatus (*SubmitTransactionRequest)(nil), // 1: loop.SubmitTransactionRequest (*TransactionMeta)(nil), // 2: loop.TransactionMeta (*GetTransactionStatusRequest)(nil), // 3: loop.GetTransactionStatusRequest - (*GetEstimateFeeRequest)(nil), // 4: loop.GetEstimateFeeRequest - (*GetTransactionStatusReply)(nil), // 5: loop.GetTransactionStatusReply - (*GetFeeComponentsReply)(nil), // 6: loop.GetFeeComponentsReply - (*GetEstimateFeeReply)(nil), // 7: loop.GetEstimateFeeReply - (*VersionedBytes)(nil), // 8: loop.VersionedBytes - (*BigInt)(nil), // 9: loop.BigInt - (*emptypb.Empty)(nil), // 10: google.protobuf.Empty + (*GetTransactionStatusReply)(nil), // 4: loop.GetTransactionStatusReply + (*GetFeeComponentsReply)(nil), // 5: loop.GetFeeComponentsReply + (*VersionedBytes)(nil), // 6: loop.VersionedBytes + (*BigInt)(nil), // 7: loop.BigInt + (*emptypb.Empty)(nil), // 8: google.protobuf.Empty } var file_contract_writer_proto_depIdxs = []int32{ - 8, // 0: loop.SubmitTransactionRequest.params:type_name -> loop.VersionedBytes + 6, // 0: loop.SubmitTransactionRequest.params:type_name -> loop.VersionedBytes 2, // 1: loop.SubmitTransactionRequest.meta:type_name -> loop.TransactionMeta - 9, // 2: loop.SubmitTransactionRequest.value:type_name -> loop.BigInt - 9, // 3: loop.TransactionMeta.gas_limit:type_name -> loop.BigInt - 8, // 4: loop.GetEstimateFeeRequest.params:type_name -> loop.VersionedBytes - 2, // 5: loop.GetEstimateFeeRequest.meta:type_name -> loop.TransactionMeta - 9, // 6: loop.GetEstimateFeeRequest.value:type_name -> loop.BigInt - 0, // 7: loop.GetTransactionStatusReply.transaction_status:type_name -> loop.TransactionStatus - 9, // 8: loop.GetFeeComponentsReply.execution_fee:type_name -> loop.BigInt - 9, // 9: loop.GetFeeComponentsReply.data_availability_fee:type_name -> loop.BigInt - 9, // 10: loop.GetEstimateFeeReply.fee:type_name -> loop.BigInt - 1, // 11: loop.ContractWriter.SubmitTransaction:input_type -> loop.SubmitTransactionRequest - 3, // 12: loop.ContractWriter.GetTransactionStatus:input_type -> loop.GetTransactionStatusRequest - 10, // 13: loop.ContractWriter.GetFeeComponents:input_type -> google.protobuf.Empty - 4, // 14: loop.ContractWriter.GetEstimateFee:input_type -> loop.GetEstimateFeeRequest - 10, // 15: loop.ContractWriter.SubmitTransaction:output_type -> google.protobuf.Empty - 5, // 16: loop.ContractWriter.GetTransactionStatus:output_type -> loop.GetTransactionStatusReply - 6, // 17: loop.ContractWriter.GetFeeComponents:output_type -> loop.GetFeeComponentsReply - 7, // 18: loop.ContractWriter.GetEstimateFee:output_type -> loop.GetEstimateFeeReply - 15, // [15:19] is the sub-list for method output_type - 11, // [11:15] is the sub-list for method input_type - 11, // [11:11] is the sub-list for extension type_name - 11, // [11:11] is the sub-list for extension extendee - 0, // [0:11] is the sub-list for field type_name + 7, // 2: loop.SubmitTransactionRequest.value:type_name -> loop.BigInt + 7, // 3: loop.TransactionMeta.gas_limit:type_name -> loop.BigInt + 0, // 4: loop.GetTransactionStatusReply.transaction_status:type_name -> loop.TransactionStatus + 7, // 5: loop.GetFeeComponentsReply.execution_fee:type_name -> loop.BigInt + 7, // 6: loop.GetFeeComponentsReply.data_availability_fee:type_name -> loop.BigInt + 1, // 7: loop.ContractWriter.SubmitTransaction:input_type -> loop.SubmitTransactionRequest + 3, // 8: loop.ContractWriter.GetTransactionStatus:input_type -> loop.GetTransactionStatusRequest + 8, // 9: loop.ContractWriter.GetFeeComponents:input_type -> google.protobuf.Empty + 8, // 10: loop.ContractWriter.SubmitTransaction:output_type -> google.protobuf.Empty + 4, // 11: loop.ContractWriter.GetTransactionStatus:output_type -> loop.GetTransactionStatusReply + 5, // 12: loop.ContractWriter.GetFeeComponents:output_type -> loop.GetFeeComponentsReply + 10, // [10:13] is the sub-list for method output_type + 7, // [7:10] is the sub-list for method input_type + 7, // [7:7] is the sub-list for extension type_name + 7, // [7:7] is the sub-list for extension extendee + 0, // [0:7] is the sub-list for field type_name } func init() { file_contract_writer_proto_init() } @@ -622,7 +464,7 @@ func file_contract_writer_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_contract_writer_proto_rawDesc), len(file_contract_writer_proto_rawDesc)), NumEnums: 1, - NumMessages: 7, + NumMessages: 5, NumExtensions: 0, NumServices: 1, }, diff --git a/pkg/loop/internal/pb/contract_writer.proto b/pkg/loop/internal/pb/contract_writer.proto index 4bd97e014..dd33f99cd 100644 --- a/pkg/loop/internal/pb/contract_writer.proto +++ b/pkg/loop/internal/pb/contract_writer.proto @@ -12,7 +12,6 @@ service ContractWriter { rpc SubmitTransaction(SubmitTransactionRequest) returns (google.protobuf.Empty) {} rpc GetTransactionStatus(GetTransactionStatusRequest) returns (GetTransactionStatusReply) {} rpc GetFeeComponents(google.protobuf.Empty) returns (GetFeeComponentsReply) {} - rpc GetEstimateFee(GetEstimateFeeRequest) returns (GetEstimateFeeReply) {} } message SubmitTransactionRequest { @@ -35,16 +34,6 @@ message GetTransactionStatusRequest { string transaction_id = 1; } -// GetEstimateFeeReply has arguments for [github.com/smartcontractkit/chainlink-common/pkg/types.ContractWriter.GetEstimateFee]. -message GetEstimateFeeRequest { - string contract_name = 1; - string method = 2; - VersionedBytes params = 3; - string to_address = 4; - TransactionMeta meta = 5; - BigInt value = 6; -} - // TransactionStatus is an enum for the status of a transaction. // This should always be a 1-1 mapping to: [github.com/smartcontractkit/chainlink-common/pkg/types.TransactionStatus]. enum TransactionStatus { @@ -66,9 +55,3 @@ message GetFeeComponentsReply { BigInt execution_fee = 1; BigInt data_availability_fee = 2; } - -// GetEstimateFeeReply has return arguments for [github.com/smartcontractkit/chainlink-common/pkg/types.ContractWriter.GetEstimateFee]. -message GetEstimateFeeReply { - BigInt fee = 1; - uint32 decimals = 2; -} diff --git a/pkg/loop/internal/pb/contract_writer_grpc.pb.go b/pkg/loop/internal/pb/contract_writer_grpc.pb.go index e9bcba861..6c2da2796 100644 --- a/pkg/loop/internal/pb/contract_writer_grpc.pb.go +++ b/pkg/loop/internal/pb/contract_writer_grpc.pb.go @@ -23,7 +23,6 @@ const ( ContractWriter_SubmitTransaction_FullMethodName = "/loop.ContractWriter/SubmitTransaction" ContractWriter_GetTransactionStatus_FullMethodName = "/loop.ContractWriter/GetTransactionStatus" ContractWriter_GetFeeComponents_FullMethodName = "/loop.ContractWriter/GetFeeComponents" - ContractWriter_GetEstimateFee_FullMethodName = "/loop.ContractWriter/GetEstimateFee" ) // ContractWriterClient is the client API for ContractWriter service. @@ -33,7 +32,6 @@ type ContractWriterClient interface { SubmitTransaction(ctx context.Context, in *SubmitTransactionRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) GetTransactionStatus(ctx context.Context, in *GetTransactionStatusRequest, opts ...grpc.CallOption) (*GetTransactionStatusReply, error) GetFeeComponents(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*GetFeeComponentsReply, error) - GetEstimateFee(ctx context.Context, in *GetEstimateFeeRequest, opts ...grpc.CallOption) (*GetEstimateFeeReply, error) } type contractWriterClient struct { @@ -74,16 +72,6 @@ func (c *contractWriterClient) GetFeeComponents(ctx context.Context, in *emptypb return out, nil } -func (c *contractWriterClient) GetEstimateFee(ctx context.Context, in *GetEstimateFeeRequest, opts ...grpc.CallOption) (*GetEstimateFeeReply, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) - out := new(GetEstimateFeeReply) - err := c.cc.Invoke(ctx, ContractWriter_GetEstimateFee_FullMethodName, in, out, cOpts...) - if err != nil { - return nil, err - } - return out, nil -} - // ContractWriterServer is the server API for ContractWriter service. // All implementations must embed UnimplementedContractWriterServer // for forward compatibility. @@ -91,7 +79,6 @@ type ContractWriterServer interface { SubmitTransaction(context.Context, *SubmitTransactionRequest) (*emptypb.Empty, error) GetTransactionStatus(context.Context, *GetTransactionStatusRequest) (*GetTransactionStatusReply, error) GetFeeComponents(context.Context, *emptypb.Empty) (*GetFeeComponentsReply, error) - GetEstimateFee(context.Context, *GetEstimateFeeRequest) (*GetEstimateFeeReply, error) mustEmbedUnimplementedContractWriterServer() } @@ -111,9 +98,6 @@ func (UnimplementedContractWriterServer) GetTransactionStatus(context.Context, * func (UnimplementedContractWriterServer) GetFeeComponents(context.Context, *emptypb.Empty) (*GetFeeComponentsReply, error) { return nil, status.Errorf(codes.Unimplemented, "method GetFeeComponents not implemented") } -func (UnimplementedContractWriterServer) GetEstimateFee(context.Context, *GetEstimateFeeRequest) (*GetEstimateFeeReply, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetEstimateFee not implemented") -} func (UnimplementedContractWriterServer) mustEmbedUnimplementedContractWriterServer() {} func (UnimplementedContractWriterServer) testEmbeddedByValue() {} @@ -189,24 +173,6 @@ func _ContractWriter_GetFeeComponents_Handler(srv interface{}, ctx context.Conte return interceptor(ctx, in, info, handler) } -func _ContractWriter_GetEstimateFee_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(GetEstimateFeeRequest) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(ContractWriterServer).GetEstimateFee(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: ContractWriter_GetEstimateFee_FullMethodName, - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(ContractWriterServer).GetEstimateFee(ctx, req.(*GetEstimateFeeRequest)) - } - return interceptor(ctx, in, info, handler) -} - // ContractWriter_ServiceDesc is the grpc.ServiceDesc for ContractWriter service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -226,10 +192,6 @@ var ContractWriter_ServiceDesc = grpc.ServiceDesc{ MethodName: "GetFeeComponents", Handler: _ContractWriter_GetFeeComponents_Handler, }, - { - MethodName: "GetEstimateFee", - Handler: _ContractWriter_GetEstimateFee_Handler, - }, }, Streams: []grpc.StreamDesc{}, Metadata: "contract_writer.proto", diff --git a/pkg/loop/internal/relayer/pluginprovider/contractwriter/contract_writer.go b/pkg/loop/internal/relayer/pluginprovider/contractwriter/contract_writer.go index 99f975f8f..bbf890f6b 100644 --- a/pkg/loop/internal/relayer/pluginprovider/contractwriter/contract_writer.go +++ b/pkg/loop/internal/relayer/pluginprovider/contractwriter/contract_writer.go @@ -89,32 +89,6 @@ func (c *Client) GetFeeComponents(ctx context.Context) (*types.ChainFeeComponent }, nil } -func (c *Client) GetEstimateFee(ctx context.Context, contract, method string, params any, toAddress string, meta *types.TxMeta, val *big.Int) (types.EstimateFee, error) { - versionedParams, err := contractreader.EncodeVersionedBytes(params, c.encodeWith) - if err != nil { - return types.EstimateFee{}, err - } - - req := &pb.GetEstimateFeeRequest{ - ContractName: contract, - Method: method, - Params: versionedParams, - ToAddress: toAddress, - Meta: TxMetaToProto(meta), - Value: pb.NewBigIntFromInt(val), - } - - reply, err := c.grpc.GetEstimateFee(ctx, req) - if err != nil { - return types.EstimateFee{}, net.WrapRPCErr(err) - } - - return types.EstimateFee{ - Fee: reply.Fee.Int(), - Decimals: reply.Decimals, - }, nil -} - // Server. var _ pb.ContractWriterServer = (*Server)(nil) @@ -181,23 +155,6 @@ func (s *Server) GetFeeComponents(ctx context.Context, _ *emptypb.Empty) (*pb.Ge }, nil } -func (s *Server) GetEstimateFee(ctx context.Context, req *pb.GetEstimateFeeRequest) (*pb.GetEstimateFeeReply, error) { - params := map[string]any{} - if err := contractreader.DecodeVersionedBytes(¶ms, req.Params); err != nil { - return nil, err - } - - estimateFee, err := s.impl.GetEstimateFee(ctx, req.ContractName, req.Method, params, req.ToAddress, TxMetaFromProto(req.Meta), req.Value.Int()) - if err != nil { - return nil, err - } - - return &pb.GetEstimateFeeReply{ - Fee: pb.NewBigIntFromInt(estimateFee.Fee), - Decimals: estimateFee.Decimals, - }, nil -} - func RegisterContractWriterService(s *grpc.Server, contractWriter types.ContractWriter) { pb.RegisterServiceServer(s, &goplugin.ServiceServer{Srv: contractWriter}) pb.RegisterContractWriterServer(s, NewServer(contractWriter)) diff --git a/pkg/types/contract_writer.go b/pkg/types/contract_writer.go index d8e4ea291..a5ccc3ac0 100644 --- a/pkg/types/contract_writer.go +++ b/pkg/types/contract_writer.go @@ -28,10 +28,6 @@ type ContractWriter interface { // GetFeeComponents retrieves the associated gas costs for executing a transaction. GetFeeComponents(ctx context.Context) (*ChainFeeComponents, error) - - // GetEstimateFee returns total cost of TX execution in the underlying chain's currency - // The value (val) is included in the fee calculation. - GetEstimateFee(ctx context.Context, contract, method string, args any, toAddress string, meta *TxMeta, val *big.Int) (EstimateFee, error) } // TxMeta contains metadata fields for a transaction. @@ -64,9 +60,3 @@ type ChainFeeComponents struct { // The cost associated with an L2 posting a transaction's data to the L1. DataAvailabilityFee *big.Int } - -// Estimated total cost of TX execution in the underlying chain's currency -type EstimateFee struct { - Fee *big.Int - Decimals uint32 -} From 3277762c47b6a49b1eb7daf1a4c2af31bf4f19d8 Mon Sep 17 00:00:00 2001 From: Silas Lenihan Date: Tue, 20 May 2025 11:27:18 -0400 Subject: [PATCH 03/16] Reapply "add GetEstimateFee (#1196)" This reverts commit f3ed96e428724590d600ab73b9c6be5fce146d10. --- pkg/loop/internal/pb/contract_writer.pb.go | 226 +++++++++++++++--- pkg/loop/internal/pb/contract_writer.proto | 17 ++ .../internal/pb/contract_writer_grpc.pb.go | 38 +++ .../contractwriter/contract_writer.go | 43 ++++ pkg/types/contract_writer.go | 10 + 5 files changed, 300 insertions(+), 34 deletions(-) diff --git a/pkg/loop/internal/pb/contract_writer.pb.go b/pkg/loop/internal/pb/contract_writer.pb.go index 5b61982db..372380200 100644 --- a/pkg/loop/internal/pb/contract_writer.pb.go +++ b/pkg/loop/internal/pb/contract_writer.pb.go @@ -271,6 +271,91 @@ func (x *GetTransactionStatusRequest) GetTransactionId() string { return "" } +// GetEstimateFeeReply has arguments for [github.com/smartcontractkit/chainlink-common/pkg/types.ContractWriter.GetEstimateFee]. +type GetEstimateFeeRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + ContractName string `protobuf:"bytes,1,opt,name=contract_name,json=contractName,proto3" json:"contract_name,omitempty"` + Method string `protobuf:"bytes,2,opt,name=method,proto3" json:"method,omitempty"` + Params *VersionedBytes `protobuf:"bytes,3,opt,name=params,proto3" json:"params,omitempty"` + ToAddress string `protobuf:"bytes,4,opt,name=to_address,json=toAddress,proto3" json:"to_address,omitempty"` + Meta *TransactionMeta `protobuf:"bytes,5,opt,name=meta,proto3" json:"meta,omitempty"` + Value *BigInt `protobuf:"bytes,6,opt,name=value,proto3" json:"value,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetEstimateFeeRequest) Reset() { + *x = GetEstimateFeeRequest{} + mi := &file_contract_writer_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetEstimateFeeRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetEstimateFeeRequest) ProtoMessage() {} + +func (x *GetEstimateFeeRequest) ProtoReflect() protoreflect.Message { + mi := &file_contract_writer_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetEstimateFeeRequest.ProtoReflect.Descriptor instead. +func (*GetEstimateFeeRequest) Descriptor() ([]byte, []int) { + return file_contract_writer_proto_rawDescGZIP(), []int{3} +} + +func (x *GetEstimateFeeRequest) GetContractName() string { + if x != nil { + return x.ContractName + } + return "" +} + +func (x *GetEstimateFeeRequest) GetMethod() string { + if x != nil { + return x.Method + } + return "" +} + +func (x *GetEstimateFeeRequest) GetParams() *VersionedBytes { + if x != nil { + return x.Params + } + return nil +} + +func (x *GetEstimateFeeRequest) GetToAddress() string { + if x != nil { + return x.ToAddress + } + return "" +} + +func (x *GetEstimateFeeRequest) GetMeta() *TransactionMeta { + if x != nil { + return x.Meta + } + return nil +} + +func (x *GetEstimateFeeRequest) GetValue() *BigInt { + if x != nil { + return x.Value + } + return nil +} + // GetTransactionStatusReply has return arguments for [github.com/smartcontractkit/chainlink-common/pkg/types.ContractWriter.GetTransactionStatus]. type GetTransactionStatusReply struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -281,7 +366,7 @@ type GetTransactionStatusReply struct { func (x *GetTransactionStatusReply) Reset() { *x = GetTransactionStatusReply{} - mi := &file_contract_writer_proto_msgTypes[3] + mi := &file_contract_writer_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -293,7 +378,7 @@ func (x *GetTransactionStatusReply) String() string { func (*GetTransactionStatusReply) ProtoMessage() {} func (x *GetTransactionStatusReply) ProtoReflect() protoreflect.Message { - mi := &file_contract_writer_proto_msgTypes[3] + mi := &file_contract_writer_proto_msgTypes[4] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -306,7 +391,7 @@ func (x *GetTransactionStatusReply) ProtoReflect() protoreflect.Message { // Deprecated: Use GetTransactionStatusReply.ProtoReflect.Descriptor instead. func (*GetTransactionStatusReply) Descriptor() ([]byte, []int) { - return file_contract_writer_proto_rawDescGZIP(), []int{3} + return file_contract_writer_proto_rawDescGZIP(), []int{4} } func (x *GetTransactionStatusReply) GetTransactionStatus() TransactionStatus { @@ -327,7 +412,7 @@ type GetFeeComponentsReply struct { func (x *GetFeeComponentsReply) Reset() { *x = GetFeeComponentsReply{} - mi := &file_contract_writer_proto_msgTypes[4] + mi := &file_contract_writer_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -339,7 +424,7 @@ func (x *GetFeeComponentsReply) String() string { func (*GetFeeComponentsReply) ProtoMessage() {} func (x *GetFeeComponentsReply) ProtoReflect() protoreflect.Message { - mi := &file_contract_writer_proto_msgTypes[4] + mi := &file_contract_writer_proto_msgTypes[5] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -352,7 +437,7 @@ func (x *GetFeeComponentsReply) ProtoReflect() protoreflect.Message { // Deprecated: Use GetFeeComponentsReply.ProtoReflect.Descriptor instead. func (*GetFeeComponentsReply) Descriptor() ([]byte, []int) { - return file_contract_writer_proto_rawDescGZIP(), []int{4} + return file_contract_writer_proto_rawDescGZIP(), []int{5} } func (x *GetFeeComponentsReply) GetExecutionFee() *BigInt { @@ -369,6 +454,59 @@ func (x *GetFeeComponentsReply) GetDataAvailabilityFee() *BigInt { return nil } +// GetEstimateFeeReply has return arguments for [github.com/smartcontractkit/chainlink-common/pkg/types.ContractWriter.GetEstimateFee]. +type GetEstimateFeeReply struct { + state protoimpl.MessageState `protogen:"open.v1"` + Fee *BigInt `protobuf:"bytes,1,opt,name=fee,proto3" json:"fee,omitempty"` + Decimals uint32 `protobuf:"varint,2,opt,name=decimals,proto3" json:"decimals,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetEstimateFeeReply) Reset() { + *x = GetEstimateFeeReply{} + mi := &file_contract_writer_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetEstimateFeeReply) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetEstimateFeeReply) ProtoMessage() {} + +func (x *GetEstimateFeeReply) ProtoReflect() protoreflect.Message { + mi := &file_contract_writer_proto_msgTypes[6] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetEstimateFeeReply.ProtoReflect.Descriptor instead. +func (*GetEstimateFeeReply) Descriptor() ([]byte, []int) { + return file_contract_writer_proto_rawDescGZIP(), []int{6} +} + +func (x *GetEstimateFeeReply) GetFee() *BigInt { + if x != nil { + return x.Fee + } + return nil +} + +func (x *GetEstimateFeeReply) GetDecimals() uint32 { + if x != nil { + return x.Decimals + } + return 0 +} + var File_contract_writer_proto protoreflect.FileDescriptor const file_contract_writer_proto_rawDesc = "" + @@ -387,23 +525,35 @@ const file_contract_writer_proto_rawDesc = "" + "\x15workflow_execution_id\x18\x01 \x01(\tR\x13workflowExecutionId\x12)\n" + "\tgas_limit\x18\x02 \x01(\v2\f.loop.BigIntR\bgasLimit\"D\n" + "\x1bGetTransactionStatusRequest\x12%\n" + - "\x0etransaction_id\x18\x01 \x01(\tR\rtransactionId\"c\n" + + "\x0etransaction_id\x18\x01 \x01(\tR\rtransactionId\"\xf0\x01\n" + + "\x15GetEstimateFeeRequest\x12#\n" + + "\rcontract_name\x18\x01 \x01(\tR\fcontractName\x12\x16\n" + + "\x06method\x18\x02 \x01(\tR\x06method\x12,\n" + + "\x06params\x18\x03 \x01(\v2\x14.loop.VersionedBytesR\x06params\x12\x1d\n" + + "\n" + + "to_address\x18\x04 \x01(\tR\ttoAddress\x12)\n" + + "\x04meta\x18\x05 \x01(\v2\x15.loop.TransactionMetaR\x04meta\x12\"\n" + + "\x05value\x18\x06 \x01(\v2\f.loop.BigIntR\x05value\"c\n" + "\x19GetTransactionStatusReply\x12F\n" + "\x12transaction_status\x18\x01 \x01(\x0e2\x17.loop.TransactionStatusR\x11transactionStatus\"\x8c\x01\n" + "\x15GetFeeComponentsReply\x121\n" + "\rexecution_fee\x18\x01 \x01(\v2\f.loop.BigIntR\fexecutionFee\x12@\n" + - "\x15data_availability_fee\x18\x02 \x01(\v2\f.loop.BigIntR\x13dataAvailabilityFee*\xd6\x01\n" + + "\x15data_availability_fee\x18\x02 \x01(\v2\f.loop.BigIntR\x13dataAvailabilityFee\"Q\n" + + "\x13GetEstimateFeeReply\x12\x1e\n" + + "\x03fee\x18\x01 \x01(\v2\f.loop.BigIntR\x03fee\x12\x1a\n" + + "\bdecimals\x18\x02 \x01(\rR\bdecimals*\xd6\x01\n" + "\x11TransactionStatus\x12\x1e\n" + "\x1aTRANSACTION_STATUS_UNKNOWN\x10\x00\x12\x1e\n" + "\x1aTRANSACTION_STATUS_PENDING\x10\x01\x12\"\n" + "\x1eTRANSACTION_STATUS_UNCONFIRMED\x10\x02\x12 \n" + "\x1cTRANSACTION_STATUS_FINALIZED\x10\x03\x12\x1d\n" + "\x19TRANSACTION_STATUS_FAILED\x10\x04\x12\x1c\n" + - "\x18TRANSACTION_STATUS_FATAL\x10\x052\x88\x02\n" + + "\x18TRANSACTION_STATUS_FATAL\x10\x052\xd4\x02\n" + "\x0eContractWriter\x12M\n" + "\x11SubmitTransaction\x12\x1e.loop.SubmitTransactionRequest\x1a\x16.google.protobuf.Empty\"\x00\x12\\\n" + "\x14GetTransactionStatus\x12!.loop.GetTransactionStatusRequest\x1a\x1f.loop.GetTransactionStatusReply\"\x00\x12I\n" + - "\x10GetFeeComponents\x12\x16.google.protobuf.Empty\x1a\x1b.loop.GetFeeComponentsReply\"\x00BCZAgithub.com/smartcontractkit/chainlink-common/pkg/loop/internal/pbb\x06proto3" + "\x10GetFeeComponents\x12\x16.google.protobuf.Empty\x1a\x1b.loop.GetFeeComponentsReply\"\x00\x12J\n" + + "\x0eGetEstimateFee\x12\x1b.loop.GetEstimateFeeRequest\x1a\x19.loop.GetEstimateFeeReply\"\x00BCZAgithub.com/smartcontractkit/chainlink-common/pkg/loop/internal/pbb\x06proto3" var ( file_contract_writer_proto_rawDescOnce sync.Once @@ -418,37 +568,45 @@ func file_contract_writer_proto_rawDescGZIP() []byte { } var file_contract_writer_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_contract_writer_proto_msgTypes = make([]protoimpl.MessageInfo, 5) +var file_contract_writer_proto_msgTypes = make([]protoimpl.MessageInfo, 7) var file_contract_writer_proto_goTypes = []any{ (TransactionStatus)(0), // 0: loop.TransactionStatus (*SubmitTransactionRequest)(nil), // 1: loop.SubmitTransactionRequest (*TransactionMeta)(nil), // 2: loop.TransactionMeta (*GetTransactionStatusRequest)(nil), // 3: loop.GetTransactionStatusRequest - (*GetTransactionStatusReply)(nil), // 4: loop.GetTransactionStatusReply - (*GetFeeComponentsReply)(nil), // 5: loop.GetFeeComponentsReply - (*VersionedBytes)(nil), // 6: loop.VersionedBytes - (*BigInt)(nil), // 7: loop.BigInt - (*emptypb.Empty)(nil), // 8: google.protobuf.Empty + (*GetEstimateFeeRequest)(nil), // 4: loop.GetEstimateFeeRequest + (*GetTransactionStatusReply)(nil), // 5: loop.GetTransactionStatusReply + (*GetFeeComponentsReply)(nil), // 6: loop.GetFeeComponentsReply + (*GetEstimateFeeReply)(nil), // 7: loop.GetEstimateFeeReply + (*VersionedBytes)(nil), // 8: loop.VersionedBytes + (*BigInt)(nil), // 9: loop.BigInt + (*emptypb.Empty)(nil), // 10: google.protobuf.Empty } var file_contract_writer_proto_depIdxs = []int32{ - 6, // 0: loop.SubmitTransactionRequest.params:type_name -> loop.VersionedBytes + 8, // 0: loop.SubmitTransactionRequest.params:type_name -> loop.VersionedBytes 2, // 1: loop.SubmitTransactionRequest.meta:type_name -> loop.TransactionMeta - 7, // 2: loop.SubmitTransactionRequest.value:type_name -> loop.BigInt - 7, // 3: loop.TransactionMeta.gas_limit:type_name -> loop.BigInt - 0, // 4: loop.GetTransactionStatusReply.transaction_status:type_name -> loop.TransactionStatus - 7, // 5: loop.GetFeeComponentsReply.execution_fee:type_name -> loop.BigInt - 7, // 6: loop.GetFeeComponentsReply.data_availability_fee:type_name -> loop.BigInt - 1, // 7: loop.ContractWriter.SubmitTransaction:input_type -> loop.SubmitTransactionRequest - 3, // 8: loop.ContractWriter.GetTransactionStatus:input_type -> loop.GetTransactionStatusRequest - 8, // 9: loop.ContractWriter.GetFeeComponents:input_type -> google.protobuf.Empty - 8, // 10: loop.ContractWriter.SubmitTransaction:output_type -> google.protobuf.Empty - 4, // 11: loop.ContractWriter.GetTransactionStatus:output_type -> loop.GetTransactionStatusReply - 5, // 12: loop.ContractWriter.GetFeeComponents:output_type -> loop.GetFeeComponentsReply - 10, // [10:13] is the sub-list for method output_type - 7, // [7:10] is the sub-list for method input_type - 7, // [7:7] is the sub-list for extension type_name - 7, // [7:7] is the sub-list for extension extendee - 0, // [0:7] is the sub-list for field type_name + 9, // 2: loop.SubmitTransactionRequest.value:type_name -> loop.BigInt + 9, // 3: loop.TransactionMeta.gas_limit:type_name -> loop.BigInt + 8, // 4: loop.GetEstimateFeeRequest.params:type_name -> loop.VersionedBytes + 2, // 5: loop.GetEstimateFeeRequest.meta:type_name -> loop.TransactionMeta + 9, // 6: loop.GetEstimateFeeRequest.value:type_name -> loop.BigInt + 0, // 7: loop.GetTransactionStatusReply.transaction_status:type_name -> loop.TransactionStatus + 9, // 8: loop.GetFeeComponentsReply.execution_fee:type_name -> loop.BigInt + 9, // 9: loop.GetFeeComponentsReply.data_availability_fee:type_name -> loop.BigInt + 9, // 10: loop.GetEstimateFeeReply.fee:type_name -> loop.BigInt + 1, // 11: loop.ContractWriter.SubmitTransaction:input_type -> loop.SubmitTransactionRequest + 3, // 12: loop.ContractWriter.GetTransactionStatus:input_type -> loop.GetTransactionStatusRequest + 10, // 13: loop.ContractWriter.GetFeeComponents:input_type -> google.protobuf.Empty + 4, // 14: loop.ContractWriter.GetEstimateFee:input_type -> loop.GetEstimateFeeRequest + 10, // 15: loop.ContractWriter.SubmitTransaction:output_type -> google.protobuf.Empty + 5, // 16: loop.ContractWriter.GetTransactionStatus:output_type -> loop.GetTransactionStatusReply + 6, // 17: loop.ContractWriter.GetFeeComponents:output_type -> loop.GetFeeComponentsReply + 7, // 18: loop.ContractWriter.GetEstimateFee:output_type -> loop.GetEstimateFeeReply + 15, // [15:19] is the sub-list for method output_type + 11, // [11:15] is the sub-list for method input_type + 11, // [11:11] is the sub-list for extension type_name + 11, // [11:11] is the sub-list for extension extendee + 0, // [0:11] is the sub-list for field type_name } func init() { file_contract_writer_proto_init() } @@ -464,7 +622,7 @@ func file_contract_writer_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_contract_writer_proto_rawDesc), len(file_contract_writer_proto_rawDesc)), NumEnums: 1, - NumMessages: 5, + NumMessages: 7, NumExtensions: 0, NumServices: 1, }, diff --git a/pkg/loop/internal/pb/contract_writer.proto b/pkg/loop/internal/pb/contract_writer.proto index dd33f99cd..4bd97e014 100644 --- a/pkg/loop/internal/pb/contract_writer.proto +++ b/pkg/loop/internal/pb/contract_writer.proto @@ -12,6 +12,7 @@ service ContractWriter { rpc SubmitTransaction(SubmitTransactionRequest) returns (google.protobuf.Empty) {} rpc GetTransactionStatus(GetTransactionStatusRequest) returns (GetTransactionStatusReply) {} rpc GetFeeComponents(google.protobuf.Empty) returns (GetFeeComponentsReply) {} + rpc GetEstimateFee(GetEstimateFeeRequest) returns (GetEstimateFeeReply) {} } message SubmitTransactionRequest { @@ -34,6 +35,16 @@ message GetTransactionStatusRequest { string transaction_id = 1; } +// GetEstimateFeeReply has arguments for [github.com/smartcontractkit/chainlink-common/pkg/types.ContractWriter.GetEstimateFee]. +message GetEstimateFeeRequest { + string contract_name = 1; + string method = 2; + VersionedBytes params = 3; + string to_address = 4; + TransactionMeta meta = 5; + BigInt value = 6; +} + // TransactionStatus is an enum for the status of a transaction. // This should always be a 1-1 mapping to: [github.com/smartcontractkit/chainlink-common/pkg/types.TransactionStatus]. enum TransactionStatus { @@ -55,3 +66,9 @@ message GetFeeComponentsReply { BigInt execution_fee = 1; BigInt data_availability_fee = 2; } + +// GetEstimateFeeReply has return arguments for [github.com/smartcontractkit/chainlink-common/pkg/types.ContractWriter.GetEstimateFee]. +message GetEstimateFeeReply { + BigInt fee = 1; + uint32 decimals = 2; +} diff --git a/pkg/loop/internal/pb/contract_writer_grpc.pb.go b/pkg/loop/internal/pb/contract_writer_grpc.pb.go index 6c2da2796..e9bcba861 100644 --- a/pkg/loop/internal/pb/contract_writer_grpc.pb.go +++ b/pkg/loop/internal/pb/contract_writer_grpc.pb.go @@ -23,6 +23,7 @@ const ( ContractWriter_SubmitTransaction_FullMethodName = "/loop.ContractWriter/SubmitTransaction" ContractWriter_GetTransactionStatus_FullMethodName = "/loop.ContractWriter/GetTransactionStatus" ContractWriter_GetFeeComponents_FullMethodName = "/loop.ContractWriter/GetFeeComponents" + ContractWriter_GetEstimateFee_FullMethodName = "/loop.ContractWriter/GetEstimateFee" ) // ContractWriterClient is the client API for ContractWriter service. @@ -32,6 +33,7 @@ type ContractWriterClient interface { SubmitTransaction(ctx context.Context, in *SubmitTransactionRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) GetTransactionStatus(ctx context.Context, in *GetTransactionStatusRequest, opts ...grpc.CallOption) (*GetTransactionStatusReply, error) GetFeeComponents(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*GetFeeComponentsReply, error) + GetEstimateFee(ctx context.Context, in *GetEstimateFeeRequest, opts ...grpc.CallOption) (*GetEstimateFeeReply, error) } type contractWriterClient struct { @@ -72,6 +74,16 @@ func (c *contractWriterClient) GetFeeComponents(ctx context.Context, in *emptypb return out, nil } +func (c *contractWriterClient) GetEstimateFee(ctx context.Context, in *GetEstimateFeeRequest, opts ...grpc.CallOption) (*GetEstimateFeeReply, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(GetEstimateFeeReply) + err := c.cc.Invoke(ctx, ContractWriter_GetEstimateFee_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + // ContractWriterServer is the server API for ContractWriter service. // All implementations must embed UnimplementedContractWriterServer // for forward compatibility. @@ -79,6 +91,7 @@ type ContractWriterServer interface { SubmitTransaction(context.Context, *SubmitTransactionRequest) (*emptypb.Empty, error) GetTransactionStatus(context.Context, *GetTransactionStatusRequest) (*GetTransactionStatusReply, error) GetFeeComponents(context.Context, *emptypb.Empty) (*GetFeeComponentsReply, error) + GetEstimateFee(context.Context, *GetEstimateFeeRequest) (*GetEstimateFeeReply, error) mustEmbedUnimplementedContractWriterServer() } @@ -98,6 +111,9 @@ func (UnimplementedContractWriterServer) GetTransactionStatus(context.Context, * func (UnimplementedContractWriterServer) GetFeeComponents(context.Context, *emptypb.Empty) (*GetFeeComponentsReply, error) { return nil, status.Errorf(codes.Unimplemented, "method GetFeeComponents not implemented") } +func (UnimplementedContractWriterServer) GetEstimateFee(context.Context, *GetEstimateFeeRequest) (*GetEstimateFeeReply, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetEstimateFee not implemented") +} func (UnimplementedContractWriterServer) mustEmbedUnimplementedContractWriterServer() {} func (UnimplementedContractWriterServer) testEmbeddedByValue() {} @@ -173,6 +189,24 @@ func _ContractWriter_GetFeeComponents_Handler(srv interface{}, ctx context.Conte return interceptor(ctx, in, info, handler) } +func _ContractWriter_GetEstimateFee_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetEstimateFeeRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ContractWriterServer).GetEstimateFee(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ContractWriter_GetEstimateFee_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ContractWriterServer).GetEstimateFee(ctx, req.(*GetEstimateFeeRequest)) + } + return interceptor(ctx, in, info, handler) +} + // ContractWriter_ServiceDesc is the grpc.ServiceDesc for ContractWriter service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -192,6 +226,10 @@ var ContractWriter_ServiceDesc = grpc.ServiceDesc{ MethodName: "GetFeeComponents", Handler: _ContractWriter_GetFeeComponents_Handler, }, + { + MethodName: "GetEstimateFee", + Handler: _ContractWriter_GetEstimateFee_Handler, + }, }, Streams: []grpc.StreamDesc{}, Metadata: "contract_writer.proto", diff --git a/pkg/loop/internal/relayer/pluginprovider/contractwriter/contract_writer.go b/pkg/loop/internal/relayer/pluginprovider/contractwriter/contract_writer.go index bbf890f6b..99f975f8f 100644 --- a/pkg/loop/internal/relayer/pluginprovider/contractwriter/contract_writer.go +++ b/pkg/loop/internal/relayer/pluginprovider/contractwriter/contract_writer.go @@ -89,6 +89,32 @@ func (c *Client) GetFeeComponents(ctx context.Context) (*types.ChainFeeComponent }, nil } +func (c *Client) GetEstimateFee(ctx context.Context, contract, method string, params any, toAddress string, meta *types.TxMeta, val *big.Int) (types.EstimateFee, error) { + versionedParams, err := contractreader.EncodeVersionedBytes(params, c.encodeWith) + if err != nil { + return types.EstimateFee{}, err + } + + req := &pb.GetEstimateFeeRequest{ + ContractName: contract, + Method: method, + Params: versionedParams, + ToAddress: toAddress, + Meta: TxMetaToProto(meta), + Value: pb.NewBigIntFromInt(val), + } + + reply, err := c.grpc.GetEstimateFee(ctx, req) + if err != nil { + return types.EstimateFee{}, net.WrapRPCErr(err) + } + + return types.EstimateFee{ + Fee: reply.Fee.Int(), + Decimals: reply.Decimals, + }, nil +} + // Server. var _ pb.ContractWriterServer = (*Server)(nil) @@ -155,6 +181,23 @@ func (s *Server) GetFeeComponents(ctx context.Context, _ *emptypb.Empty) (*pb.Ge }, nil } +func (s *Server) GetEstimateFee(ctx context.Context, req *pb.GetEstimateFeeRequest) (*pb.GetEstimateFeeReply, error) { + params := map[string]any{} + if err := contractreader.DecodeVersionedBytes(¶ms, req.Params); err != nil { + return nil, err + } + + estimateFee, err := s.impl.GetEstimateFee(ctx, req.ContractName, req.Method, params, req.ToAddress, TxMetaFromProto(req.Meta), req.Value.Int()) + if err != nil { + return nil, err + } + + return &pb.GetEstimateFeeReply{ + Fee: pb.NewBigIntFromInt(estimateFee.Fee), + Decimals: estimateFee.Decimals, + }, nil +} + func RegisterContractWriterService(s *grpc.Server, contractWriter types.ContractWriter) { pb.RegisterServiceServer(s, &goplugin.ServiceServer{Srv: contractWriter}) pb.RegisterContractWriterServer(s, NewServer(contractWriter)) diff --git a/pkg/types/contract_writer.go b/pkg/types/contract_writer.go index a5ccc3ac0..d8e4ea291 100644 --- a/pkg/types/contract_writer.go +++ b/pkg/types/contract_writer.go @@ -28,6 +28,10 @@ type ContractWriter interface { // GetFeeComponents retrieves the associated gas costs for executing a transaction. GetFeeComponents(ctx context.Context) (*ChainFeeComponents, error) + + // GetEstimateFee returns total cost of TX execution in the underlying chain's currency + // The value (val) is included in the fee calculation. + GetEstimateFee(ctx context.Context, contract, method string, args any, toAddress string, meta *TxMeta, val *big.Int) (EstimateFee, error) } // TxMeta contains metadata fields for a transaction. @@ -60,3 +64,9 @@ type ChainFeeComponents struct { // The cost associated with an L2 posting a transaction's data to the L1. DataAvailabilityFee *big.Int } + +// Estimated total cost of TX execution in the underlying chain's currency +type EstimateFee struct { + Fee *big.Int + Decimals uint32 +} From 46548a5cb22fb1e307b65faa876f14c37faaa2de Mon Sep 17 00:00:00 2001 From: Silas Lenihan Date: Fri, 23 May 2025 14:05:34 -0400 Subject: [PATCH 04/16] addressed feedback --- .../consensus/ocr3/types/aggregator.go | 6 +- .../consensus/ocr3/types/aggregator_test.go | 198 +++++++++++++++++- 2 files changed, 197 insertions(+), 7 deletions(-) diff --git a/pkg/capabilities/consensus/ocr3/types/aggregator.go b/pkg/capabilities/consensus/ocr3/types/aggregator.go index db53c0239..444fe466f 100644 --- a/pkg/capabilities/consensus/ocr3/types/aggregator.go +++ b/pkg/capabilities/consensus/ocr3/types/aggregator.go @@ -176,11 +176,7 @@ func Decode(raw []byte) (Metadata, []byte, error) { } func (m Metadata) Length() int { - b, err := m.Encode() - if err != nil { - return 0 - } - return len(b) + return MetadataLen } // Aggregator is the interface that enables a hook to the Outcome() phase of OCR reporting. diff --git a/pkg/capabilities/consensus/ocr3/types/aggregator_test.go b/pkg/capabilities/consensus/ocr3/types/aggregator_test.go index 94660d103..58948ae14 100644 --- a/pkg/capabilities/consensus/ocr3/types/aggregator_test.go +++ b/pkg/capabilities/consensus/ocr3/types/aggregator_test.go @@ -1,11 +1,205 @@ package types import ( + "strings" "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +func TestMetadata_EncodeDecode(t *testing.T) { + metadata := Metadata{ + Version: 1, + ExecutionID: "1234567890123456789012345678901234567890123456789012345678901234", + Timestamp: 1620000000, + DONID: 1, + DONConfigVersion: 1, + WorkflowID: "1234567890123456789012345678901234567890123456789012345678901234", + WorkflowName: "12", + WorkflowOwner: "1234567890123456789012345678901234567890", + ReportID: "1234", + } + + metadata.padWorkflowName() + + encoded, err := metadata.Encode() + require.NoError(t, err) + + require.Len(t, encoded, 109) + + // append tail to encoded + tail := []byte("tail") + encoded = append(encoded, tail...) + decoded, remaining, err := Decode(encoded) + require.NoError(t, err) + require.Equal(t, metadata.Version, decoded.Version) + require.Equal(t, metadata.ExecutionID, decoded.ExecutionID) + require.Equal(t, metadata.Timestamp, decoded.Timestamp) + require.Equal(t, metadata.DONID, decoded.DONID) + require.Equal(t, metadata.DONConfigVersion, decoded.DONConfigVersion) + require.Equal(t, metadata.WorkflowID, decoded.WorkflowID) + require.Equal(t, metadata.WorkflowName, decoded.WorkflowName) + require.Equal(t, metadata.WorkflowOwner, decoded.WorkflowOwner) + require.Equal(t, metadata.ReportID, decoded.ReportID) + require.Equal(t, tail, remaining) +} + +func TestMetadata_Length(t *testing.T) { + var m Metadata + require.Equal(t, MetadataLen, m.Length()) +} + +func TestPadWorkflowName_NoPadWhenExactLength(t *testing.T) { + // 20 hex characters = 10 bytes, exact length + original := "abcdef0123456789abcd" + m := &Metadata{WorkflowName: original} + m.padWorkflowName() + require.Equal(t, original, m.WorkflowName) +} + +func TestPadWorkflowName_TooLong(t *testing.T) { + // 22 hex characters = 11 bytes, should not be truncated by pad + original := "abcdef0123456789abcd01" + m := &Metadata{WorkflowName: original} + m.padWorkflowName() + require.Equal(t, original, m.WorkflowName) +} + +func TestEncode_InvalidHexFields(t *testing.T) { + m := Metadata{ + Version: 1, + ExecutionID: "zzzz", // invalid hex + Timestamp: 0, + DONID: 0, + DONConfigVersion: 0, + WorkflowID: strings.Repeat("00", 32), + WorkflowName: "00", + WorkflowOwner: strings.Repeat("00", 20), + ReportID: "0000", + } + _, err := m.Encode() + require.Error(t, err) + require.Contains(t, err.Error(), "invalid hex") +} + +func TestEncode_WrongLengthFields(t *testing.T) { + tests := []struct { + name string + m Metadata + }{ + { + name: "short ExecutionID", + m: Metadata{ + Version: 1, + ExecutionID: "00", // too short + Timestamp: 0, + DONID: 0, + DONConfigVersion: 0, + WorkflowID: strings.Repeat("00", 32), + WorkflowName: "00", + WorkflowOwner: strings.Repeat("00", 20), + ReportID: "0000", + }, + }, + { + name: "short WorkflowID", + m: Metadata{ + Version: 1, + ExecutionID: strings.Repeat("00", 32), + Timestamp: 0, + DONID: 0, + DONConfigVersion: 0, + WorkflowID: "00", // too short + WorkflowName: "00", + WorkflowOwner: strings.Repeat("00", 20), + ReportID: "0000", + }, + }, + { + name: "long WorkflowName", + m: Metadata{ + Version: 1, + ExecutionID: strings.Repeat("00", 32), + Timestamp: 0, + DONID: 0, + DONConfigVersion: 0, + WorkflowID: strings.Repeat("00", 32), + WorkflowName: strings.Repeat("01", 11), // 22 chars, >20 + WorkflowOwner: strings.Repeat("00", 20), + ReportID: "0000", + }, + }, + { + name: "short WorkflowOwner", + m: Metadata{ + Version: 1, + ExecutionID: strings.Repeat("00", 32), + Timestamp: 0, + DONID: 0, + DONConfigVersion: 0, + WorkflowID: strings.Repeat("00", 32), + WorkflowName: "00", + WorkflowOwner: "00", // too short + ReportID: "0000", + }, + }, + { + name: "short ReportID", + m: Metadata{ + Version: 1, + ExecutionID: strings.Repeat("00", 32), + Timestamp: 0, + DONID: 0, + DONConfigVersion: 0, + WorkflowID: strings.Repeat("00", 32), + WorkflowName: "00", + WorkflowOwner: strings.Repeat("00", 20), + ReportID: "00", // too short + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := tt.m.Encode() + require.Error(t, err) + require.Contains(t, err.Error(), "wrong length") + }) + } +} + +func TestDecode_RawTooShort(t *testing.T) { + _, _, err := Decode([]byte{0x01, 0x02}) + require.Error(t, err) + require.Contains(t, err.Error(), "raw too short") +} + +func TestDecode_RemainingData(t *testing.T) { + m := Metadata{ + Version: 1, + ExecutionID: strings.Repeat("11", 32), + Timestamp: 2, + DONID: 3, + DONConfigVersion: 4, + WorkflowID: strings.Repeat("22", 32), + WorkflowName: "33", + WorkflowOwner: strings.Repeat("44", 20), + ReportID: "5555", + } + m.padWorkflowName() + + encoded, err := m.Encode() + require.NoError(t, err) + // add extra bytes to simulate payload + extra := []byte("extra") + data := append(encoded, extra...) + + decoded, remaining, err := Decode(data) + require.NoError(t, err) + require.Equal(t, extra, remaining) + require.Equal(t, m, decoded) +} + func TestMetadata_padWorkflowName(t *testing.T) { type fields struct { WorkflowName string @@ -50,7 +244,7 @@ func TestMetadata_padWorkflowName(t *testing.T) { WorkflowName: tt.fields.WorkflowName, } m.padWorkflowName() - assert.Equal(t, tt.want, m.WorkflowName, tt.name) + require.Equal(t, tt.want, m.WorkflowName, tt.name) }) } } From 5ca7af883f2a187e974868d2ab770aab7231a231 Mon Sep 17 00:00:00 2001 From: Silas Lenihan Date: Tue, 27 May 2025 14:33:39 -0400 Subject: [PATCH 05/16] Moved proto emitter and helpers to common --- pkg/beholder/proto_emitter.go | 96 ++++++++++++++++++++++++++++ pkg/beholder/schema.go | 116 ++++++++++++++++++++++++++++++++++ pkg/beholder/schema_test.go | 72 +++++++++++++++++++++ 3 files changed, 284 insertions(+) create mode 100644 pkg/beholder/proto_emitter.go create mode 100644 pkg/beholder/schema.go create mode 100644 pkg/beholder/schema_test.go diff --git a/pkg/beholder/proto_emitter.go b/pkg/beholder/proto_emitter.go new file mode 100644 index 000000000..1ade828bd --- /dev/null +++ b/pkg/beholder/proto_emitter.go @@ -0,0 +1,96 @@ +//nolint:revive,staticcheck // disable revive, staticcheck +package beholder + +import ( + "context" + + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" +) + +const ( + // Helper keys to avoid duplicating attributes + CtxKeySkipAppendAttrs = "skip_append_attrs" +) + +// BeholderClient is a Beholder client extension with a custom ProtoEmitter +type BeholderClient struct { + *Client + ProtoEmitter ProtoEmitter +} + +// ProtoEmitter is an interface for emitting protobuf messages +type ProtoEmitter interface { + // Sends message with bytes and attributes to OTel Collector + Emit(ctx context.Context, m proto.Message, attrKVs ...any) error + EmitWithLog(ctx context.Context, m proto.Message, attrKVs ...any) error +} + +// ProtoProcessor is an interface for processing emitted protobuf messages +type ProtoProcessor interface { + Process(ctx context.Context, m proto.Message, attrKVs ...any) error +} + +func NewProtoEmitter(lggr logger.Logger, client *Client, schemaBasePath string) ProtoEmitter { + return &protoEmitter{lggr, client, schemaBasePath} +} + +// protoEmitter is a ProtoEmitter implementation +var _ ProtoEmitter = (*protoEmitter)(nil) + +type protoEmitter struct { + lggr logger.Logger + client *Client + schemaBasePath string +} + +func (e *protoEmitter) Emit(ctx context.Context, m proto.Message, attrKVs ...any) error { + payload, err := proto.Marshal(m) + if err != nil { + // Notice: we log here because emit errors are usually not critical and swallowed by the caller + e.lggr.Errorw("[Beholder] Failed to marshal", "err", err) + return err + } + + // Skip appending attributes if the context says it's already done that + if skip, ok := ctx.Value(CtxKeySkipAppendAttrs).(bool); !ok || !skip { + attrKVs = e.appendAttrsRequired(ctx, m, attrKVs) + } + + // Emit the message with attributes + err = e.client.Emitter.Emit(ctx, payload, attrKVs...) + if err != nil { + // Notice: we log here because emit errors are usually not critical and swallowed by the caller + e.lggr.Errorw("[Beholder] Failed to client.Emitter.Emit", "err", err) + return err + } + + return nil +} + +// EmitWithLog emits a protobuf message with attributes and logs the emitted message +func (e *protoEmitter) EmitWithLog(ctx context.Context, m proto.Message, attrKVs ...any) error { + attrKVs = e.appendAttrsRequired(ctx, m, attrKVs) + // attach a bool switch to ctx to avoid duplicating common attrs + ctx = context.WithValue(ctx, CtxKeySkipAppendAttrs, true) + + // Marshal the message as JSON and log before emitting + // https://protobuf.dev/programming-guides/json/ + mStr := protojson.MarshalOptions{ + UseProtoNames: true, + EmitUnpopulated: true, + }.Format(m) + e.lggr.Infow("[Beholder.emit]", "message", mStr, "attributes", attrKVs) + + return e.Emit(ctx, m, attrKVs...) +} + +// appendAttrsRequired appends required attributes to the attribute key-value list +func (e *protoEmitter) appendAttrsRequired(ctx context.Context, m proto.Message, attrKVs []any) []any { + attrKVs = appendRequiredAttrDataSchema(m, attrKVs, e.schemaBasePath) + attrKVs = appendRequiredAttrEntity(m, attrKVs) + attrKVs = appendRequiredAttrDomain(m, attrKVs) + return attrKVs +} diff --git a/pkg/beholder/schema.go b/pkg/beholder/schema.go new file mode 100644 index 000000000..06b6214b1 --- /dev/null +++ b/pkg/beholder/schema.go @@ -0,0 +1,116 @@ +//nolint:gosimple // disable gosimple +package beholder + +import ( + "fmt" + "path" + "regexp" + "strings" + + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/runtime/protoimpl" +) + +const ( + AttrKeyBeholderDataSchema = "beholder_data_schema" + AttrKeyBeholderEntity = "beholder_entity" + AttrKeyBeholderDomain = "beholder_domain" +) + +// patternSnake is a regular expression to match CamelCase words +// Notice: we use the Unicode property 'Lu' (uppercase letter) to match +// the first letter of the word, and 'P{Lu}' (not uppercase letter) to match +// the rest of the word. +var patternSnake = regexp.MustCompile("(\\p{Lu}+\\P{Lu}*)") + +// toSnakeCase converts a CamelCase to snake_case (used for type -> file name mapping) +func toSnakeCase(s string) string { + s = patternSnake.ReplaceAllString(s, "_${1}") + s, _ = strings.CutPrefix(strings.ToLower(s), "_") + return s +} + +// toSchemaName returns a protobuf message name (short) +func toSchemaName(m proto.Message) string { + return string(protoimpl.X.MessageTypeOf(m).Descriptor().Name()) +} + +// toSchemaName returns a protobuf message name (full) +func toSchemaFullName(m proto.Message) string { + return string(protoimpl.X.MessageTypeOf(m).Descriptor().FullName()) +} + +// toSchemaPath maps a protobuf message to a Beholder schema path +func toSchemaPath(m proto.Message, basePath string) string { + // Notice: a name like 'platform.on_chain.forwarder.ReportProcessed' + protoName := toSchemaFullName(m) + + // We map to a Beholder schema path like '/platform/on-chain/forwarder/report_processed.proto' + protoPath := protoName + protoPath = strings.ReplaceAll(protoPath, ".", "/") + protoPath = strings.ReplaceAll(protoPath, "_", "-") + + // Split the path components (at least one component) + pp := strings.Split(protoPath, "/") + pp[len(pp)-1] = toSnakeCase(pp[len(pp)-1]) + + // Join the path components again + protoPath = strings.Join(pp, "/") + protoPath = fmt.Sprintf("%s.proto", protoPath) + + // Return the full schema path + return path.Join(basePath, protoPath) +} + +// appendRequiredAttrDataSchema adds the message schema path as an attribute (required) +func appendRequiredAttrDataSchema(m proto.Message, attrKVs []any, basePath string) []any { + key := AttrKeyBeholderDataSchema + for i := 0; i < len(attrKVs); i += 2 { + if attrKVs[i] == key { + return attrKVs + } + } + + attrKVs = append(attrKVs, key) + // Needs to be an URI (Beholder requirement) + val := toSchemaPath(m, basePath) + attrKVs = append(attrKVs, val) + return attrKVs +} + +// appendRequiredAttrEntity adds the message entity type as an attribute (required) +func appendRequiredAttrEntity(m proto.Message, attrKVs []any) []any { + key := AttrKeyBeholderEntity + for i := 0; i < len(attrKVs); i += 2 { + if attrKVs[i] == key { + return attrKVs + } + } + + attrKVs = append(attrKVs, key) + attrKVs = append(attrKVs, toSchemaName(m)) + return attrKVs +} + +// appendRequiredAttrDomain adds the message domain as an attribute (required) +func appendRequiredAttrDomain(m proto.Message, attrKVs []any) []any { + key := AttrKeyBeholderDomain + for i := 0; i < len(attrKVs); i += 2 { + if attrKVs[i] == key { + return attrKVs + } + } + + // Notice: a name like 'platform.on_chain.forwarder.ReportProcessed' + protoName := toSchemaFullName(m) + + // Extract first path component (entrypoint package) as a domain + domain := "unknown" + if strings.Contains(protoName, ".") { + domain = strings.Split(protoName, ".")[0] + } + + attrKVs = append(attrKVs, key) + attrKVs = append(attrKVs, domain) + return attrKVs +} diff --git a/pkg/beholder/schema_test.go b/pkg/beholder/schema_test.go new file mode 100644 index 000000000..2709e0acd --- /dev/null +++ b/pkg/beholder/schema_test.go @@ -0,0 +1,72 @@ +package beholder + +import ( + "path" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/descriptorpb" + "google.golang.org/protobuf/types/dynamicpb" +) + +func makeDynamicMessage(t *testing.T, pkg, msgName string) protoreflect.ProtoMessage { + fdProto := &descriptorpb.FileDescriptorProto{ + Name: proto.String("test.proto"), + Package: proto.String(pkg), + MessageType: []*descriptorpb.DescriptorProto{{ + Name: proto.String(msgName), + }}, + } + + fd, err := protodesc.NewFile(fdProto, nil) + require.NoError(t, err) + + md := fd.Messages().ByName(protoreflect.Name(msgName)) + return dynamicpb.NewMessage(md) +} + +func TestToSchemaPath(t *testing.T) { + base := "/" + tests := []struct { + pkg, msgName, expected string + }{ + { + pkg: "alpha.bravo.charlie", + msgName: "FirstTest", + expected: path.Join(base, "alpha/bravo/charlie/first_test.proto"), + }, + { + pkg: "one.two", + msgName: "XMLEncode", + expected: path.Join(base, "one/two/xmlencode.proto"), + }, + { + pkg: "single", + msgName: "SimpleMessage", + expected: path.Join(base, "single/simple_message.proto"), + }, + { + pkg: "a.b.c.d.e", + msgName: "NestedLevel", + expected: path.Join(base, "a/b/c/d/e/nested_level.proto"), + }, + { + pkg: "mix.UpAndDOWN", + msgName: "CamelCaseID", + // package segment "UpAndDOWN" is left verbatim (no hyphenation), only the message gets snake_cased + expected: path.Join(base, "mix/UpAndDOWN/camel_case_id.proto"), + }, + } + + for _, tt := range tests { + t.Run(tt.expected, func(t *testing.T) { + m := makeDynamicMessage(t, tt.pkg, tt.msgName) + got := toSchemaPath(m, base) + assert.Equal(t, tt.expected, got) + }) + } +} From 7d73232acaed1f95dad92c149387381e598af3d6 Mon Sep 17 00:00:00 2001 From: Silas Lenihan Date: Tue, 27 May 2025 14:57:11 -0400 Subject: [PATCH 06/16] Published common attritutes --- pkg/beholder/attributes.go | 7 ++++ pkg/beholder/chip_ingress_emitter.go | 4 +-- pkg/beholder/chip_ingress_emitter_test.go | 34 +++++++++--------- pkg/beholder/client_test.go | 24 ++++++------- pkg/beholder/example_test.go | 12 +++---- pkg/beholder/message.go | 12 +++---- pkg/beholder/message_emitter_test.go | 42 +++++++++++------------ pkg/beholder/schema.go | 6 ---- pkg/capabilities/events/events.go | 2 +- pkg/custmsg/custom_message.go | 6 ++-- pkg/utils/tests/beholder.go | 8 ++--- 11 files changed, 79 insertions(+), 78 deletions(-) create mode 100644 pkg/beholder/attributes.go diff --git a/pkg/beholder/attributes.go b/pkg/beholder/attributes.go new file mode 100644 index 000000000..cd7a1fb5c --- /dev/null +++ b/pkg/beholder/attributes.go @@ -0,0 +1,7 @@ +package beholder + +const ( + AttrKeyBeholderDataSchema = "beholder_data_schema" + AttrKeyBeholderEntity = "beholder_entity" + AttrKeyBeholderDomain = "beholder_domain" +) diff --git a/pkg/beholder/chip_ingress_emitter.go b/pkg/beholder/chip_ingress_emitter.go index 5461ac906..a75da2fcd 100644 --- a/pkg/beholder/chip_ingress_emitter.go +++ b/pkg/beholder/chip_ingress_emitter.go @@ -51,12 +51,12 @@ func ExtractSourceAndType(attrKVs ...any) (string, string, error) { for key, value := range attributes { // Retrieve source and type using either ChIP or legacy attribute names, prioritizing source/type - if key == "source" || (key == "beholder_domain" && sourceDomain == "") { + if key == "source" || (key == AttrKeyBeholderDomain && sourceDomain == "") { if val, ok := value.(string); ok { sourceDomain = val } } - if key == "type" || (key == "beholder_entity" && entityType == "") { + if key == "type" || (key == AttrKeyBeholderEntity && entityType == "") { if val, ok := value.(string); ok { entityType = val } diff --git a/pkg/beholder/chip_ingress_emitter_test.go b/pkg/beholder/chip_ingress_emitter_test.go index bd214dca3..aa99b2490 100644 --- a/pkg/beholder/chip_ingress_emitter_test.go +++ b/pkg/beholder/chip_ingress_emitter_test.go @@ -42,7 +42,7 @@ func TestChipIngressEmit(t *testing.T) { emitter, err := beholder.NewChipIngressEmitter(clientMock) require.NoError(t, err) - err = emitter.Emit(t.Context(), body, "beholder_domain", domain, "beholder_entity", entity) + err = emitter.Emit(t.Context(), body, beholder.AttrKeyBeholderDomain, domain, beholder.AttrKeyBeholderEntity, entity) require.NoError(t, err) clientMock.AssertExpectations(t) @@ -59,7 +59,7 @@ func TestChipIngressEmit(t *testing.T) { emitter, err := beholder.NewChipIngressEmitter(clientMock) require.NoError(t, err) - err = emitter.Emit(t.Context(), body, "beholder_domain", domain) + err = emitter.Emit(t.Context(), body, beholder.AttrKeyBeholderDomain, domain) assert.Error(t, err) }) @@ -74,7 +74,7 @@ func TestChipIngressEmit(t *testing.T) { emitter, err := beholder.NewChipIngressEmitter(clientMock) require.NoError(t, err) - err = emitter.Emit(t.Context(), body, "beholder_domain", domain, "beholder_entity", entity) + err = emitter.Emit(t.Context(), body, beholder.AttrKeyBeholderDomain, domain, beholder.AttrKeyBeholderEntity, entity) require.Error(t, err) clientMock.AssertExpectations(t) @@ -92,7 +92,7 @@ func TestExtractSourceAndType(t *testing.T) { }{ { name: "happy path - domain and entity exist", - attrs: []any{map[string]any{"beholder_domain": "test-domain", "beholder_entity": "test-entity"}}, + attrs: []any{map[string]any{beholder.AttrKeyBeholderDomain: "test-domain", beholder.AttrKeyBeholderEntity: "test-entity"}}, wantDomain: "test-domain", wantEntity: "test-entity", wantErr: false, @@ -106,14 +106,14 @@ func TestExtractSourceAndType(t *testing.T) { }, { name: "happy path - domain and entity exist - uses source/type", - attrs: []any{map[string]any{"source": "other-domain", "beholder_domain": "test-domain", "beholder_entity": "test-entity", "type": "other-entity"}}, + attrs: []any{map[string]any{"source": "other-domain", beholder.AttrKeyBeholderDomain: "test-domain", beholder.AttrKeyBeholderEntity: "test-entity", "type": "other-entity"}}, wantDomain: "other-domain", wantEntity: "other-entity", wantErr: false, }, { name: "missing domain/source", - attrs: []any{map[string]any{"beholder_entity": "test-entity"}}, + attrs: []any{map[string]any{beholder.AttrKeyBeholderEntity: "test-entity"}}, wantDomain: "", wantEntity: "", wantErr: true, @@ -121,7 +121,7 @@ func TestExtractSourceAndType(t *testing.T) { }, { name: "missing entity/type", - attrs: []any{map[string]any{"beholder_domain": "test-domain"}}, + attrs: []any{map[string]any{beholder.AttrKeyBeholderDomain: "test-domain"}}, wantDomain: "", wantEntity: "", wantErr: true, @@ -146,10 +146,10 @@ func TestExtractSourceAndType(t *testing.T) { { name: "domain and entity with additional attributes", attrs: []any{map[string]any{ - "other_key": "other_value", - "beholder_domain": "test-domain", - "beholder_entity": "test-entity", - "something_else": 123, + "other_key": "other_value", + beholder.AttrKeyBeholderDomain: "test-domain", + beholder.AttrKeyBeholderEntity: "test-entity", + "something_else": 123, }}, wantDomain: "test-domain", wantEntity: "test-entity", @@ -158,9 +158,9 @@ func TestExtractSourceAndType(t *testing.T) { { name: "non-string keys ignored", attrs: []any{map[string]any{ - "other_value": "value", - "beholder_domain": "test-domain", - "beholder_entity": "test-entity", + "other_value": "value", + beholder.AttrKeyBeholderDomain: "test-domain", + beholder.AttrKeyBeholderEntity: "test-entity", }, 123, "other_key"}, wantDomain: "test-domain", wantEntity: "test-entity", @@ -169,9 +169,9 @@ func TestExtractSourceAndType(t *testing.T) { { name: "non-string values handled", attrs: []any{map[string]any{ - "other_key": 123, - "beholder_domain": "test-domain", - "beholder_entity": "test-entity", + "other_key": 123, + beholder.AttrKeyBeholderDomain: "test-domain", + beholder.AttrKeyBeholderEntity: "test-entity", }}, wantDomain: "test-domain", wantEntity: "test-entity", diff --git a/pkg/beholder/client_test.go b/pkg/beholder/client_test.go index 320427fda..5fba1d835 100644 --- a/pkg/beholder/client_test.go +++ b/pkg/beholder/client_test.go @@ -42,18 +42,18 @@ func (m *MockExporter) ForceFlush(ctx context.Context) error { func TestClient(t *testing.T) { defaultCustomAttributes := func() map[string]any { return map[string]any{ - "int_key_1": 123, - "int64_key_1": int64(123), - "int32_key_1": int32(123), - "str_key_1": "str_val_1", - "bool_key_1": true, - "float_key_1": 123.456, - "byte_key_1": []byte("byte_val_1"), - "str_slice_key_1": []string{"str_val_1", "str_val_2"}, - "nil_key_1": nil, - "beholder_domain": "TestDomain", // Required field - "beholder_entity": "TestEntity", // Required field - "beholder_data_schema": "/schemas/ids/1001", // Required field, URI + "int_key_1": 123, + "int64_key_1": int64(123), + "int32_key_1": int32(123), + "str_key_1": "str_val_1", + "bool_key_1": true, + "float_key_1": 123.456, + "byte_key_1": []byte("byte_val_1"), + "str_slice_key_1": []string{"str_val_1", "str_val_2"}, + "nil_key_1": nil, + beholder.AttrKeyBeholderDomain: "TestDomain", // Required field + beholder.AttrKeyBeholderEntity: "TestEntity", // Required field + beholder.AttrKeyBeholderDataSchema: "/schemas/ids/1001", // Required field, URI } } defaultMessageBody := []byte("body bytes") diff --git a/pkg/beholder/example_test.go b/pkg/beholder/example_test.go index 2045c3bb8..9d5264087 100644 --- a/pkg/beholder/example_test.go +++ b/pkg/beholder/example_test.go @@ -44,9 +44,9 @@ func ExampleNewClient() { fmt.Println("Emit custom messages") for range 10 { err := beholder.GetEmitter().Emit(context.Background(), payloadBytes, - "beholder_data_schema", "/custom-message/versions/1", // required - "beholder_domain", "ExampleDomain", // required - "beholder_entity", "ExampleEntity", // required + beholder.AttrKeyBeholderDataSchema, "/custom-message/versions/1", // required + beholder.AttrKeyBeholderDomain, "ExampleDomain", // required + beholder.AttrKeyBeholderEntity, "ExampleEntity", // required "beholder_data_type", "custom_message", "foo", "bar", ) @@ -106,9 +106,9 @@ func ExampleNewNoopClient() { fmt.Println("Emitting custom message via noop otel client") err := beholder.GetEmitter().Emit(context.Background(), []byte("test message"), - "beholder_data_schema", "/custom-message/versions/1", // required - "beholder_domain", "ExampleDomain", // required - "beholder_entity", "ExampleEntity", // required + beholder.AttrKeyBeholderDataSchema, "/custom-message/versions/1", // required + beholder.AttrKeyBeholderDomain, "ExampleDomain", // required + beholder.AttrKeyBeholderEntity, "ExampleEntity", // required ) if err != nil { log.Printf("Error emitting message: %v", err) diff --git a/pkg/beholder/message.go b/pkg/beholder/message.go index 7cb6f1bdb..4cb619a4d 100644 --- a/pkg/beholder/message.go +++ b/pkg/beholder/message.go @@ -60,9 +60,9 @@ func (m Metadata) Attributes() Attributes { "workflow_owner_address": m.WorkflowOwnerAddress, "workflow_spec_id": m.WorkflowSpecID, "workflow_execution_id": m.WorkflowExecutionID, - "beholder_domain": m.BeholderDomain, - "beholder_entity": m.BeholderEntity, - "beholder_data_schema": m.BeholderDataSchema, + AttrKeyBeholderDomain: m.BeholderDomain, + AttrKeyBeholderEntity: m.BeholderEntity, + AttrKeyBeholderDataSchema: m.BeholderDataSchema, "capability_contract_address": m.CapabilityContractAddress, "capability_id": m.CapabilityID, "capability_version": m.CapabilityVersion, @@ -206,11 +206,11 @@ func (m *Metadata) FromAttributes(attrs Attributes) *Metadata { m.WorkflowSpecID = v.(string) case "workflow_execution_id": m.WorkflowExecutionID = v.(string) - case "beholder_domain": + case AttrKeyBeholderDomain: m.BeholderDomain = v.(string) - case "beholder_entity": + case AttrKeyBeholderEntity: m.BeholderEntity = v.(string) - case "beholder_data_schema": + case AttrKeyBeholderDataSchema: m.BeholderDataSchema = v.(string) case "capability_contract_address": m.CapabilityContractAddress = v.(string) diff --git a/pkg/beholder/message_emitter_test.go b/pkg/beholder/message_emitter_test.go index 6ee4eeec0..a055a1fb7 100644 --- a/pkg/beholder/message_emitter_test.go +++ b/pkg/beholder/message_emitter_test.go @@ -44,9 +44,9 @@ func TestEmitterMessageValidation(t *testing.T) { { name: "Invalid URI", attrs: beholder.Attributes{ - "beholder_domain": "TestDomain", - "beholder_entity": "TestEntity", - "beholder_data_schema": "example-schema", + beholder.AttrKeyBeholderDomain: "TestDomain", + beholder.AttrKeyBeholderEntity: "TestEntity", + beholder.AttrKeyBeholderDataSchema: "example-schema", }, exporterCalledTimes: 0, expectedError: "'Metadata.BeholderDataSchema' Error:Field validation for 'BeholderDataSchema' failed on the 'uri' tag", @@ -54,9 +54,9 @@ func TestEmitterMessageValidation(t *testing.T) { { name: "Invalid Beholder domain (double underscore)", attrs: beholder.Attributes{ - "beholder_data_schema": "/example-schema/versions/1", - "beholder_entity": "TestEntity", - "beholder_domain": "Test__Domain", + beholder.AttrKeyBeholderDataSchema: "/example-schema/versions/1", + beholder.AttrKeyBeholderEntity: "TestEntity", + beholder.AttrKeyBeholderDomain: "Test__Domain", }, exporterCalledTimes: 0, expectedError: "'Metadata.BeholderDomain' Error:Field validation for 'BeholderDomain' failed on the 'domain_entity' tag", @@ -64,9 +64,9 @@ func TestEmitterMessageValidation(t *testing.T) { { name: "Invalid Beholder domain (special characters)", attrs: beholder.Attributes{ - "beholder_data_schema": "/example-schema/versions/1", - "beholder_entity": "TestEntity", - "beholder_domain": "TestDomain*$", + beholder.AttrKeyBeholderDataSchema: "/example-schema/versions/1", + beholder.AttrKeyBeholderEntity: "TestEntity", + beholder.AttrKeyBeholderDomain: "TestDomain*$", }, exporterCalledTimes: 0, expectedError: "'Metadata.BeholderDomain' Error:Field validation for 'BeholderDomain' failed on the 'domain_entity' tag", @@ -74,9 +74,9 @@ func TestEmitterMessageValidation(t *testing.T) { { name: "Invalid Beholder entity (double underscore)", attrs: beholder.Attributes{ - "beholder_data_schema": "/example-schema/versions/1", - "beholder_entity": "Test__Entity", - "beholder_domain": "TestDomain", + beholder.AttrKeyBeholderDataSchema: "/example-schema/versions/1", + beholder.AttrKeyBeholderEntity: "Test__Entity", + beholder.AttrKeyBeholderDomain: "TestDomain", }, exporterCalledTimes: 0, expectedError: "'Metadata.BeholderEntity' Error:Field validation for 'BeholderEntity' failed on the 'domain_entity' tag", @@ -84,9 +84,9 @@ func TestEmitterMessageValidation(t *testing.T) { { name: "Invalid Beholder entity (special characters)", attrs: beholder.Attributes{ - "beholder_data_schema": "/example-schema/versions/1", - "beholder_entity": "TestEntity*$", - "beholder_domain": "TestDomain", + beholder.AttrKeyBeholderDataSchema: "/example-schema/versions/1", + beholder.AttrKeyBeholderEntity: "TestEntity*$", + beholder.AttrKeyBeholderDomain: "TestDomain", }, exporterCalledTimes: 0, expectedError: "'Metadata.BeholderEntity' Error:Field validation for 'BeholderEntity' failed on the 'domain_entity' tag", @@ -95,9 +95,9 @@ func TestEmitterMessageValidation(t *testing.T) { name: "Valid Attributes", exporterCalledTimes: 1, attrs: beholder.Attributes{ - "beholder_domain": "TestDomain", - "beholder_entity": "TestEntity", - "beholder_data_schema": "/example-schema/versions/1", + beholder.AttrKeyBeholderDomain: "TestDomain", + beholder.AttrKeyBeholderEntity: "TestEntity", + beholder.AttrKeyBeholderDataSchema: "/example-schema/versions/1", }, expectedError: "", }, @@ -105,9 +105,9 @@ func TestEmitterMessageValidation(t *testing.T) { name: "Valid Attributes (special characters)", exporterCalledTimes: 1, attrs: beholder.Attributes{ - "beholder_domain": "Test.Domain_42-1", - "beholder_entity": "Test.Entity_42-1", - "beholder_data_schema": "/example-schema/versions/1", + beholder.AttrKeyBeholderDomain: "Test.Domain_42-1", + beholder.AttrKeyBeholderEntity: "Test.Entity_42-1", + beholder.AttrKeyBeholderDataSchema: "/example-schema/versions/1", }, expectedError: "", }, diff --git a/pkg/beholder/schema.go b/pkg/beholder/schema.go index 06b6214b1..8b9024bcf 100644 --- a/pkg/beholder/schema.go +++ b/pkg/beholder/schema.go @@ -11,12 +11,6 @@ import ( "google.golang.org/protobuf/runtime/protoimpl" ) -const ( - AttrKeyBeholderDataSchema = "beholder_data_schema" - AttrKeyBeholderEntity = "beholder_entity" - AttrKeyBeholderDomain = "beholder_domain" -) - // patternSnake is a regular expression to match CamelCase words // Notice: we use the Unicode property 'Lu' (uppercase letter) to match // the first letter of the word, and 'P{Lu}' (not uppercase letter) to match diff --git a/pkg/capabilities/events/events.go b/pkg/capabilities/events/events.go index 444f45705..bcb72159a 100644 --- a/pkg/capabilities/events/events.go +++ b/pkg/capabilities/events/events.go @@ -193,7 +193,7 @@ func (e *Emitter) Emit(ctx context.Context, msg Message) error { } attrs := []any{ - "beholder_data_schema", + beholder.AttrKeyBeholderDataSchema, "/capabilities-operational-event/versions/1", "beholder_data_type", "custom_message", diff --git a/pkg/custmsg/custom_message.go b/pkg/custmsg/custom_message.go index 49ed8c459..c897e125b 100644 --- a/pkg/custmsg/custom_message.go +++ b/pkg/custmsg/custom_message.go @@ -111,9 +111,9 @@ func sendLogAsCustomMessageW(ctx context.Context, msg string, labels map[string] } err = beholder.GetEmitter().Emit(ctx, payloadBytes, - "beholder_data_schema", "/beholder-base-message/versions/1", // required - "beholder_domain", "platform", // required - "beholder_entity", "BaseMessage", // required + beholder.AttrKeyBeholderDataSchema, "/beholder-base-message/versions/1", // required + beholder.AttrKeyBeholderDomain, "platform", // required + beholder.AttrKeyBeholderEntity, "BaseMessage", // required ) if err != nil { return fmt.Errorf("sending custom message failed on emit: %w", err) diff --git a/pkg/utils/tests/beholder.go b/pkg/utils/tests/beholder.go index 544547475..902bbb3f7 100644 --- a/pkg/utils/tests/beholder.go +++ b/pkg/utils/tests/beholder.go @@ -38,11 +38,11 @@ func (b BeholderTester) Len(t *testing.T, attrKVs ...any) int { // Messages returns messages matching the provided keys and values. func (b BeholderTester) Messages(t *testing.T, attrKVs ...any) []beholder.Message { t.Helper() - + if attrKVs == nil { - return b.emitter.msgs + return b.emitter.msgs } - + return b.msgsForKVs(t, attrKVs...) } func (b BeholderTester) msgsForKVs(t *testing.T, attrKVs ...any) []beholder.Message { @@ -86,7 +86,7 @@ func (b BeholderTester) BaseMessagesForLabels(t *testing.T, labels map[string]st messageLoop: for _, eMsg := range b.emitter.msgs { - dataSchema, ok := eMsg.Attrs["beholder_entity"].(string) + dataSchema, ok := eMsg.Attrs[beholder.AttrKeyBeholderEntity].(string) if !ok { continue } From acf79334d50f29274e6a2196349b76333e1789fc Mon Sep 17 00:00:00 2001 From: Silas Lenihan Date: Wed, 28 May 2025 12:46:57 -0400 Subject: [PATCH 07/16] addressed feedback --- pkg/beholder/proto_emitter.go | 12 ++--- pkg/beholder/schema.go | 44 +++++++++---------- .../consensus/ocr3/types/aggregator.go | 2 + 3 files changed, 29 insertions(+), 29 deletions(-) diff --git a/pkg/beholder/proto_emitter.go b/pkg/beholder/proto_emitter.go index 1ade828bd..6d5c8a784 100644 --- a/pkg/beholder/proto_emitter.go +++ b/pkg/beholder/proto_emitter.go @@ -56,7 +56,7 @@ func (e *protoEmitter) Emit(ctx context.Context, m proto.Message, attrKVs ...any // Skip appending attributes if the context says it's already done that if skip, ok := ctx.Value(CtxKeySkipAppendAttrs).(bool); !ok || !skip { - attrKVs = e.appendAttrsRequired(ctx, m, attrKVs) + attrKVs = e.appendAttrsRequired(attrKVs, m) } // Emit the message with attributes @@ -72,7 +72,7 @@ func (e *protoEmitter) Emit(ctx context.Context, m proto.Message, attrKVs ...any // EmitWithLog emits a protobuf message with attributes and logs the emitted message func (e *protoEmitter) EmitWithLog(ctx context.Context, m proto.Message, attrKVs ...any) error { - attrKVs = e.appendAttrsRequired(ctx, m, attrKVs) + attrKVs = e.appendAttrsRequired(attrKVs, m) // attach a bool switch to ctx to avoid duplicating common attrs ctx = context.WithValue(ctx, CtxKeySkipAppendAttrs, true) @@ -88,9 +88,9 @@ func (e *protoEmitter) EmitWithLog(ctx context.Context, m proto.Message, attrKVs } // appendAttrsRequired appends required attributes to the attribute key-value list -func (e *protoEmitter) appendAttrsRequired(ctx context.Context, m proto.Message, attrKVs []any) []any { - attrKVs = appendRequiredAttrDataSchema(m, attrKVs, e.schemaBasePath) - attrKVs = appendRequiredAttrEntity(m, attrKVs) - attrKVs = appendRequiredAttrDomain(m, attrKVs) +func (e *protoEmitter) appendAttrsRequired(attrKVs []any, m proto.Message) []any { + attrKVs = appendRequiredAttrDataSchema(attrKVs, toSchemaPath(m, e.schemaBasePath)) + attrKVs = appendRequiredAttrEntity(attrKVs, m) + attrKVs = appendRequiredAttrDomain(attrKVs, m) return attrKVs } diff --git a/pkg/beholder/schema.go b/pkg/beholder/schema.go index 8b9024bcf..f61daf608 100644 --- a/pkg/beholder/schema.go +++ b/pkg/beholder/schema.go @@ -57,42 +57,31 @@ func toSchemaPath(m proto.Message, basePath string) string { } // appendRequiredAttrDataSchema adds the message schema path as an attribute (required) -func appendRequiredAttrDataSchema(m proto.Message, attrKVs []any, basePath string) []any { - key := AttrKeyBeholderDataSchema - for i := 0; i < len(attrKVs); i += 2 { - if attrKVs[i] == key { - return attrKVs - } +func appendRequiredAttrDataSchema(attrKVs []any, val string) []any { + if containsKey(attrKVs, AttrKeyBeholderDataSchema) { + return attrKVs } - attrKVs = append(attrKVs, key) - // Needs to be an URI (Beholder requirement) - val := toSchemaPath(m, basePath) + attrKVs = append(attrKVs, AttrKeyBeholderDataSchema) attrKVs = append(attrKVs, val) return attrKVs } // appendRequiredAttrEntity adds the message entity type as an attribute (required) -func appendRequiredAttrEntity(m proto.Message, attrKVs []any) []any { - key := AttrKeyBeholderEntity - for i := 0; i < len(attrKVs); i += 2 { - if attrKVs[i] == key { - return attrKVs - } +func appendRequiredAttrEntity(attrKVs []any, m proto.Message) []any { + if containsKey(attrKVs, AttrKeyBeholderEntity) { + return attrKVs } - attrKVs = append(attrKVs, key) + attrKVs = append(attrKVs, AttrKeyBeholderEntity) attrKVs = append(attrKVs, toSchemaName(m)) return attrKVs } // appendRequiredAttrDomain adds the message domain as an attribute (required) -func appendRequiredAttrDomain(m proto.Message, attrKVs []any) []any { - key := AttrKeyBeholderDomain - for i := 0; i < len(attrKVs); i += 2 { - if attrKVs[i] == key { - return attrKVs - } +func appendRequiredAttrDomain(attrKVs []any, m proto.Message) []any { + if containsKey(attrKVs, AttrKeyBeholderDomain) { + return attrKVs } // Notice: a name like 'platform.on_chain.forwarder.ReportProcessed' @@ -104,7 +93,16 @@ func appendRequiredAttrDomain(m proto.Message, attrKVs []any) []any { domain = strings.Split(protoName, ".")[0] } - attrKVs = append(attrKVs, key) + attrKVs = append(attrKVs, AttrKeyBeholderDomain) attrKVs = append(attrKVs, domain) return attrKVs } + +func containsKey(attrKVs []any, key string) bool { + for i := 0; i < len(attrKVs); i += 2 { + if attrKVs[i] == key { + return true + } + } + return false +} diff --git a/pkg/capabilities/consensus/ocr3/types/aggregator.go b/pkg/capabilities/consensus/ocr3/types/aggregator.go index 444fe466f..40192f58a 100644 --- a/pkg/capabilities/consensus/ocr3/types/aggregator.go +++ b/pkg/capabilities/consensus/ocr3/types/aggregator.go @@ -100,6 +100,8 @@ func (m Metadata) Encode() ([]byte, error) { return buf.Bytes(), nil } +// 1B Version, 32B ExecutionID, 4B Timestamp, 4B DONID, 4B DONConfigVersion, +// 32B WorkflowID, 10B WorkflowName, 20B WorkflowOwner, 2B ReportID const MetadataLen = 1 + 32 + 4 + 4 + 4 + 32 + 10 + 20 + 2 // =109 // Decode parses exactly MetadataLen bytes from raw, returns a Metadata struct From a99824a0746bf2bb3c4758bdcfedd08affe726a7 Mon Sep 17 00:00:00 2001 From: Silas Lenihan Date: Tue, 3 Jun 2025 09:19:53 -0400 Subject: [PATCH 08/16] added beholder attribute data_type --- pkg/beholder/attributes.go | 1 + pkg/beholder/client.go | 4 ++-- pkg/beholder/example_test.go | 2 +- pkg/beholder/httpclient.go | 4 ++-- pkg/capabilities/events/events.go | 2 +- 5 files changed, 7 insertions(+), 6 deletions(-) diff --git a/pkg/beholder/attributes.go b/pkg/beholder/attributes.go index cd7a1fb5c..f47954e95 100644 --- a/pkg/beholder/attributes.go +++ b/pkg/beholder/attributes.go @@ -4,4 +4,5 @@ const ( AttrKeyBeholderDataSchema = "beholder_data_schema" AttrKeyBeholderEntity = "beholder_entity" AttrKeyBeholderDomain = "beholder_domain" + AttrKeyBeholderDataType = "beholder_data_type" ) diff --git a/pkg/beholder/client.go b/pkg/beholder/client.go index edbdace32..f697566a4 100644 --- a/pkg/beholder/client.go +++ b/pkg/beholder/client.go @@ -133,7 +133,7 @@ func NewGRPCClient(cfg Config, otlploggrpcNew otlploggrpcFactory) (*Client, erro loggerProcessor = sdklog.NewSimpleProcessor(sharedLogExporter) } loggerAttributes := []attribute.KeyValue{ - attribute.String("beholder_data_type", "zap_log_message"), + attribute.String(AttrKeyBeholderDataType, "zap_log_message"), } loggerResource, err := sdkresource.Merge( sdkresource.NewSchemaless(loggerAttributes...), @@ -187,7 +187,7 @@ func NewGRPCClient(cfg Config, otlploggrpcNew otlploggrpcFactory) (*Client, erro } messageAttributes := []attribute.KeyValue{ - attribute.String("beholder_data_type", "custom_message"), + attribute.String(AttrKeyBeholderDataType, "custom_message"), } messageLoggerResource, err := sdkresource.Merge( sdkresource.NewSchemaless(messageAttributes...), diff --git a/pkg/beholder/example_test.go b/pkg/beholder/example_test.go index 9d5264087..e54da38d4 100644 --- a/pkg/beholder/example_test.go +++ b/pkg/beholder/example_test.go @@ -47,7 +47,7 @@ func ExampleNewClient() { beholder.AttrKeyBeholderDataSchema, "/custom-message/versions/1", // required beholder.AttrKeyBeholderDomain, "ExampleDomain", // required beholder.AttrKeyBeholderEntity, "ExampleEntity", // required - "beholder_data_type", "custom_message", + beholder.AttrKeyBeholderDataType, "custom_message", "foo", "bar", ) if err != nil { diff --git a/pkg/beholder/httpclient.go b/pkg/beholder/httpclient.go index c36f4df60..aeb245d69 100644 --- a/pkg/beholder/httpclient.go +++ b/pkg/beholder/httpclient.go @@ -99,7 +99,7 @@ func NewHTTPClient(cfg Config, otlploghttpNew otlploghttpFactory) (*Client, erro loggerProcessor = sdklog.NewSimpleProcessor(sharedLogExporter) } loggerAttributes := []attribute.KeyValue{ - attribute.String("beholder_data_type", "zap_log_message"), + attribute.String(AttrKeyBeholderDataType, "zap_log_message"), } loggerResource, err := sdkresource.Merge( sdkresource.NewSchemaless(loggerAttributes...), @@ -153,7 +153,7 @@ func NewHTTPClient(cfg Config, otlploghttpNew otlploghttpFactory) (*Client, erro } messageAttributes := []attribute.KeyValue{ - attribute.String("beholder_data_type", "custom_message"), + attribute.String(AttrKeyBeholderDataType, "custom_message"), } messageLoggerResource, err := sdkresource.Merge( sdkresource.NewSchemaless(messageAttributes...), diff --git a/pkg/capabilities/events/events.go b/pkg/capabilities/events/events.go index bcb72159a..0b2701bcf 100644 --- a/pkg/capabilities/events/events.go +++ b/pkg/capabilities/events/events.go @@ -195,7 +195,7 @@ func (e *Emitter) Emit(ctx context.Context, msg Message) error { attrs := []any{ beholder.AttrKeyBeholderDataSchema, "/capabilities-operational-event/versions/1", - "beholder_data_type", + beholder.AttrKeyBeholderDataType, "custom_message", } From 64c6a47131cca358e3a3f0838a89f7d3112ee3e2 Mon Sep 17 00:00:00 2001 From: Silas Lenihan Date: Thu, 5 Jun 2025 09:19:11 -0400 Subject: [PATCH 09/16] removed behodler for attr keys' --- pkg/beholder/attributes.go | 8 ++--- pkg/beholder/chip_ingress_emitter.go | 4 +-- pkg/beholder/chip_ingress_emitter_test.go | 34 +++++++++--------- pkg/beholder/client.go | 4 +-- pkg/beholder/client_test.go | 24 ++++++------- pkg/beholder/example_test.go | 14 ++++---- pkg/beholder/httpclient.go | 4 +-- pkg/beholder/message.go | 12 +++---- pkg/beholder/message_emitter_test.go | 42 +++++++++++------------ pkg/beholder/schema.go | 12 +++---- pkg/capabilities/events/events.go | 4 +-- pkg/custmsg/custom_message.go | 6 ++-- pkg/utils/tests/beholder.go | 2 +- 13 files changed, 85 insertions(+), 85 deletions(-) diff --git a/pkg/beholder/attributes.go b/pkg/beholder/attributes.go index f47954e95..b37e17411 100644 --- a/pkg/beholder/attributes.go +++ b/pkg/beholder/attributes.go @@ -1,8 +1,8 @@ package beholder const ( - AttrKeyBeholderDataSchema = "beholder_data_schema" - AttrKeyBeholderEntity = "beholder_entity" - AttrKeyBeholderDomain = "beholder_domain" - AttrKeyBeholderDataType = "beholder_data_type" + AttrKeyDataSchema = "beholder_data_schema" + AttrKeyEntity = "beholder_entity" + AttrKeyDomain = "beholder_domain" + AttrKeyDataType = "beholder_data_type" ) diff --git a/pkg/beholder/chip_ingress_emitter.go b/pkg/beholder/chip_ingress_emitter.go index a75da2fcd..0ca3e991a 100644 --- a/pkg/beholder/chip_ingress_emitter.go +++ b/pkg/beholder/chip_ingress_emitter.go @@ -51,12 +51,12 @@ func ExtractSourceAndType(attrKVs ...any) (string, string, error) { for key, value := range attributes { // Retrieve source and type using either ChIP or legacy attribute names, prioritizing source/type - if key == "source" || (key == AttrKeyBeholderDomain && sourceDomain == "") { + if key == "source" || (key == AttrKeyDomain && sourceDomain == "") { if val, ok := value.(string); ok { sourceDomain = val } } - if key == "type" || (key == AttrKeyBeholderEntity && entityType == "") { + if key == "type" || (key == AttrKeyEntity && entityType == "") { if val, ok := value.(string); ok { entityType = val } diff --git a/pkg/beholder/chip_ingress_emitter_test.go b/pkg/beholder/chip_ingress_emitter_test.go index aa99b2490..9798b6fff 100644 --- a/pkg/beholder/chip_ingress_emitter_test.go +++ b/pkg/beholder/chip_ingress_emitter_test.go @@ -42,7 +42,7 @@ func TestChipIngressEmit(t *testing.T) { emitter, err := beholder.NewChipIngressEmitter(clientMock) require.NoError(t, err) - err = emitter.Emit(t.Context(), body, beholder.AttrKeyBeholderDomain, domain, beholder.AttrKeyBeholderEntity, entity) + err = emitter.Emit(t.Context(), body, beholder.AttrKeyDomain, domain, beholder.AttrKeyEntity, entity) require.NoError(t, err) clientMock.AssertExpectations(t) @@ -59,7 +59,7 @@ func TestChipIngressEmit(t *testing.T) { emitter, err := beholder.NewChipIngressEmitter(clientMock) require.NoError(t, err) - err = emitter.Emit(t.Context(), body, beholder.AttrKeyBeholderDomain, domain) + err = emitter.Emit(t.Context(), body, beholder.AttrKeyDomain, domain) assert.Error(t, err) }) @@ -74,7 +74,7 @@ func TestChipIngressEmit(t *testing.T) { emitter, err := beholder.NewChipIngressEmitter(clientMock) require.NoError(t, err) - err = emitter.Emit(t.Context(), body, beholder.AttrKeyBeholderDomain, domain, beholder.AttrKeyBeholderEntity, entity) + err = emitter.Emit(t.Context(), body, beholder.AttrKeyDomain, domain, beholder.AttrKeyEntity, entity) require.Error(t, err) clientMock.AssertExpectations(t) @@ -92,7 +92,7 @@ func TestExtractSourceAndType(t *testing.T) { }{ { name: "happy path - domain and entity exist", - attrs: []any{map[string]any{beholder.AttrKeyBeholderDomain: "test-domain", beholder.AttrKeyBeholderEntity: "test-entity"}}, + attrs: []any{map[string]any{beholder.AttrKeyDomain: "test-domain", beholder.AttrKeyEntity: "test-entity"}}, wantDomain: "test-domain", wantEntity: "test-entity", wantErr: false, @@ -106,14 +106,14 @@ func TestExtractSourceAndType(t *testing.T) { }, { name: "happy path - domain and entity exist - uses source/type", - attrs: []any{map[string]any{"source": "other-domain", beholder.AttrKeyBeholderDomain: "test-domain", beholder.AttrKeyBeholderEntity: "test-entity", "type": "other-entity"}}, + attrs: []any{map[string]any{"source": "other-domain", beholder.AttrKeyDomain: "test-domain", beholder.AttrKeyEntity: "test-entity", "type": "other-entity"}}, wantDomain: "other-domain", wantEntity: "other-entity", wantErr: false, }, { name: "missing domain/source", - attrs: []any{map[string]any{beholder.AttrKeyBeholderEntity: "test-entity"}}, + attrs: []any{map[string]any{beholder.AttrKeyEntity: "test-entity"}}, wantDomain: "", wantEntity: "", wantErr: true, @@ -121,7 +121,7 @@ func TestExtractSourceAndType(t *testing.T) { }, { name: "missing entity/type", - attrs: []any{map[string]any{beholder.AttrKeyBeholderDomain: "test-domain"}}, + attrs: []any{map[string]any{beholder.AttrKeyDomain: "test-domain"}}, wantDomain: "", wantEntity: "", wantErr: true, @@ -146,10 +146,10 @@ func TestExtractSourceAndType(t *testing.T) { { name: "domain and entity with additional attributes", attrs: []any{map[string]any{ - "other_key": "other_value", - beholder.AttrKeyBeholderDomain: "test-domain", - beholder.AttrKeyBeholderEntity: "test-entity", - "something_else": 123, + "other_key": "other_value", + beholder.AttrKeyDomain: "test-domain", + beholder.AttrKeyEntity: "test-entity", + "something_else": 123, }}, wantDomain: "test-domain", wantEntity: "test-entity", @@ -158,9 +158,9 @@ func TestExtractSourceAndType(t *testing.T) { { name: "non-string keys ignored", attrs: []any{map[string]any{ - "other_value": "value", - beholder.AttrKeyBeholderDomain: "test-domain", - beholder.AttrKeyBeholderEntity: "test-entity", + "other_value": "value", + beholder.AttrKeyDomain: "test-domain", + beholder.AttrKeyEntity: "test-entity", }, 123, "other_key"}, wantDomain: "test-domain", wantEntity: "test-entity", @@ -169,9 +169,9 @@ func TestExtractSourceAndType(t *testing.T) { { name: "non-string values handled", attrs: []any{map[string]any{ - "other_key": 123, - beholder.AttrKeyBeholderDomain: "test-domain", - beholder.AttrKeyBeholderEntity: "test-entity", + "other_key": 123, + beholder.AttrKeyDomain: "test-domain", + beholder.AttrKeyEntity: "test-entity", }}, wantDomain: "test-domain", wantEntity: "test-entity", diff --git a/pkg/beholder/client.go b/pkg/beholder/client.go index f697566a4..dd43e1540 100644 --- a/pkg/beholder/client.go +++ b/pkg/beholder/client.go @@ -133,7 +133,7 @@ func NewGRPCClient(cfg Config, otlploggrpcNew otlploggrpcFactory) (*Client, erro loggerProcessor = sdklog.NewSimpleProcessor(sharedLogExporter) } loggerAttributes := []attribute.KeyValue{ - attribute.String(AttrKeyBeholderDataType, "zap_log_message"), + attribute.String(AttrKeyDataType, "zap_log_message"), } loggerResource, err := sdkresource.Merge( sdkresource.NewSchemaless(loggerAttributes...), @@ -187,7 +187,7 @@ func NewGRPCClient(cfg Config, otlploggrpcNew otlploggrpcFactory) (*Client, erro } messageAttributes := []attribute.KeyValue{ - attribute.String(AttrKeyBeholderDataType, "custom_message"), + attribute.String(AttrKeyDataType, "custom_message"), } messageLoggerResource, err := sdkresource.Merge( sdkresource.NewSchemaless(messageAttributes...), diff --git a/pkg/beholder/client_test.go b/pkg/beholder/client_test.go index 2e19c3c95..d8bac61cc 100644 --- a/pkg/beholder/client_test.go +++ b/pkg/beholder/client_test.go @@ -42,18 +42,18 @@ func (m *MockExporter) ForceFlush(ctx context.Context) error { func TestClient(t *testing.T) { defaultCustomAttributes := func() map[string]any { return map[string]any{ - "int_key_1": 123, - "int64_key_1": int64(123), - "int32_key_1": int32(123), - "str_key_1": "str_val_1", - "bool_key_1": true, - "float_key_1": 123.456, - "byte_key_1": []byte("byte_val_1"), - "str_slice_key_1": []string{"str_val_1", "str_val_2"}, - "nil_key_1": nil, - beholder.AttrKeyBeholderDomain: "TestDomain", // Required field - beholder.AttrKeyBeholderEntity: "TestEntity", // Required field - beholder.AttrKeyBeholderDataSchema: "/schemas/ids/1001", // Required field, URI + "int_key_1": 123, + "int64_key_1": int64(123), + "int32_key_1": int32(123), + "str_key_1": "str_val_1", + "bool_key_1": true, + "float_key_1": 123.456, + "byte_key_1": []byte("byte_val_1"), + "str_slice_key_1": []string{"str_val_1", "str_val_2"}, + "nil_key_1": nil, + beholder.AttrKeyDomain: "TestDomain", // Required field + beholder.AttrKeyEntity: "TestEntity", // Required field + beholder.AttrKeyDataSchema: "/schemas/ids/1001", // Required field, URI } } defaultMessageBody := []byte("body bytes") diff --git a/pkg/beholder/example_test.go b/pkg/beholder/example_test.go index e54da38d4..ddb9c6bdc 100644 --- a/pkg/beholder/example_test.go +++ b/pkg/beholder/example_test.go @@ -44,10 +44,10 @@ func ExampleNewClient() { fmt.Println("Emit custom messages") for range 10 { err := beholder.GetEmitter().Emit(context.Background(), payloadBytes, - beholder.AttrKeyBeholderDataSchema, "/custom-message/versions/1", // required - beholder.AttrKeyBeholderDomain, "ExampleDomain", // required - beholder.AttrKeyBeholderEntity, "ExampleEntity", // required - beholder.AttrKeyBeholderDataType, "custom_message", + beholder.AttrKeyDataSchema, "/custom-message/versions/1", // required + beholder.AttrKeyDomain, "ExampleDomain", // required + beholder.AttrKeyEntity, "ExampleEntity", // required + beholder.AttrKeyDataType, "custom_message", "foo", "bar", ) if err != nil { @@ -106,9 +106,9 @@ func ExampleNewNoopClient() { fmt.Println("Emitting custom message via noop otel client") err := beholder.GetEmitter().Emit(context.Background(), []byte("test message"), - beholder.AttrKeyBeholderDataSchema, "/custom-message/versions/1", // required - beholder.AttrKeyBeholderDomain, "ExampleDomain", // required - beholder.AttrKeyBeholderEntity, "ExampleEntity", // required + beholder.AttrKeyDataSchema, "/custom-message/versions/1", // required + beholder.AttrKeyDomain, "ExampleDomain", // required + beholder.AttrKeyEntity, "ExampleEntity", // required ) if err != nil { log.Printf("Error emitting message: %v", err) diff --git a/pkg/beholder/httpclient.go b/pkg/beholder/httpclient.go index aeb245d69..9bca3931e 100644 --- a/pkg/beholder/httpclient.go +++ b/pkg/beholder/httpclient.go @@ -99,7 +99,7 @@ func NewHTTPClient(cfg Config, otlploghttpNew otlploghttpFactory) (*Client, erro loggerProcessor = sdklog.NewSimpleProcessor(sharedLogExporter) } loggerAttributes := []attribute.KeyValue{ - attribute.String(AttrKeyBeholderDataType, "zap_log_message"), + attribute.String(AttrKeyDataType, "zap_log_message"), } loggerResource, err := sdkresource.Merge( sdkresource.NewSchemaless(loggerAttributes...), @@ -153,7 +153,7 @@ func NewHTTPClient(cfg Config, otlploghttpNew otlploghttpFactory) (*Client, erro } messageAttributes := []attribute.KeyValue{ - attribute.String(AttrKeyBeholderDataType, "custom_message"), + attribute.String(AttrKeyDataType, "custom_message"), } messageLoggerResource, err := sdkresource.Merge( sdkresource.NewSchemaless(messageAttributes...), diff --git a/pkg/beholder/message.go b/pkg/beholder/message.go index 4cb619a4d..14320c7b2 100644 --- a/pkg/beholder/message.go +++ b/pkg/beholder/message.go @@ -60,9 +60,9 @@ func (m Metadata) Attributes() Attributes { "workflow_owner_address": m.WorkflowOwnerAddress, "workflow_spec_id": m.WorkflowSpecID, "workflow_execution_id": m.WorkflowExecutionID, - AttrKeyBeholderDomain: m.BeholderDomain, - AttrKeyBeholderEntity: m.BeholderEntity, - AttrKeyBeholderDataSchema: m.BeholderDataSchema, + AttrKeyDomain: m.BeholderDomain, + AttrKeyEntity: m.BeholderEntity, + AttrKeyDataSchema: m.BeholderDataSchema, "capability_contract_address": m.CapabilityContractAddress, "capability_id": m.CapabilityID, "capability_version": m.CapabilityVersion, @@ -206,11 +206,11 @@ func (m *Metadata) FromAttributes(attrs Attributes) *Metadata { m.WorkflowSpecID = v.(string) case "workflow_execution_id": m.WorkflowExecutionID = v.(string) - case AttrKeyBeholderDomain: + case AttrKeyDomain: m.BeholderDomain = v.(string) - case AttrKeyBeholderEntity: + case AttrKeyEntity: m.BeholderEntity = v.(string) - case AttrKeyBeholderDataSchema: + case AttrKeyDataSchema: m.BeholderDataSchema = v.(string) case "capability_contract_address": m.CapabilityContractAddress = v.(string) diff --git a/pkg/beholder/message_emitter_test.go b/pkg/beholder/message_emitter_test.go index a055a1fb7..471bb73db 100644 --- a/pkg/beholder/message_emitter_test.go +++ b/pkg/beholder/message_emitter_test.go @@ -44,9 +44,9 @@ func TestEmitterMessageValidation(t *testing.T) { { name: "Invalid URI", attrs: beholder.Attributes{ - beholder.AttrKeyBeholderDomain: "TestDomain", - beholder.AttrKeyBeholderEntity: "TestEntity", - beholder.AttrKeyBeholderDataSchema: "example-schema", + beholder.AttrKeyDomain: "TestDomain", + beholder.AttrKeyEntity: "TestEntity", + beholder.AttrKeyDataSchema: "example-schema", }, exporterCalledTimes: 0, expectedError: "'Metadata.BeholderDataSchema' Error:Field validation for 'BeholderDataSchema' failed on the 'uri' tag", @@ -54,9 +54,9 @@ func TestEmitterMessageValidation(t *testing.T) { { name: "Invalid Beholder domain (double underscore)", attrs: beholder.Attributes{ - beholder.AttrKeyBeholderDataSchema: "/example-schema/versions/1", - beholder.AttrKeyBeholderEntity: "TestEntity", - beholder.AttrKeyBeholderDomain: "Test__Domain", + beholder.AttrKeyDataSchema: "/example-schema/versions/1", + beholder.AttrKeyEntity: "TestEntity", + beholder.AttrKeyDomain: "Test__Domain", }, exporterCalledTimes: 0, expectedError: "'Metadata.BeholderDomain' Error:Field validation for 'BeholderDomain' failed on the 'domain_entity' tag", @@ -64,9 +64,9 @@ func TestEmitterMessageValidation(t *testing.T) { { name: "Invalid Beholder domain (special characters)", attrs: beholder.Attributes{ - beholder.AttrKeyBeholderDataSchema: "/example-schema/versions/1", - beholder.AttrKeyBeholderEntity: "TestEntity", - beholder.AttrKeyBeholderDomain: "TestDomain*$", + beholder.AttrKeyDataSchema: "/example-schema/versions/1", + beholder.AttrKeyEntity: "TestEntity", + beholder.AttrKeyDomain: "TestDomain*$", }, exporterCalledTimes: 0, expectedError: "'Metadata.BeholderDomain' Error:Field validation for 'BeholderDomain' failed on the 'domain_entity' tag", @@ -74,9 +74,9 @@ func TestEmitterMessageValidation(t *testing.T) { { name: "Invalid Beholder entity (double underscore)", attrs: beholder.Attributes{ - beholder.AttrKeyBeholderDataSchema: "/example-schema/versions/1", - beholder.AttrKeyBeholderEntity: "Test__Entity", - beholder.AttrKeyBeholderDomain: "TestDomain", + beholder.AttrKeyDataSchema: "/example-schema/versions/1", + beholder.AttrKeyEntity: "Test__Entity", + beholder.AttrKeyDomain: "TestDomain", }, exporterCalledTimes: 0, expectedError: "'Metadata.BeholderEntity' Error:Field validation for 'BeholderEntity' failed on the 'domain_entity' tag", @@ -84,9 +84,9 @@ func TestEmitterMessageValidation(t *testing.T) { { name: "Invalid Beholder entity (special characters)", attrs: beholder.Attributes{ - beholder.AttrKeyBeholderDataSchema: "/example-schema/versions/1", - beholder.AttrKeyBeholderEntity: "TestEntity*$", - beholder.AttrKeyBeholderDomain: "TestDomain", + beholder.AttrKeyDataSchema: "/example-schema/versions/1", + beholder.AttrKeyEntity: "TestEntity*$", + beholder.AttrKeyDomain: "TestDomain", }, exporterCalledTimes: 0, expectedError: "'Metadata.BeholderEntity' Error:Field validation for 'BeholderEntity' failed on the 'domain_entity' tag", @@ -95,9 +95,9 @@ func TestEmitterMessageValidation(t *testing.T) { name: "Valid Attributes", exporterCalledTimes: 1, attrs: beholder.Attributes{ - beholder.AttrKeyBeholderDomain: "TestDomain", - beholder.AttrKeyBeholderEntity: "TestEntity", - beholder.AttrKeyBeholderDataSchema: "/example-schema/versions/1", + beholder.AttrKeyDomain: "TestDomain", + beholder.AttrKeyEntity: "TestEntity", + beholder.AttrKeyDataSchema: "/example-schema/versions/1", }, expectedError: "", }, @@ -105,9 +105,9 @@ func TestEmitterMessageValidation(t *testing.T) { name: "Valid Attributes (special characters)", exporterCalledTimes: 1, attrs: beholder.Attributes{ - beholder.AttrKeyBeholderDomain: "Test.Domain_42-1", - beholder.AttrKeyBeholderEntity: "Test.Entity_42-1", - beholder.AttrKeyBeholderDataSchema: "/example-schema/versions/1", + beholder.AttrKeyDomain: "Test.Domain_42-1", + beholder.AttrKeyEntity: "Test.Entity_42-1", + beholder.AttrKeyDataSchema: "/example-schema/versions/1", }, expectedError: "", }, diff --git a/pkg/beholder/schema.go b/pkg/beholder/schema.go index f61daf608..97f34837a 100644 --- a/pkg/beholder/schema.go +++ b/pkg/beholder/schema.go @@ -58,29 +58,29 @@ func toSchemaPath(m proto.Message, basePath string) string { // appendRequiredAttrDataSchema adds the message schema path as an attribute (required) func appendRequiredAttrDataSchema(attrKVs []any, val string) []any { - if containsKey(attrKVs, AttrKeyBeholderDataSchema) { + if containsKey(attrKVs, AttrKeyDataSchema) { return attrKVs } - attrKVs = append(attrKVs, AttrKeyBeholderDataSchema) + attrKVs = append(attrKVs, AttrKeyDataSchema) attrKVs = append(attrKVs, val) return attrKVs } // appendRequiredAttrEntity adds the message entity type as an attribute (required) func appendRequiredAttrEntity(attrKVs []any, m proto.Message) []any { - if containsKey(attrKVs, AttrKeyBeholderEntity) { + if containsKey(attrKVs, AttrKeyEntity) { return attrKVs } - attrKVs = append(attrKVs, AttrKeyBeholderEntity) + attrKVs = append(attrKVs, AttrKeyEntity) attrKVs = append(attrKVs, toSchemaName(m)) return attrKVs } // appendRequiredAttrDomain adds the message domain as an attribute (required) func appendRequiredAttrDomain(attrKVs []any, m proto.Message) []any { - if containsKey(attrKVs, AttrKeyBeholderDomain) { + if containsKey(attrKVs, AttrKeyDomain) { return attrKVs } @@ -93,7 +93,7 @@ func appendRequiredAttrDomain(attrKVs []any, m proto.Message) []any { domain = strings.Split(protoName, ".")[0] } - attrKVs = append(attrKVs, AttrKeyBeholderDomain) + attrKVs = append(attrKVs, AttrKeyDomain) attrKVs = append(attrKVs, domain) return attrKVs } diff --git a/pkg/capabilities/events/events.go b/pkg/capabilities/events/events.go index 0b2701bcf..78e28217e 100644 --- a/pkg/capabilities/events/events.go +++ b/pkg/capabilities/events/events.go @@ -193,9 +193,9 @@ func (e *Emitter) Emit(ctx context.Context, msg Message) error { } attrs := []any{ - beholder.AttrKeyBeholderDataSchema, + beholder.AttrKeyDataSchema, "/capabilities-operational-event/versions/1", - beholder.AttrKeyBeholderDataType, + beholder.AttrKeyDataType, "custom_message", } diff --git a/pkg/custmsg/custom_message.go b/pkg/custmsg/custom_message.go index c897e125b..a2f9b7d2a 100644 --- a/pkg/custmsg/custom_message.go +++ b/pkg/custmsg/custom_message.go @@ -111,9 +111,9 @@ func sendLogAsCustomMessageW(ctx context.Context, msg string, labels map[string] } err = beholder.GetEmitter().Emit(ctx, payloadBytes, - beholder.AttrKeyBeholderDataSchema, "/beholder-base-message/versions/1", // required - beholder.AttrKeyBeholderDomain, "platform", // required - beholder.AttrKeyBeholderEntity, "BaseMessage", // required + beholder.AttrKeyDataSchema, "/beholder-base-message/versions/1", // required + beholder.AttrKeyDomain, "platform", // required + beholder.AttrKeyEntity, "BaseMessage", // required ) if err != nil { return fmt.Errorf("sending custom message failed on emit: %w", err) diff --git a/pkg/utils/tests/beholder.go b/pkg/utils/tests/beholder.go index 902bbb3f7..c2479e645 100644 --- a/pkg/utils/tests/beholder.go +++ b/pkg/utils/tests/beholder.go @@ -86,7 +86,7 @@ func (b BeholderTester) BaseMessagesForLabels(t *testing.T, labels map[string]st messageLoop: for _, eMsg := range b.emitter.msgs { - dataSchema, ok := eMsg.Attrs[beholder.AttrKeyBeholderEntity].(string) + dataSchema, ok := eMsg.Attrs[beholder.AttrKeyEntity].(string) if !ok { continue } From 1ff9a4634a1f7bbf2ffcaaec8d038587689e8857 Mon Sep 17 00:00:00 2001 From: Silas Lenihan Date: Thu, 5 Jun 2025 12:34:55 -0400 Subject: [PATCH 10/16] Revert "pkg/loop: expand EnvConfig and make available from Server (#1149)" This reverts commit c4fb36f5716e26a7ff98f2224d2a6cff1bb05fc9. --- pkg/loop/ccip_commit_test.go | 2 +- pkg/loop/ccip_execution_test.go | 2 +- pkg/loop/config.go | 137 +--------- pkg/loop/config_test.go | 311 +++++++++++++++------- pkg/loop/internal/example-relay/main.go | 2 +- pkg/loop/internal/pb/relayer.pb.go | 13 +- pkg/loop/internal/pb/relayer.proto | 3 +- pkg/loop/internal/relayer/relayer.go | 37 +-- pkg/loop/internal/relayer/test/relayer.go | 13 +- pkg/loop/internal/types/types.go | 2 +- pkg/loop/plugin_median_test.go | 2 +- pkg/loop/plugin_mercury_test.go | 2 +- pkg/loop/plugin_relayer_test.go | 2 +- pkg/loop/relayer_service.go | 4 +- pkg/loop/relayer_service_test.go | 12 +- pkg/loop/server.go | 70 +++-- 16 files changed, 285 insertions(+), 329 deletions(-) diff --git a/pkg/loop/ccip_commit_test.go b/pkg/loop/ccip_commit_test.go index 06281cd67..fbce4f14b 100644 --- a/pkg/loop/ccip_commit_test.go +++ b/pkg/loop/ccip_commit_test.go @@ -112,7 +112,7 @@ func TestCommitLOOP(t *testing.T) { func newCommitProvider(t *testing.T, pr loop.PluginRelayer) (types.CCIPCommitProvider, error) { ctx := t.Context() - r, err := pr.NewRelayer(ctx, test.ConfigTOML, keystoretest.Keystore, keystoretest.Keystore, nil) + r, err := pr.NewRelayer(ctx, test.ConfigTOML, keystoretest.Keystore, nil) require.NoError(t, err) servicetest.Run(t, r) diff --git a/pkg/loop/ccip_execution_test.go b/pkg/loop/ccip_execution_test.go index b665e156e..52f6cb6a4 100644 --- a/pkg/loop/ccip_execution_test.go +++ b/pkg/loop/ccip_execution_test.go @@ -113,7 +113,7 @@ func TestExecLOOP(t *testing.T) { func newExecutionProvider(t *testing.T, pr loop.PluginRelayer) (types.CCIPExecProvider, error) { ctx := t.Context() - r, err := pr.NewRelayer(ctx, test.ConfigTOML, keystoretest.Keystore, keystoretest.Keystore, nil) + r, err := pr.NewRelayer(ctx, test.ConfigTOML, keystoretest.Keystore, nil) require.NoError(t, err) servicetest.Run(t, r) diff --git a/pkg/loop/config.go b/pkg/loop/config.go index 59010ce2f..0be7a4eab 100644 --- a/pkg/loop/config.go +++ b/pkg/loop/config.go @@ -14,30 +14,13 @@ import ( ) const ( - envAppID = "CL_APP_ID" - - envDatabaseURL = "CL_DATABASE_URL" - envDatabaseIdleInTxSessionTimeout = "CL_DATABASE_IDLE_IN_TX_SESSION_TIMEOUT" - envDatabaseLockTimeout = "CL_DATABASE_LOCK_TIMEOUT" - envDatabaseQueryTimeout = "CL_DATABASE_QUERY_TIMEOUT" - envDatabaseListenerFallbackPollInterval = "CL_DATABASE_LISTNER_FALLBACK_POLL_INTERVAL" - envDatabaseLogSQL = "CL_DATABASE_LOG_SQL" - envDatabaseMaxOpenConns = "CL_DATABASE_MAX_OPEN_CONNS" - envDatabaseMaxIdleConns = "CL_DATABASE_MAX_IDLE_CONNS" - - envFeatureLogPoller = "CL_FEATURE_LOG_POLLER" - - envMercuryCacheLatestReportDeadline = "CL_MERCURY_CACHE_LATEST_REPORT_DEADLINE" - envMercuryCacheLatestReportTTL = "CL_MERCURY_CACHE_LATEST_REPORT_TTL" - envMercuryCacheMaxStaleAge = "CL_MERCURY_CACHE_MAX_STALE_AGE" - - envMercuryTransmitterProtocol = "CL_MERCURY_TRANSMITTER_PROTOCOL" - envMercuryTransmitterTransmitQueueMaxSize = "CL_MERCURY_TRANSMITTER_TRANSMIT__QUEUE_MAX_SIZE" - envMercuryTransmitterTransmitTimeout = "CL_MERCURY_TRANSMITTER_TRANSMIT_TIMEOUT" - envMercuryTransmitterTransmitConcurrency = "CL_MERCURY_TRANSMITTER_TRANSMIT_CONCURRENCY" - envMercuryTransmitterReaperFrequency = "CL_MERCURY_TRANSMITTER_REAPER_FREQUENCY" - envMercuryTransmitterReaperMaxAge = "CL_MERCURY_TRANSMITTER_REAPER_MAX_AGE" - envMercuryVerboseLogging = "CL_MERCURY_VERBOSE_LOGGING" + envDatabaseURL = "CL_DATABASE_URL" + envDatabaseIdleInTxSessionTimeout = "CL_DATABASE_IDLE_IN_TX_SESSION_TIMEOUT" + envDatabaseLockTimeout = "CL_DATABASE_LOCK_TIMEOUT" + envDatabaseQueryTimeout = "CL_DATABASE_QUERY_TIMEOUT" + envDatabaseLogSQL = "CL_DATABASE_LOG_SQL" + envDatabaseMaxOpenConns = "CL_DATABASE_MAX_OPEN_CONNS" + envDatabaseMaxIdleConns = "CL_DATABASE_MAX_IDLE_CONNS" envPromPort = "CL_PROMETHEUS_PORT" @@ -67,30 +50,13 @@ const ( // EnvConfig is the configuration between the application and the LOOP executable. The values // are fully resolved and static and passed via the environment. type EnvConfig struct { - AppID string - - DatabaseURL *config.SecretURL - DatabaseIdleInTxSessionTimeout time.Duration - DatabaseLockTimeout time.Duration - DatabaseQueryTimeout time.Duration - DatabaseListenerFallbackPollInterval time.Duration - DatabaseLogSQL bool - DatabaseMaxOpenConns int - DatabaseMaxIdleConns int - - FeatureLogPoller bool - - MercuryCacheLatestReportDeadline time.Duration - MercuryCacheLatestReportTTL time.Duration - MercuryCacheMaxStaleAge time.Duration - - MercuryTransmitterProtocol string - MercuryTransmitterTransmitQueueMaxSize uint32 - MercuryTransmitterTransmitTimeout time.Duration - MercuryTransmitterTransmitConcurrency uint32 - MercuryTransmitterReaperFrequency time.Duration - MercuryTransmitterReaperMaxAge time.Duration - MercuryVerboseLogging bool + DatabaseURL *config.SecretURL + DatabaseIdleInTxSessionTimeout time.Duration + DatabaseLockTimeout time.Duration + DatabaseQueryTimeout time.Duration + DatabaseLogSQL bool + DatabaseMaxOpenConns int + DatabaseMaxIdleConns int PrometheusPort int @@ -123,33 +89,16 @@ func (e *EnvConfig) AsCmdEnv() (env []string) { env = append(env, k+"="+v) } - add(envAppID, e.AppID) - if e.DatabaseURL != nil { // optional add(envDatabaseURL, e.DatabaseURL.URL().String()) add(envDatabaseIdleInTxSessionTimeout, e.DatabaseIdleInTxSessionTimeout.String()) add(envDatabaseLockTimeout, e.DatabaseLockTimeout.String()) add(envDatabaseQueryTimeout, e.DatabaseQueryTimeout.String()) - add(envDatabaseListenerFallbackPollInterval, e.DatabaseListenerFallbackPollInterval.String()) add(envDatabaseLogSQL, strconv.FormatBool(e.DatabaseLogSQL)) add(envDatabaseMaxOpenConns, strconv.Itoa(e.DatabaseMaxOpenConns)) add(envDatabaseMaxIdleConns, strconv.Itoa(e.DatabaseMaxIdleConns)) } - add(envFeatureLogPoller, strconv.FormatBool(e.FeatureLogPoller)) - - add(envMercuryCacheLatestReportDeadline, e.MercuryCacheLatestReportDeadline.String()) - add(envMercuryCacheLatestReportTTL, e.MercuryCacheLatestReportTTL.String()) - add(envMercuryCacheMaxStaleAge, e.MercuryCacheMaxStaleAge.String()) - - add(envMercuryTransmitterProtocol, e.MercuryTransmitterProtocol) - add(envMercuryTransmitterTransmitQueueMaxSize, strconv.FormatUint(uint64(e.MercuryTransmitterTransmitQueueMaxSize), 10)) - add(envMercuryTransmitterTransmitTimeout, e.MercuryTransmitterTransmitTimeout.String()) - add(envMercuryTransmitterTransmitConcurrency, strconv.FormatUint(uint64(e.MercuryTransmitterTransmitConcurrency), 10)) - add(envMercuryTransmitterReaperFrequency, e.MercuryTransmitterReaperFrequency.String()) - add(envMercuryTransmitterReaperMaxAge, e.MercuryTransmitterReaperMaxAge.String()) - add(envMercuryVerboseLogging, strconv.FormatBool(e.MercuryVerboseLogging)) - add(envPromPort, strconv.Itoa(e.PrometheusPort)) add(envTracingEnabled, strconv.FormatBool(e.TracingEnabled)) @@ -187,7 +136,6 @@ func (e *EnvConfig) AsCmdEnv() (env []string) { // parse deserializes environment variables func (e *EnvConfig) parse() error { - e.AppID = os.Getenv(envAppID) var err error e.DatabaseURL, err = getEnv(envDatabaseURL, func(s string) (*config.SecretURL, error) { if s == "" { // DatabaseURL is optional @@ -215,10 +163,6 @@ func (e *EnvConfig) parse() error { if err != nil { return err } - e.DatabaseListenerFallbackPollInterval, err = getEnv(envDatabaseListenerFallbackPollInterval, time.ParseDuration) - if err != nil { - return err - } e.DatabaseLogSQL, err = getEnv(envDatabaseLogSQL, strconv.ParseBool) if err != nil { return err @@ -233,50 +177,6 @@ func (e *EnvConfig) parse() error { } } - e.FeatureLogPoller, err = getBool(envFeatureLogPoller) - if err != nil { - return err - } - - e.MercuryCacheLatestReportDeadline, err = getEnv(envMercuryCacheLatestReportDeadline, time.ParseDuration) - if err != nil { - return err - } - e.MercuryCacheLatestReportTTL, err = getEnv(envMercuryCacheLatestReportTTL, time.ParseDuration) - if err != nil { - return err - } - e.MercuryCacheMaxStaleAge, err = getEnv(envMercuryCacheMaxStaleAge, time.ParseDuration) - if err != nil { - return err - } - - e.MercuryTransmitterProtocol = os.Getenv(envMercuryTransmitterProtocol) - e.MercuryTransmitterTransmitQueueMaxSize, err = getUint32(envMercuryTransmitterTransmitQueueMaxSize) - if err != nil { - return err - } - e.MercuryTransmitterTransmitTimeout, err = getEnv(envMercuryTransmitterTransmitTimeout, time.ParseDuration) - if err != nil { - return err - } - e.MercuryTransmitterTransmitConcurrency, err = getUint32(envMercuryTransmitterTransmitConcurrency) - if err != nil { - return err - } - e.MercuryTransmitterReaperFrequency, err = getEnv(envMercuryTransmitterReaperFrequency, time.ParseDuration) - if err != nil { - return err - } - e.MercuryTransmitterReaperMaxAge, err = getEnv(envMercuryTransmitterReaperMaxAge, time.ParseDuration) - if err != nil { - return err - } - e.MercuryVerboseLogging, err = getBool(envMercuryVerboseLogging) - if err != nil { - return err - } - promPortStr := os.Getenv(envPromPort) e.PrometheusPort, err = strconv.Atoi(promPortStr) if err != nil { @@ -399,15 +299,6 @@ func getFloat64OrZero(envKey string) float64 { return f } -func getUint32(envKey string) (uint32, error) { - s := os.Getenv(envKey) - u, err := strconv.ParseUint(s, 10, 32) - if err != nil { - return 0, err - } - return uint32(u), nil -} - func getEnv[T any](key string, parse func(string) (T, error)) (t T, err error) { v := os.Getenv(key) t, err = parse(v) diff --git a/pkg/loop/config_test.go b/pkg/loop/config_test.go index e3d240c41..85a0b6547 100644 --- a/pkg/loop/config_test.go +++ b/pkg/loop/config_test.go @@ -1,6 +1,8 @@ package loop import ( + "maps" + "net/url" "os" "strconv" "strings" @@ -22,37 +24,47 @@ func TestEnvConfig_parse(t *testing.T) { envVars map[string]string expectError bool - expectConfig EnvConfig + expectedDatabaseURL string + expectedDatabaseIdleInTxSessionTimeout time.Duration + expectedDatabaseLockTimeout time.Duration + expectedDatabaseQueryTimeout time.Duration + expectedDatabaseLogSQL bool + expectedDatabaseMaxOpenConns int + expectedDatabaseMaxIdleConns int + + expectedPrometheusPort int + expectedTracingEnabled bool + expectedTracingCollectorTarget string + expectedTracingSamplingRatio float64 + expectedTracingTLSCertPath string + + expectedTelemetryEnabled bool + expectedTelemetryEndpoint string + expectedTelemetryInsecureConn bool + expectedTelemetryCACertFile string + expectedTelemetryAttributes OtelAttributes + expectedTelemetryTraceSampleRatio float64 + expectedTelemetryAuthHeaders map[string]string + expectedTelemetryAuthPubKeyHex string + expectedTelemetryEmitterBatchProcessor bool + expectedTelemetryEmitterExportTimeout time.Duration + expectedTelemetryEmitterExportInterval time.Duration + expectedTelemetryEmitterExportMaxBatchSize int + expectedTelemetryEmitterMaxQueueSize int + expectedChipIngressEndpoint string }{ { name: "All variables set correctly", envVars: map[string]string{ - envAppID: "app-id", - envDatabaseURL: "postgres://user:password@localhost:5432/db", - envDatabaseIdleInTxSessionTimeout: "42s", - envDatabaseLockTimeout: "8m", - envDatabaseQueryTimeout: "7s", - envDatabaseListenerFallbackPollInterval: "17s", - envDatabaseLogSQL: "true", - envDatabaseMaxOpenConns: "9999", - envDatabaseMaxIdleConns: "8080", - - envFeatureLogPoller: "true", - - envMercuryCacheLatestReportDeadline: "1ms", - envMercuryCacheLatestReportTTL: "1µs", - envMercuryCacheMaxStaleAge: "1ns", - - envMercuryTransmitterProtocol: "foo", - envMercuryTransmitterTransmitQueueMaxSize: "42", - envMercuryTransmitterTransmitTimeout: "1s", - envMercuryTransmitterTransmitConcurrency: "13", - envMercuryTransmitterReaperFrequency: "1h", - envMercuryTransmitterReaperMaxAge: "1m", - envMercuryVerboseLogging: "true", - - envPromPort: "8080", - + envDatabaseURL: "postgres://user:password@localhost:5432/db", + envDatabaseIdleInTxSessionTimeout: "42s", + envDatabaseLockTimeout: "8m", + envDatabaseQueryTimeout: "7s", + envDatabaseLogSQL: "true", + envDatabaseMaxOpenConns: "9999", + envDatabaseMaxIdleConns: "8080", + + envPromPort: "8080", envTracingEnabled: "true", envTracingCollectorTarget: "some:target", envTracingSamplingRatio: "1.0", @@ -73,11 +85,38 @@ func TestEnvConfig_parse(t *testing.T) { envTelemetryEmitterExportInterval: "2s", envTelemetryEmitterExportMaxBatchSize: "100", envTelemetryEmitterMaxQueueSize: "1000", - - envChipIngressEndpoint: "http://chip-ingress.example.com", + envChipIngressEndpoint: "http://chip-ingress.example.com", }, - expectError: false, - expectConfig: envCfgFull, + expectError: false, + + expectedDatabaseURL: "postgres://user:password@localhost:5432/db", + expectedDatabaseIdleInTxSessionTimeout: 42 * time.Second, + expectedDatabaseLockTimeout: 8 * time.Minute, + expectedDatabaseQueryTimeout: 7 * time.Second, + expectedDatabaseLogSQL: true, + expectedDatabaseMaxOpenConns: 9999, + expectedDatabaseMaxIdleConns: 8080, + + expectedPrometheusPort: 8080, + expectedTracingEnabled: true, + expectedTracingCollectorTarget: "some:target", + expectedTracingSamplingRatio: 1.0, + expectedTracingTLSCertPath: "internal/test/fixtures/client.pem", + + expectedTelemetryEnabled: true, + expectedTelemetryEndpoint: "example.com/beholder", + expectedTelemetryInsecureConn: true, + expectedTelemetryCACertFile: "foo/bar", + expectedTelemetryAttributes: OtelAttributes{"foo": "bar", "baz": "42"}, + expectedTelemetryTraceSampleRatio: 0.42, + expectedTelemetryAuthHeaders: map[string]string{"header-key": "header-value"}, + expectedTelemetryAuthPubKeyHex: "pub-key-hex", + expectedTelemetryEmitterBatchProcessor: true, + expectedTelemetryEmitterExportTimeout: 1 * time.Second, + expectedTelemetryEmitterExportInterval: 2 * time.Second, + expectedTelemetryEmitterExportMaxBatchSize: 100, + expectedTelemetryEmitterMaxQueueSize: 1000, + expectedChipIngressEndpoint: "http://chip-ingress.example.com", }, { name: "CL_DATABASE_URL parse error", @@ -113,94 +152,164 @@ func TestEnvConfig_parse(t *testing.T) { err := config.parse() if tc.expectError { - require.Error(t, err) + if err == nil { + t.Errorf("Expected error, got nil") + } } else { - require.NoError(t, err) - require.Equal(t, tc.expectConfig, config) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } else { + if config.DatabaseURL.URL().String() != tc.expectedDatabaseURL { + t.Errorf("Expected Database URL %s, got %s", tc.expectedDatabaseURL, config.DatabaseURL.String()) + } + if config.DatabaseIdleInTxSessionTimeout != tc.expectedDatabaseIdleInTxSessionTimeout { + t.Errorf("Expected Database idle in tx session timeout %s, got %s", tc.expectedDatabaseIdleInTxSessionTimeout, config.DatabaseIdleInTxSessionTimeout) + } + if config.DatabaseLockTimeout != tc.expectedDatabaseLockTimeout { + t.Errorf("Expected Database lock timeout %s, got %s", tc.expectedDatabaseLockTimeout, config.DatabaseLockTimeout) + } + if config.DatabaseQueryTimeout != tc.expectedDatabaseQueryTimeout { + t.Errorf("Expected Database query timeout %s, got %s", tc.expectedDatabaseQueryTimeout, config.DatabaseQueryTimeout) + } + if config.DatabaseLogSQL != tc.expectedDatabaseLogSQL { + t.Errorf("Expected Database log sql %t, got %t", tc.expectedDatabaseLogSQL, config.DatabaseLogSQL) + } + if config.DatabaseMaxOpenConns != tc.expectedDatabaseMaxOpenConns { + t.Errorf("Expected Database max open conns %d, got %d", tc.expectedDatabaseMaxOpenConns, config.DatabaseMaxOpenConns) + } + if config.DatabaseMaxIdleConns != tc.expectedDatabaseMaxIdleConns { + t.Errorf("Expected Database max idle conns %d, got %d", tc.expectedDatabaseMaxIdleConns, config.DatabaseMaxIdleConns) + } + + if config.PrometheusPort != tc.expectedPrometheusPort { + t.Errorf("Expected Prometheus port %d, got %d", tc.expectedPrometheusPort, config.PrometheusPort) + } + if config.TracingEnabled != tc.expectedTracingEnabled { + t.Errorf("Expected tracingEnabled %v, got %v", tc.expectedTracingEnabled, config.TracingEnabled) + } + if config.TracingCollectorTarget != tc.expectedTracingCollectorTarget { + t.Errorf("Expected tracingCollectorTarget %s, got %s", tc.expectedTracingCollectorTarget, config.TracingCollectorTarget) + } + if config.TracingSamplingRatio != tc.expectedTracingSamplingRatio { + t.Errorf("Expected tracingSamplingRatio %f, got %f", tc.expectedTracingSamplingRatio, config.TracingSamplingRatio) + } + if config.TracingTLSCertPath != tc.expectedTracingTLSCertPath { + t.Errorf("Expected tracingTLSCertPath %s, got %s", tc.expectedTracingTLSCertPath, config.TracingTLSCertPath) + } + if config.TelemetryEnabled != tc.expectedTelemetryEnabled { + t.Errorf("Expected telemetryEnabled %v, got %v", tc.expectedTelemetryEnabled, config.TelemetryEnabled) + } + if config.TelemetryEndpoint != tc.expectedTelemetryEndpoint { + t.Errorf("Expected telemetryEndpoint %s, got %s", tc.expectedTelemetryEndpoint, config.TelemetryEndpoint) + } + if config.TelemetryInsecureConnection != tc.expectedTelemetryInsecureConn { + t.Errorf("Expected telemetryInsecureConn %v, got %v", tc.expectedTelemetryInsecureConn, config.TelemetryInsecureConnection) + } + if config.TelemetryCACertFile != tc.expectedTelemetryCACertFile { + t.Errorf("Expected telemetryCACertFile %s, got %s", tc.expectedTelemetryCACertFile, config.TelemetryCACertFile) + } + if !maps.Equal(config.TelemetryAttributes, tc.expectedTelemetryAttributes) { + t.Errorf("Expected telemetryAttributes %v, got %v", tc.expectedTelemetryAttributes, config.TelemetryAttributes) + } + if config.TelemetryTraceSampleRatio != tc.expectedTelemetryTraceSampleRatio { + t.Errorf("Expected telemetryTraceSampleRatio %f, got %f", tc.expectedTelemetryTraceSampleRatio, config.TelemetryTraceSampleRatio) + } + if !maps.Equal(config.TelemetryAuthHeaders, tc.expectedTelemetryAuthHeaders) { + t.Errorf("Expected telemetryAuthHeaders %v, got %v", tc.expectedTelemetryAuthHeaders, config.TelemetryAuthHeaders) + } + if config.TelemetryAuthPubKeyHex != tc.expectedTelemetryAuthPubKeyHex { + t.Errorf("Expected telemetryAuthPubKeyHex %s, got %s", tc.expectedTelemetryAuthPubKeyHex, config.TelemetryAuthPubKeyHex) + } + if config.TelemetryEmitterBatchProcessor != tc.expectedTelemetryEmitterBatchProcessor { + t.Errorf("Expected telemetryEmitterBatchProcessor %v, got %v", tc.expectedTelemetryEmitterBatchProcessor, config.TelemetryEmitterBatchProcessor) + } + if config.TelemetryEmitterExportTimeout != tc.expectedTelemetryEmitterExportTimeout { + t.Errorf("Expected telemetryEmitterExportTimeout %v, got %v", tc.expectedTelemetryEmitterExportTimeout, config.TelemetryEmitterExportTimeout) + } + if config.TelemetryEmitterExportInterval != tc.expectedTelemetryEmitterExportInterval { + t.Errorf("Expected telemetryEmitterExportInterval %v, got %v", tc.expectedTelemetryEmitterExportInterval, config.TelemetryEmitterExportInterval) + } + if config.TelemetryEmitterExportMaxBatchSize != tc.expectedTelemetryEmitterExportMaxBatchSize { + t.Errorf("Expected telemetryEmitterExportMaxBatchSize %d, got %d", tc.expectedTelemetryEmitterExportMaxBatchSize, config.TelemetryEmitterExportMaxBatchSize) + } + if config.TelemetryEmitterMaxQueueSize != tc.expectedTelemetryEmitterMaxQueueSize { + t.Errorf("Expected telemetryEmitterMaxQueueSize %d, got %d", tc.expectedTelemetryEmitterMaxQueueSize, config.TelemetryEmitterMaxQueueSize) + } + if config.ChipIngressEndpoint != tc.expectedChipIngressEndpoint { + t.Errorf("Expected ChipIngressEndpoint %s, got %s", tc.expectedChipIngressEndpoint, config.ChipIngressEndpoint) + } + } } }) } } -var envCfgFull = EnvConfig{ - AppID: "app-id", - - DatabaseURL: config.MustSecretURL("postgres://user:password@localhost:5432/db"), - DatabaseIdleInTxSessionTimeout: 42 * time.Second, - DatabaseLockTimeout: 8 * time.Minute, - DatabaseQueryTimeout: 7 * time.Second, - DatabaseListenerFallbackPollInterval: 17 * time.Second, - DatabaseLogSQL: true, - DatabaseMaxOpenConns: 9999, - DatabaseMaxIdleConns: 8080, - - FeatureLogPoller: true, - - MercuryCacheLatestReportDeadline: time.Millisecond, - MercuryCacheLatestReportTTL: time.Microsecond, - MercuryCacheMaxStaleAge: time.Nanosecond, - - MercuryTransmitterProtocol: "foo", - MercuryTransmitterTransmitQueueMaxSize: 42, - MercuryTransmitterTransmitTimeout: time.Second, - MercuryTransmitterTransmitConcurrency: 13, - MercuryTransmitterReaperFrequency: time.Hour, - MercuryTransmitterReaperMaxAge: time.Minute, - MercuryVerboseLogging: true, - - PrometheusPort: 8080, - - TracingEnabled: true, - TracingAttributes: map[string]string{"XYZ": "value"}, - TracingCollectorTarget: "some:target", - TracingSamplingRatio: 1.0, - TracingTLSCertPath: "internal/test/fixtures/client.pem", - - TelemetryEnabled: true, - TelemetryEndpoint: "example.com/beholder", - TelemetryInsecureConnection: true, - TelemetryCACertFile: "foo/bar", - TelemetryAttributes: OtelAttributes{"foo": "bar", "baz": "42"}, - TelemetryTraceSampleRatio: 0.42, - TelemetryAuthHeaders: map[string]string{"header-key": "header-value"}, - TelemetryAuthPubKeyHex: "pub-key-hex", - TelemetryEmitterBatchProcessor: true, - TelemetryEmitterExportTimeout: 1 * time.Second, - TelemetryEmitterExportInterval: 2 * time.Second, - TelemetryEmitterExportMaxBatchSize: 100, - TelemetryEmitterMaxQueueSize: 1000, - - ChipIngressEndpoint: "http://chip-ingress.example.com", +func equalOtelAttributes(a, b OtelAttributes) bool { + if len(a) != len(b) { + return false + } + for k, v := range a { + if b[k] != v { + return false + } + } + return true +} + +func equalStringMaps(a, b map[string]string) bool { + if len(a) != len(b) { + return false + } + for k, v := range a { + if b[k] != v { + return false + } + } + return true } func TestEnvConfig_AsCmdEnv(t *testing.T) { + envCfg := EnvConfig{ + DatabaseURL: (*config.SecretURL)(&url.URL{Scheme: "postgres", Host: "localhost:5432", User: url.UserPassword("user", "password"), Path: "/db"}), + PrometheusPort: 9090, + + TracingEnabled: true, + TracingCollectorTarget: "http://localhost:9000", + TracingSamplingRatio: 0.1, + TracingTLSCertPath: "some/path", + TracingAttributes: map[string]string{"key": "value"}, + + TelemetryEnabled: true, + TelemetryEndpoint: "example.com/beholder", + TelemetryInsecureConnection: true, + TelemetryCACertFile: "foo/bar", + TelemetryAttributes: OtelAttributes{"foo": "bar", "baz": "42"}, + TelemetryTraceSampleRatio: 0.42, + TelemetryAuthHeaders: map[string]string{"header-key": "header-value"}, + TelemetryAuthPubKeyHex: "pub-key-hex", + TelemetryEmitterBatchProcessor: true, + TelemetryEmitterExportTimeout: 1 * time.Second, + TelemetryEmitterExportInterval: 2 * time.Second, + TelemetryEmitterExportMaxBatchSize: 100, + TelemetryEmitterMaxQueueSize: 1000, + + ChipIngressEndpoint: "http://chip-ingress.example.com", + } got := map[string]string{} - for _, kv := range envCfgFull.AsCmdEnv() { + for _, kv := range envCfg.AsCmdEnv() { pair := strings.SplitN(kv, "=", 2) require.Len(t, pair, 2) got[pair[0]] = pair[1] } assert.Equal(t, "postgres://user:password@localhost:5432/db", got[envDatabaseURL]) - - assert.Equal(t, "1ms", got[envMercuryCacheLatestReportDeadline]) - assert.Equal(t, "1µs", got[envMercuryCacheLatestReportTTL]) - assert.Equal(t, "1ns", got[envMercuryCacheMaxStaleAge]) - assert.Equal(t, "foo", got[envMercuryTransmitterProtocol]) - assert.Equal(t, "42", got[envMercuryTransmitterTransmitQueueMaxSize]) - assert.Equal(t, "1s", got[envMercuryTransmitterTransmitTimeout]) - assert.Equal(t, "13", got[envMercuryTransmitterTransmitConcurrency]) - assert.Equal(t, "1h0m0s", got[envMercuryTransmitterReaperFrequency]) - assert.Equal(t, "1m0s", got[envMercuryTransmitterReaperMaxAge]) - assert.Equal(t, "true", got[envMercuryVerboseLogging]) - - assert.Equal(t, strconv.Itoa(8080), got[envPromPort]) + assert.Equal(t, strconv.Itoa(9090), got[envPromPort]) assert.Equal(t, "true", got[envTracingEnabled]) - assert.Equal(t, "some:target", got[envTracingCollectorTarget]) - assert.Equal(t, "1", got[envTracingSamplingRatio]) - assert.Equal(t, "internal/test/fixtures/client.pem", got[envTracingTLSCertPath]) - assert.Equal(t, "value", got[envTracingAttribute+"XYZ"]) + assert.Equal(t, "http://localhost:9000", got[envTracingCollectorTarget]) + assert.Equal(t, "0.1", got[envTracingSamplingRatio]) + assert.Equal(t, "some/path", got[envTracingTLSCertPath]) + assert.Equal(t, "value", got[envTracingAttribute+"key"]) assert.Equal(t, "true", got[envTelemetryEnabled]) assert.Equal(t, "example.com/beholder", got[envTelemetryEndpoint]) diff --git a/pkg/loop/internal/example-relay/main.go b/pkg/loop/internal/example-relay/main.go index c74c1f83b..635dcb839 100644 --- a/pkg/loop/internal/example-relay/main.go +++ b/pkg/loop/internal/example-relay/main.go @@ -60,7 +60,7 @@ func (p *pluginRelayer) HealthReport() map[string]error { return map[string]erro func (p *pluginRelayer) Name() string { return p.lggr.Name() } -func (p *pluginRelayer) NewRelayer(ctx context.Context, config string, keystore, csaKeystore core.Keystore, cr core.CapabilitiesRegistry) (loop.Relayer, error) { +func (p *pluginRelayer) NewRelayer(ctx context.Context, config string, keystore core.Keystore, cr core.CapabilitiesRegistry) (loop.Relayer, error) { return &relayer{lggr: logger.Named(p.lggr, "Relayer"), ds: p.ds}, nil } diff --git a/pkg/loop/internal/pb/relayer.pb.go b/pkg/loop/internal/pb/relayer.pb.go index dfe17c046..eee091b12 100644 --- a/pkg/loop/internal/pb/relayer.pb.go +++ b/pkg/loop/internal/pb/relayer.pb.go @@ -28,7 +28,6 @@ type NewRelayerRequest struct { Config string `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"` // toml (is chain instance config enough?) KeystoreID uint32 `protobuf:"varint,2,opt,name=keystoreID,proto3" json:"keystoreID,omitempty"` CapabilityRegistryID uint32 `protobuf:"varint,3,opt,name=capabilityRegistryID,proto3" json:"capabilityRegistryID,omitempty"` - KeystoreCSAID uint32 `protobuf:"varint,4,opt,name=keystoreCSAID,proto3" json:"keystoreCSAID,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -84,13 +83,6 @@ func (x *NewRelayerRequest) GetCapabilityRegistryID() uint32 { return 0 } -func (x *NewRelayerRequest) GetKeystoreCSAID() uint32 { - if x != nil { - return x.KeystoreCSAID - } - return 0 -} - type NewRelayerReply struct { state protoimpl.MessageState `protogen:"open.v1"` RelayerID uint32 `protobuf:"varint,1,opt,name=relayerID,proto3" json:"relayerID,omitempty"` @@ -2517,14 +2509,13 @@ var File_loop_internal_pb_relayer_proto protoreflect.FileDescriptor const file_loop_internal_pb_relayer_proto_rawDesc = "" + "\n" + - "\x1eloop/internal/pb/relayer.proto\x12\x04loop\x1a\x1bgoogle/protobuf/empty.proto\x1a\x1cgoogle/protobuf/struct.proto\x1a&loop/internal/pb/contract_reader.proto\"\xa5\x01\n" + + "\x1eloop/internal/pb/relayer.proto\x12\x04loop\x1a\x1bgoogle/protobuf/empty.proto\x1a\x1cgoogle/protobuf/struct.proto\x1a&loop/internal/pb/contract_reader.proto\"\x7f\n" + "\x11NewRelayerRequest\x12\x16\n" + "\x06config\x18\x01 \x01(\tR\x06config\x12\x1e\n" + "\n" + "keystoreID\x18\x02 \x01(\rR\n" + "keystoreID\x122\n" + - "\x14capabilityRegistryID\x18\x03 \x01(\rR\x14capabilityRegistryID\x12$\n" + - "\rkeystoreCSAID\x18\x04 \x01(\rR\rkeystoreCSAID\"/\n" + + "\x14capabilityRegistryID\x18\x03 \x01(\rR\x14capabilityRegistryID\"/\n" + "\x0fNewRelayerReply\x12\x1c\n" + "\trelayerID\x18\x01 \x01(\rR\trelayerID\"+\n" + "\rAccountsReply\x12\x1a\n" + diff --git a/pkg/loop/internal/pb/relayer.proto b/pkg/loop/internal/pb/relayer.proto index aa5e3b246..58aff16a6 100644 --- a/pkg/loop/internal/pb/relayer.proto +++ b/pkg/loop/internal/pb/relayer.proto @@ -16,7 +16,8 @@ message NewRelayerRequest { string config = 1; // toml (is chain instance config enough?) uint32 keystoreID = 2; uint32 capabilityRegistryID = 3; - uint32 keystoreCSAID = 4; + + //TODO prometheus? https://smartcontract-it.atlassian.net/browse/BCF-2075 } message NewRelayerReply { diff --git a/pkg/loop/internal/relayer/relayer.go b/pkg/loop/internal/relayer/relayer.go index 616bfa7f3..c44600189 100644 --- a/pkg/loop/internal/relayer/relayer.go +++ b/pkg/loop/internal/relayer/relayer.go @@ -46,10 +46,10 @@ func NewPluginRelayerClient(brokerCfg net.BrokerConfig) *PluginRelayerClient { return &PluginRelayerClient{PluginClient: pc, pluginRelayer: pb.NewPluginRelayerClient(pc), ServiceClient: goplugin.NewServiceClient(pc.BrokerExt, pc)} } -func (p *PluginRelayerClient) NewRelayer(ctx context.Context, config string, keystore, csaKeystore core.Keystore, capabilityRegistry core.CapabilitiesRegistry) (looptypes.Relayer, error) { - cc := p.NewClientConn("Relayer", func(ctx context.Context) (relayerID uint32, deps net.Resources, err error) { +func (p *PluginRelayerClient) NewRelayer(ctx context.Context, config string, keystore core.Keystore, capabilityRegistry core.CapabilitiesRegistry) (looptypes.Relayer, error) { + cc := p.NewClientConn("Relayer", func(ctx context.Context) (id uint32, deps net.Resources, err error) { var ksRes net.Resource - ksID, ksRes, err := p.ServeNew("Keystore", func(s *grpc.Server) { + id, ksRes, err = p.ServeNew("Keystore", func(s *grpc.Server) { pb.RegisterKeystoreServer(s, &keystoreServer{impl: keystore}) }) if err != nil { @@ -57,15 +57,6 @@ func (p *PluginRelayerClient) NewRelayer(ctx context.Context, config string, key } deps.Add(ksRes) - var ksCSARes net.Resource - ksCSAID, ksCSARes, err := p.ServeNew("CSAKeystore", func(s *grpc.Server) { - pb.RegisterKeystoreServer(s, &keystoreServer{impl: csaKeystore}) - }) - if err != nil { - return 0, nil, fmt.Errorf("Failed to create relayer client: failed to serve CSA keystore: %w", err) - } - deps.Add(ksCSARes) - capabilityRegistryID, capabilityRegistryResource, err := p.ServeNew("CapabilitiesRegistry", func(s *grpc.Server) { pb.RegisterCapabilitiesRegistryServer(s, capability.NewCapabilitiesRegistryServer(p.BrokerExt, capabilityRegistry)) }) @@ -76,8 +67,7 @@ func (p *PluginRelayerClient) NewRelayer(ctx context.Context, config string, key reply, err := p.pluginRelayer.NewRelayer(ctx, &pb.NewRelayerRequest{ Config: config, - KeystoreID: ksID, - KeystoreCSAID: ksCSAID, + KeystoreID: id, CapabilityRegistryID: capabilityRegistryID, }) if err != nil { @@ -112,31 +102,22 @@ func (p *pluginRelayerServer) NewRelayer(ctx context.Context, request *pb.NewRel if err != nil { return nil, net.ErrConnDial{Name: "Keystore", ID: request.KeystoreID, Err: err} } - ksRes := net.Resource{Closer: ksConn, Name: "Keystore"} - - ksCSAConn, err := p.Dial(request.KeystoreCSAID) - if err != nil { - p.CloseAll(ksRes) - return nil, net.ErrConnDial{Name: "CSAKeystore", ID: request.KeystoreCSAID, Err: err} - } - ksCSARes := net.Resource{Closer: ksConn, Name: "CSAKeystore"} - + ksRes := net.Resource{Closer: ksConn, Name: "CapabilityRegistry"} capRegistryConn, err := p.Dial(request.CapabilityRegistryID) if err != nil { - p.CloseAll(ksRes, ksCSARes) return nil, net.ErrConnDial{Name: "CapabilityRegistry", ID: request.CapabilityRegistryID, Err: err} } crRes := net.Resource{Closer: capRegistryConn, Name: "CapabilityRegistry"} capRegistry := capability.NewCapabilitiesRegistryClient(capRegistryConn, p.BrokerExt) - r, err := p.impl.NewRelayer(ctx, request.Config, newKeystoreClient(ksConn), newKeystoreClient(ksCSAConn), capRegistry) + r, err := p.impl.NewRelayer(ctx, request.Config, newKeystoreClient(ksConn), capRegistry) if err != nil { - p.CloseAll(ksRes, ksCSARes, crRes) + p.CloseAll(ksRes, crRes) return nil, err } err = r.Start(ctx) if err != nil { - p.CloseAll(ksRes, ksCSARes, crRes) + p.CloseAll(ksRes, crRes) return nil, err } @@ -148,7 +129,7 @@ func (p *pluginRelayerServer) NewRelayer(ctx context.Context, request *pb.NewRel if evmService, ok := r.(types.EVMService); ok { evmpb.RegisterEVMServer(s, newEVMServer(evmService, p.BrokerExt)) } - }, rRes, ksRes, ksCSARes, crRes) + }, rRes, ksRes, crRes) if err != nil { return nil, err } diff --git a/pkg/loop/internal/relayer/test/relayer.go b/pkg/loop/internal/relayer/test/relayer.go index 7af353998..e10397d2e 100644 --- a/pkg/loop/internal/relayer/test/relayer.go +++ b/pkg/loop/internal/relayer/test/relayer.go @@ -141,7 +141,7 @@ func (s staticPluginRelayer) HealthReport() map[string]error { return hp } -func (s staticPluginRelayer) NewRelayer(ctx context.Context, config string, keystore, csaKeystore core.Keystore, capabilityRegistry core.CapabilitiesRegistry) (looptypes.Relayer, error) { +func (s staticPluginRelayer) NewRelayer(ctx context.Context, config string, keystore core.Keystore, capabilityRegistry core.CapabilitiesRegistry) (looptypes.Relayer, error) { if s.relayer.StaticChecks && config != ConfigTOML { return nil, fmt.Errorf("expected config %q but got %q", ConfigTOML, config) } @@ -152,13 +152,6 @@ func (s staticPluginRelayer) NewRelayer(ctx context.Context, config string, keys if len(keys) == 0 { return nil, fmt.Errorf("expected at least one key but got none") } - keys, err = csaKeystore.Accounts(ctx) - if err != nil { - return nil, err - } - if len(keys) == 0 { - return nil, fmt.Errorf("expected at least one CSA key but got none") - } return s.relayer, nil } @@ -440,7 +433,7 @@ func newRelayArgsWithProviderType(_type types.OCR2PluginType) types.RelayArgs { func RunPlugin(t *testing.T, p looptypes.PluginRelayer) { t.Run("Relayer", func(t *testing.T) { ctx := t.Context() - relayer, err := p.NewRelayer(ctx, ConfigTOML, keystoretest.Keystore, keystoretest.Keystore, nil) + relayer, err := p.NewRelayer(ctx, ConfigTOML, keystoretest.Keystore, nil) require.NoError(t, err) servicetest.Run(t, relayer) Run(t, relayer) @@ -487,7 +480,7 @@ func RunFuzzPluginRelayer(f *testing.F, relayerFunc func(*testing.T) looptypes.P } ctx := t.Context() - _, err := relayerFunc(t).NewRelayer(ctx, fConfig, keystore, keystore, nil) + _, err := relayerFunc(t).NewRelayer(ctx, fConfig, keystore, nil) grpcUnavailableErr(t, err) }) diff --git a/pkg/loop/internal/types/types.go b/pkg/loop/internal/types/types.go index 7b1f384a9..bf4132cf2 100644 --- a/pkg/loop/internal/types/types.go +++ b/pkg/loop/internal/types/types.go @@ -11,7 +11,7 @@ import ( type PluginRelayer interface { services.Service - NewRelayer(ctx context.Context, config string, keystore, csaKeystore core.Keystore, capabilityRegistry core.CapabilitiesRegistry) (Relayer, error) + NewRelayer(ctx context.Context, config string, keystore core.Keystore, capabilityRegistry core.CapabilitiesRegistry) (Relayer, error) } type MedianProvider interface { diff --git a/pkg/loop/plugin_median_test.go b/pkg/loop/plugin_median_test.go index b2cf94050..e3d29ab9c 100644 --- a/pkg/loop/plugin_median_test.go +++ b/pkg/loop/plugin_median_test.go @@ -85,7 +85,7 @@ func newStopCh(t *testing.T) <-chan struct{} { func newMedianProvider(t *testing.T, pr loop.PluginRelayer) types.MedianProvider { ctx := t.Context() - r, err := pr.NewRelayer(ctx, test.ConfigTOML, keystoretest.Keystore, keystoretest.Keystore, nil) + r, err := pr.NewRelayer(ctx, test.ConfigTOML, keystoretest.Keystore, nil) require.NoError(t, err) servicetest.Run(t, r) p, err := r.NewPluginProvider(ctx, relayertest.RelayArgs, relayertest.PluginArgs) diff --git a/pkg/loop/plugin_mercury_test.go b/pkg/loop/plugin_mercury_test.go index d349b700e..20e16142b 100644 --- a/pkg/loop/plugin_mercury_test.go +++ b/pkg/loop/plugin_mercury_test.go @@ -67,7 +67,7 @@ func TestPluginMercuryExec(t *testing.T) { func newMercuryProvider(t *testing.T, pr loop.PluginRelayer) types.MercuryProvider { ctx := t.Context() - r, err := pr.NewRelayer(ctx, test.ConfigTOML, keystoretest.Keystore, keystoretest.Keystore, nil) + r, err := pr.NewRelayer(ctx, test.ConfigTOML, keystoretest.Keystore, nil) require.NoError(t, err) servicetest.Run(t, r) p, err := r.NewPluginProvider(ctx, mercurytest.RelayArgs, mercurytest.PluginArgs) diff --git a/pkg/loop/plugin_relayer_test.go b/pkg/loop/plugin_relayer_test.go index 44b650df1..62aa7a7fe 100644 --- a/pkg/loop/plugin_relayer_test.go +++ b/pkg/loop/plugin_relayer_test.go @@ -58,7 +58,7 @@ func FuzzRelayer(f *testing.F) { p := newPluginRelayerExec(t, false, stopCh) ctx := t.Context() capRegistry := mocks.NewCapabilitiesRegistry(t) - relayer, err := p.NewRelayer(ctx, test.ConfigTOML, keystoretest.Keystore, keystoretest.Keystore, capRegistry) + relayer, err := p.NewRelayer(ctx, test.ConfigTOML, keystoretest.Keystore, capRegistry) require.NoError(t, err) diff --git a/pkg/loop/relayer_service.go b/pkg/loop/relayer_service.go index 1ede94eaa..a67cd1574 100644 --- a/pkg/loop/relayer_service.go +++ b/pkg/loop/relayer_service.go @@ -22,13 +22,13 @@ type RelayerService struct { // NewRelayerService returns a new [*RelayerService]. // cmd must return a new exec.Cmd each time it is called. -func NewRelayerService(lggr logger.Logger, grpcOpts GRPCOpts, cmd func() *exec.Cmd, config string, keystore core.Keystore, csaKeystore core.Keystore, capabilityRegistry core.CapabilitiesRegistry) *RelayerService { +func NewRelayerService(lggr logger.Logger, grpcOpts GRPCOpts, cmd func() *exec.Cmd, config string, keystore core.Keystore, capabilityRegistry core.CapabilitiesRegistry) *RelayerService { newService := func(ctx context.Context, instance any) (Relayer, services.HealthReporter, error) { plug, ok := instance.(PluginRelayer) if !ok { return nil, nil, fmt.Errorf("expected PluginRelayer but got %T", instance) } - r, err := plug.NewRelayer(ctx, config, keystore, csaKeystore, capabilityRegistry) + r, err := plug.NewRelayer(ctx, config, keystore, capabilityRegistry) if err != nil { return nil, nil, fmt.Errorf("failed to create Relayer: %w", err) } diff --git a/pkg/loop/relayer_service_test.go b/pkg/loop/relayer_service_test.go index 394b87d29..9852dbd5e 100644 --- a/pkg/loop/relayer_service_test.go +++ b/pkg/loop/relayer_service_test.go @@ -43,7 +43,7 @@ func TestRelayerService(t *testing.T) { capRegistry := mocks.NewCapabilitiesRegistry(t) relayer := loop.NewRelayerService(logger.Test(t), loop.GRPCOpts{}, func() *exec.Cmd { return NewHelperProcessCommand(loop.PluginRelayerName, false, 0) - }, test.ConfigTOML, keystoretest.Keystore, keystoretest.Keystore, capRegistry) + }, test.ConfigTOML, keystoretest.Keystore, capRegistry) hook := relayer.XXXTestHook() servicetest.Run(t, relayer) @@ -83,16 +83,12 @@ func TestRelayerService_recovery(t *testing.T) { Command: loop.PluginRelayerName, Limit: int(limit.Add(1)), }.New() - }, test.ConfigTOML, keystoretest.Keystore, keystoretest.Keystore, nil) + }, test.ConfigTOML, keystoretest.Keystore, nil) servicetest.Run(t, relayer) relayertest.Run(t, relayer) - if hp := relayer.HealthReport(); len(hp) == 2 { - servicetest.AssertHealthReportNames(t, hp, relayerServiceNames[:2]...) - } else { - servicetest.AssertHealthReportNames(t, hp, relayerServiceNames...) - } + servicetest.AssertHealthReportNames(t, relayer.HealthReport(), relayerServiceNames[:2]...) } @@ -104,7 +100,7 @@ func TestRelayerService_HealthReport(t *testing.T) { capRegistry := mocks.NewCapabilitiesRegistry(t) s := loop.NewRelayerService(lggr, loop.GRPCOpts{}, func() *exec.Cmd { return HelperProcessCommand{Command: loop.PluginRelayerName}.New() - }, test.ConfigTOML, keystoretest.Keystore, keystoretest.Keystore, capRegistry) + }, test.ConfigTOML, keystoretest.Keystore, capRegistry) servicetest.AssertHealthReportNames(t, s.HealthReport(), relayerServiceNames[0]) diff --git a/pkg/loop/server.go b/pkg/loop/server.go index e2960bd55..abe79ed8a 100644 --- a/pkg/loop/server.go +++ b/pkg/loop/server.go @@ -9,7 +9,6 @@ import ( "github.com/jmoiron/sqlx" "go.opentelemetry.io/otel/attribute" - semconv "go.opentelemetry.io/otel/semconv/v1.17.0" "github.com/smartcontractkit/chainlink-common/pkg/beholder" "github.com/smartcontractkit/chainlink-common/pkg/config/build" @@ -52,7 +51,6 @@ func MustNewStartedServer(loggerName string) *Server { // Server holds common plugin server fields. type Server struct { - EnvConfig EnvConfig GRPCOpts GRPCOpts Logger logger.SugaredLogger db *sqlx.DB // optional @@ -83,25 +81,21 @@ func (s *Server) start() error { stopAfter := context.AfterFunc(ctx, stopSig) defer stopAfter() - if err := s.EnvConfig.parse(); err != nil { + var envCfg EnvConfig + if err := envCfg.parse(); err != nil { return fmt.Errorf("error getting environment configuration: %w", err) } - tracingAttrs := s.EnvConfig.TracingAttributes - if tracingAttrs == nil { - tracingAttrs = make(map[string]string, 1) - } - tracingAttrs[string(semconv.ServiceInstanceIDKey)] = s.EnvConfig.AppID tracingConfig := TracingConfig{ - Enabled: s.EnvConfig.TracingEnabled, - CollectorTarget: s.EnvConfig.TracingCollectorTarget, - SamplingRatio: s.EnvConfig.TracingSamplingRatio, - TLSCertPath: s.EnvConfig.TracingTLSCertPath, - NodeAttributes: tracingAttrs, + Enabled: envCfg.TracingEnabled, + CollectorTarget: envCfg.TracingCollectorTarget, + SamplingRatio: envCfg.TracingSamplingRatio, + TLSCertPath: envCfg.TracingTLSCertPath, + NodeAttributes: envCfg.TracingAttributes, OnDialError: func(err error) { s.Logger.Errorw("Failed to dial", "err", err) }, } - if s.EnvConfig.TelemetryEndpoint == "" { + if envCfg.TelemetryEndpoint == "" { err := SetupTracing(tracingConfig) if err != nil { return fmt.Errorf("failed to setup tracing: %w", err) @@ -112,20 +106,20 @@ func (s *Server) start() error { attributes = tracingConfig.Attributes() } beholderCfg := beholder.Config{ - InsecureConnection: s.EnvConfig.TelemetryInsecureConnection, - CACertFile: s.EnvConfig.TelemetryCACertFile, - OtelExporterGRPCEndpoint: s.EnvConfig.TelemetryEndpoint, - ResourceAttributes: append(attributes, s.EnvConfig.TelemetryAttributes.AsStringAttributes()...), - TraceSampleRatio: s.EnvConfig.TelemetryTraceSampleRatio, - AuthHeaders: s.EnvConfig.TelemetryAuthHeaders, - AuthPublicKeyHex: s.EnvConfig.TelemetryAuthPubKeyHex, - EmitterBatchProcessor: s.EnvConfig.TelemetryEmitterBatchProcessor, - EmitterExportTimeout: s.EnvConfig.TelemetryEmitterExportTimeout, - EmitterExportInterval: s.EnvConfig.TelemetryEmitterExportInterval, - EmitterExportMaxBatchSize: s.EnvConfig.TelemetryEmitterExportMaxBatchSize, - EmitterMaxQueueSize: s.EnvConfig.TelemetryEmitterMaxQueueSize, - ChipIngressEmitterEnabled: s.EnvConfig.ChipIngressEndpoint != "", - ChipIngressEmitterGRPCEndpoint: s.EnvConfig.ChipIngressEndpoint, + InsecureConnection: envCfg.TelemetryInsecureConnection, + CACertFile: envCfg.TelemetryCACertFile, + OtelExporterGRPCEndpoint: envCfg.TelemetryEndpoint, + ResourceAttributes: append(attributes, envCfg.TelemetryAttributes.AsStringAttributes()...), + TraceSampleRatio: envCfg.TelemetryTraceSampleRatio, + AuthHeaders: envCfg.TelemetryAuthHeaders, + AuthPublicKeyHex: envCfg.TelemetryAuthPubKeyHex, + EmitterBatchProcessor: envCfg.TelemetryEmitterBatchProcessor, + EmitterExportTimeout: envCfg.TelemetryEmitterExportTimeout, + EmitterExportInterval: envCfg.TelemetryEmitterExportInterval, + EmitterExportMaxBatchSize: envCfg.TelemetryEmitterExportMaxBatchSize, + EmitterMaxQueueSize: envCfg.TelemetryEmitterMaxQueueSize, + ChipIngressEmitterEnabled: envCfg.ChipIngressEndpoint != "", + ChipIngressEmitterGRPCEndpoint: envCfg.ChipIngressEndpoint, } if tracingConfig.Enabled { @@ -147,7 +141,7 @@ func (s *Server) start() error { beholder.SetGlobalOtelProviders() } - s.promServer = NewPromServer(s.EnvConfig.PrometheusPort, s.Logger) + s.promServer = NewPromServer(envCfg.PrometheusPort, s.Logger) if err := s.promServer.Start(); err != nil { return fmt.Errorf("error starting prometheus server: %w", err) } @@ -157,22 +151,22 @@ func (s *Server) start() error { return fmt.Errorf("error starting health checker: %w", err) } - if s.EnvConfig.DatabaseURL != nil { - pg.SetApplicationName(s.EnvConfig.DatabaseURL.URL(), build.Program) - dbURL := s.EnvConfig.DatabaseURL.URL().String() + if envCfg.DatabaseURL != nil { + pg.SetApplicationName(envCfg.DatabaseURL.URL(), build.Program) + dbURL := envCfg.DatabaseURL.URL().String() var err error s.db, err = pg.DBConfig{ - IdleInTxSessionTimeout: s.EnvConfig.DatabaseIdleInTxSessionTimeout, - LockTimeout: s.EnvConfig.DatabaseLockTimeout, - MaxOpenConns: s.EnvConfig.DatabaseMaxOpenConns, - MaxIdleConns: s.EnvConfig.DatabaseMaxIdleConns, + IdleInTxSessionTimeout: envCfg.DatabaseIdleInTxSessionTimeout, + LockTimeout: envCfg.DatabaseLockTimeout, + MaxOpenConns: envCfg.DatabaseMaxOpenConns, + MaxIdleConns: envCfg.DatabaseMaxIdleConns, }.New(ctx, dbURL, pg.DriverPostgres) if err != nil { return fmt.Errorf("error connecting to DataBase: %w", err) } s.DataSource = sqlutil.WrapDataSource(s.db, s.Logger, - sqlutil.TimeoutHook(func() time.Duration { return s.EnvConfig.DatabaseQueryTimeout }), - sqlutil.MonitorHook(func() bool { return s.EnvConfig.DatabaseLogSQL })) + sqlutil.TimeoutHook(func() time.Duration { return envCfg.DatabaseQueryTimeout }), + sqlutil.MonitorHook(func() bool { return envCfg.DatabaseLogSQL })) s.dbStatsReporter = pg.NewStatsReporter(s.db.Stats, s.Logger) s.dbStatsReporter.Start() From 5981ca5ebf20aaf4f0aacd09afee349d3c32c7be Mon Sep 17 00:00:00 2001 From: Silas Lenihan Date: Fri, 6 Jun 2025 14:02:35 -0400 Subject: [PATCH 11/16] Reapply "pkg/loop: expand EnvConfig and make available from Server (#1149)" This reverts commit 1ff9a4634a1f7bbf2ffcaaec8d038587689e8857. --- pkg/loop/ccip_commit_test.go | 2 +- pkg/loop/ccip_execution_test.go | 2 +- pkg/loop/config.go | 137 +++++++++- pkg/loop/config_test.go | 311 +++++++--------------- pkg/loop/internal/example-relay/main.go | 2 +- pkg/loop/internal/pb/relayer.pb.go | 13 +- pkg/loop/internal/pb/relayer.proto | 3 +- pkg/loop/internal/relayer/relayer.go | 37 ++- pkg/loop/internal/relayer/test/relayer.go | 13 +- pkg/loop/internal/types/types.go | 2 +- pkg/loop/plugin_median_test.go | 2 +- pkg/loop/plugin_mercury_test.go | 2 +- pkg/loop/plugin_relayer_test.go | 2 +- pkg/loop/relayer_service.go | 4 +- pkg/loop/relayer_service_test.go | 12 +- pkg/loop/server.go | 70 ++--- 16 files changed, 329 insertions(+), 285 deletions(-) diff --git a/pkg/loop/ccip_commit_test.go b/pkg/loop/ccip_commit_test.go index fbce4f14b..06281cd67 100644 --- a/pkg/loop/ccip_commit_test.go +++ b/pkg/loop/ccip_commit_test.go @@ -112,7 +112,7 @@ func TestCommitLOOP(t *testing.T) { func newCommitProvider(t *testing.T, pr loop.PluginRelayer) (types.CCIPCommitProvider, error) { ctx := t.Context() - r, err := pr.NewRelayer(ctx, test.ConfigTOML, keystoretest.Keystore, nil) + r, err := pr.NewRelayer(ctx, test.ConfigTOML, keystoretest.Keystore, keystoretest.Keystore, nil) require.NoError(t, err) servicetest.Run(t, r) diff --git a/pkg/loop/ccip_execution_test.go b/pkg/loop/ccip_execution_test.go index 52f6cb6a4..b665e156e 100644 --- a/pkg/loop/ccip_execution_test.go +++ b/pkg/loop/ccip_execution_test.go @@ -113,7 +113,7 @@ func TestExecLOOP(t *testing.T) { func newExecutionProvider(t *testing.T, pr loop.PluginRelayer) (types.CCIPExecProvider, error) { ctx := t.Context() - r, err := pr.NewRelayer(ctx, test.ConfigTOML, keystoretest.Keystore, nil) + r, err := pr.NewRelayer(ctx, test.ConfigTOML, keystoretest.Keystore, keystoretest.Keystore, nil) require.NoError(t, err) servicetest.Run(t, r) diff --git a/pkg/loop/config.go b/pkg/loop/config.go index 0be7a4eab..59010ce2f 100644 --- a/pkg/loop/config.go +++ b/pkg/loop/config.go @@ -14,13 +14,30 @@ import ( ) const ( - envDatabaseURL = "CL_DATABASE_URL" - envDatabaseIdleInTxSessionTimeout = "CL_DATABASE_IDLE_IN_TX_SESSION_TIMEOUT" - envDatabaseLockTimeout = "CL_DATABASE_LOCK_TIMEOUT" - envDatabaseQueryTimeout = "CL_DATABASE_QUERY_TIMEOUT" - envDatabaseLogSQL = "CL_DATABASE_LOG_SQL" - envDatabaseMaxOpenConns = "CL_DATABASE_MAX_OPEN_CONNS" - envDatabaseMaxIdleConns = "CL_DATABASE_MAX_IDLE_CONNS" + envAppID = "CL_APP_ID" + + envDatabaseURL = "CL_DATABASE_URL" + envDatabaseIdleInTxSessionTimeout = "CL_DATABASE_IDLE_IN_TX_SESSION_TIMEOUT" + envDatabaseLockTimeout = "CL_DATABASE_LOCK_TIMEOUT" + envDatabaseQueryTimeout = "CL_DATABASE_QUERY_TIMEOUT" + envDatabaseListenerFallbackPollInterval = "CL_DATABASE_LISTNER_FALLBACK_POLL_INTERVAL" + envDatabaseLogSQL = "CL_DATABASE_LOG_SQL" + envDatabaseMaxOpenConns = "CL_DATABASE_MAX_OPEN_CONNS" + envDatabaseMaxIdleConns = "CL_DATABASE_MAX_IDLE_CONNS" + + envFeatureLogPoller = "CL_FEATURE_LOG_POLLER" + + envMercuryCacheLatestReportDeadline = "CL_MERCURY_CACHE_LATEST_REPORT_DEADLINE" + envMercuryCacheLatestReportTTL = "CL_MERCURY_CACHE_LATEST_REPORT_TTL" + envMercuryCacheMaxStaleAge = "CL_MERCURY_CACHE_MAX_STALE_AGE" + + envMercuryTransmitterProtocol = "CL_MERCURY_TRANSMITTER_PROTOCOL" + envMercuryTransmitterTransmitQueueMaxSize = "CL_MERCURY_TRANSMITTER_TRANSMIT__QUEUE_MAX_SIZE" + envMercuryTransmitterTransmitTimeout = "CL_MERCURY_TRANSMITTER_TRANSMIT_TIMEOUT" + envMercuryTransmitterTransmitConcurrency = "CL_MERCURY_TRANSMITTER_TRANSMIT_CONCURRENCY" + envMercuryTransmitterReaperFrequency = "CL_MERCURY_TRANSMITTER_REAPER_FREQUENCY" + envMercuryTransmitterReaperMaxAge = "CL_MERCURY_TRANSMITTER_REAPER_MAX_AGE" + envMercuryVerboseLogging = "CL_MERCURY_VERBOSE_LOGGING" envPromPort = "CL_PROMETHEUS_PORT" @@ -50,13 +67,30 @@ const ( // EnvConfig is the configuration between the application and the LOOP executable. The values // are fully resolved and static and passed via the environment. type EnvConfig struct { - DatabaseURL *config.SecretURL - DatabaseIdleInTxSessionTimeout time.Duration - DatabaseLockTimeout time.Duration - DatabaseQueryTimeout time.Duration - DatabaseLogSQL bool - DatabaseMaxOpenConns int - DatabaseMaxIdleConns int + AppID string + + DatabaseURL *config.SecretURL + DatabaseIdleInTxSessionTimeout time.Duration + DatabaseLockTimeout time.Duration + DatabaseQueryTimeout time.Duration + DatabaseListenerFallbackPollInterval time.Duration + DatabaseLogSQL bool + DatabaseMaxOpenConns int + DatabaseMaxIdleConns int + + FeatureLogPoller bool + + MercuryCacheLatestReportDeadline time.Duration + MercuryCacheLatestReportTTL time.Duration + MercuryCacheMaxStaleAge time.Duration + + MercuryTransmitterProtocol string + MercuryTransmitterTransmitQueueMaxSize uint32 + MercuryTransmitterTransmitTimeout time.Duration + MercuryTransmitterTransmitConcurrency uint32 + MercuryTransmitterReaperFrequency time.Duration + MercuryTransmitterReaperMaxAge time.Duration + MercuryVerboseLogging bool PrometheusPort int @@ -89,16 +123,33 @@ func (e *EnvConfig) AsCmdEnv() (env []string) { env = append(env, k+"="+v) } + add(envAppID, e.AppID) + if e.DatabaseURL != nil { // optional add(envDatabaseURL, e.DatabaseURL.URL().String()) add(envDatabaseIdleInTxSessionTimeout, e.DatabaseIdleInTxSessionTimeout.String()) add(envDatabaseLockTimeout, e.DatabaseLockTimeout.String()) add(envDatabaseQueryTimeout, e.DatabaseQueryTimeout.String()) + add(envDatabaseListenerFallbackPollInterval, e.DatabaseListenerFallbackPollInterval.String()) add(envDatabaseLogSQL, strconv.FormatBool(e.DatabaseLogSQL)) add(envDatabaseMaxOpenConns, strconv.Itoa(e.DatabaseMaxOpenConns)) add(envDatabaseMaxIdleConns, strconv.Itoa(e.DatabaseMaxIdleConns)) } + add(envFeatureLogPoller, strconv.FormatBool(e.FeatureLogPoller)) + + add(envMercuryCacheLatestReportDeadline, e.MercuryCacheLatestReportDeadline.String()) + add(envMercuryCacheLatestReportTTL, e.MercuryCacheLatestReportTTL.String()) + add(envMercuryCacheMaxStaleAge, e.MercuryCacheMaxStaleAge.String()) + + add(envMercuryTransmitterProtocol, e.MercuryTransmitterProtocol) + add(envMercuryTransmitterTransmitQueueMaxSize, strconv.FormatUint(uint64(e.MercuryTransmitterTransmitQueueMaxSize), 10)) + add(envMercuryTransmitterTransmitTimeout, e.MercuryTransmitterTransmitTimeout.String()) + add(envMercuryTransmitterTransmitConcurrency, strconv.FormatUint(uint64(e.MercuryTransmitterTransmitConcurrency), 10)) + add(envMercuryTransmitterReaperFrequency, e.MercuryTransmitterReaperFrequency.String()) + add(envMercuryTransmitterReaperMaxAge, e.MercuryTransmitterReaperMaxAge.String()) + add(envMercuryVerboseLogging, strconv.FormatBool(e.MercuryVerboseLogging)) + add(envPromPort, strconv.Itoa(e.PrometheusPort)) add(envTracingEnabled, strconv.FormatBool(e.TracingEnabled)) @@ -136,6 +187,7 @@ func (e *EnvConfig) AsCmdEnv() (env []string) { // parse deserializes environment variables func (e *EnvConfig) parse() error { + e.AppID = os.Getenv(envAppID) var err error e.DatabaseURL, err = getEnv(envDatabaseURL, func(s string) (*config.SecretURL, error) { if s == "" { // DatabaseURL is optional @@ -163,6 +215,10 @@ func (e *EnvConfig) parse() error { if err != nil { return err } + e.DatabaseListenerFallbackPollInterval, err = getEnv(envDatabaseListenerFallbackPollInterval, time.ParseDuration) + if err != nil { + return err + } e.DatabaseLogSQL, err = getEnv(envDatabaseLogSQL, strconv.ParseBool) if err != nil { return err @@ -177,6 +233,50 @@ func (e *EnvConfig) parse() error { } } + e.FeatureLogPoller, err = getBool(envFeatureLogPoller) + if err != nil { + return err + } + + e.MercuryCacheLatestReportDeadline, err = getEnv(envMercuryCacheLatestReportDeadline, time.ParseDuration) + if err != nil { + return err + } + e.MercuryCacheLatestReportTTL, err = getEnv(envMercuryCacheLatestReportTTL, time.ParseDuration) + if err != nil { + return err + } + e.MercuryCacheMaxStaleAge, err = getEnv(envMercuryCacheMaxStaleAge, time.ParseDuration) + if err != nil { + return err + } + + e.MercuryTransmitterProtocol = os.Getenv(envMercuryTransmitterProtocol) + e.MercuryTransmitterTransmitQueueMaxSize, err = getUint32(envMercuryTransmitterTransmitQueueMaxSize) + if err != nil { + return err + } + e.MercuryTransmitterTransmitTimeout, err = getEnv(envMercuryTransmitterTransmitTimeout, time.ParseDuration) + if err != nil { + return err + } + e.MercuryTransmitterTransmitConcurrency, err = getUint32(envMercuryTransmitterTransmitConcurrency) + if err != nil { + return err + } + e.MercuryTransmitterReaperFrequency, err = getEnv(envMercuryTransmitterReaperFrequency, time.ParseDuration) + if err != nil { + return err + } + e.MercuryTransmitterReaperMaxAge, err = getEnv(envMercuryTransmitterReaperMaxAge, time.ParseDuration) + if err != nil { + return err + } + e.MercuryVerboseLogging, err = getBool(envMercuryVerboseLogging) + if err != nil { + return err + } + promPortStr := os.Getenv(envPromPort) e.PrometheusPort, err = strconv.Atoi(promPortStr) if err != nil { @@ -299,6 +399,15 @@ func getFloat64OrZero(envKey string) float64 { return f } +func getUint32(envKey string) (uint32, error) { + s := os.Getenv(envKey) + u, err := strconv.ParseUint(s, 10, 32) + if err != nil { + return 0, err + } + return uint32(u), nil +} + func getEnv[T any](key string, parse func(string) (T, error)) (t T, err error) { v := os.Getenv(key) t, err = parse(v) diff --git a/pkg/loop/config_test.go b/pkg/loop/config_test.go index 85a0b6547..e3d240c41 100644 --- a/pkg/loop/config_test.go +++ b/pkg/loop/config_test.go @@ -1,8 +1,6 @@ package loop import ( - "maps" - "net/url" "os" "strconv" "strings" @@ -24,47 +22,37 @@ func TestEnvConfig_parse(t *testing.T) { envVars map[string]string expectError bool - expectedDatabaseURL string - expectedDatabaseIdleInTxSessionTimeout time.Duration - expectedDatabaseLockTimeout time.Duration - expectedDatabaseQueryTimeout time.Duration - expectedDatabaseLogSQL bool - expectedDatabaseMaxOpenConns int - expectedDatabaseMaxIdleConns int - - expectedPrometheusPort int - expectedTracingEnabled bool - expectedTracingCollectorTarget string - expectedTracingSamplingRatio float64 - expectedTracingTLSCertPath string - - expectedTelemetryEnabled bool - expectedTelemetryEndpoint string - expectedTelemetryInsecureConn bool - expectedTelemetryCACertFile string - expectedTelemetryAttributes OtelAttributes - expectedTelemetryTraceSampleRatio float64 - expectedTelemetryAuthHeaders map[string]string - expectedTelemetryAuthPubKeyHex string - expectedTelemetryEmitterBatchProcessor bool - expectedTelemetryEmitterExportTimeout time.Duration - expectedTelemetryEmitterExportInterval time.Duration - expectedTelemetryEmitterExportMaxBatchSize int - expectedTelemetryEmitterMaxQueueSize int - expectedChipIngressEndpoint string + expectConfig EnvConfig }{ { name: "All variables set correctly", envVars: map[string]string{ - envDatabaseURL: "postgres://user:password@localhost:5432/db", - envDatabaseIdleInTxSessionTimeout: "42s", - envDatabaseLockTimeout: "8m", - envDatabaseQueryTimeout: "7s", - envDatabaseLogSQL: "true", - envDatabaseMaxOpenConns: "9999", - envDatabaseMaxIdleConns: "8080", - - envPromPort: "8080", + envAppID: "app-id", + envDatabaseURL: "postgres://user:password@localhost:5432/db", + envDatabaseIdleInTxSessionTimeout: "42s", + envDatabaseLockTimeout: "8m", + envDatabaseQueryTimeout: "7s", + envDatabaseListenerFallbackPollInterval: "17s", + envDatabaseLogSQL: "true", + envDatabaseMaxOpenConns: "9999", + envDatabaseMaxIdleConns: "8080", + + envFeatureLogPoller: "true", + + envMercuryCacheLatestReportDeadline: "1ms", + envMercuryCacheLatestReportTTL: "1µs", + envMercuryCacheMaxStaleAge: "1ns", + + envMercuryTransmitterProtocol: "foo", + envMercuryTransmitterTransmitQueueMaxSize: "42", + envMercuryTransmitterTransmitTimeout: "1s", + envMercuryTransmitterTransmitConcurrency: "13", + envMercuryTransmitterReaperFrequency: "1h", + envMercuryTransmitterReaperMaxAge: "1m", + envMercuryVerboseLogging: "true", + + envPromPort: "8080", + envTracingEnabled: "true", envTracingCollectorTarget: "some:target", envTracingSamplingRatio: "1.0", @@ -85,38 +73,11 @@ func TestEnvConfig_parse(t *testing.T) { envTelemetryEmitterExportInterval: "2s", envTelemetryEmitterExportMaxBatchSize: "100", envTelemetryEmitterMaxQueueSize: "1000", - envChipIngressEndpoint: "http://chip-ingress.example.com", + + envChipIngressEndpoint: "http://chip-ingress.example.com", }, - expectError: false, - - expectedDatabaseURL: "postgres://user:password@localhost:5432/db", - expectedDatabaseIdleInTxSessionTimeout: 42 * time.Second, - expectedDatabaseLockTimeout: 8 * time.Minute, - expectedDatabaseQueryTimeout: 7 * time.Second, - expectedDatabaseLogSQL: true, - expectedDatabaseMaxOpenConns: 9999, - expectedDatabaseMaxIdleConns: 8080, - - expectedPrometheusPort: 8080, - expectedTracingEnabled: true, - expectedTracingCollectorTarget: "some:target", - expectedTracingSamplingRatio: 1.0, - expectedTracingTLSCertPath: "internal/test/fixtures/client.pem", - - expectedTelemetryEnabled: true, - expectedTelemetryEndpoint: "example.com/beholder", - expectedTelemetryInsecureConn: true, - expectedTelemetryCACertFile: "foo/bar", - expectedTelemetryAttributes: OtelAttributes{"foo": "bar", "baz": "42"}, - expectedTelemetryTraceSampleRatio: 0.42, - expectedTelemetryAuthHeaders: map[string]string{"header-key": "header-value"}, - expectedTelemetryAuthPubKeyHex: "pub-key-hex", - expectedTelemetryEmitterBatchProcessor: true, - expectedTelemetryEmitterExportTimeout: 1 * time.Second, - expectedTelemetryEmitterExportInterval: 2 * time.Second, - expectedTelemetryEmitterExportMaxBatchSize: 100, - expectedTelemetryEmitterMaxQueueSize: 1000, - expectedChipIngressEndpoint: "http://chip-ingress.example.com", + expectError: false, + expectConfig: envCfgFull, }, { name: "CL_DATABASE_URL parse error", @@ -152,164 +113,94 @@ func TestEnvConfig_parse(t *testing.T) { err := config.parse() if tc.expectError { - if err == nil { - t.Errorf("Expected error, got nil") - } + require.Error(t, err) } else { - if err != nil { - t.Errorf("Unexpected error: %v", err) - } else { - if config.DatabaseURL.URL().String() != tc.expectedDatabaseURL { - t.Errorf("Expected Database URL %s, got %s", tc.expectedDatabaseURL, config.DatabaseURL.String()) - } - if config.DatabaseIdleInTxSessionTimeout != tc.expectedDatabaseIdleInTxSessionTimeout { - t.Errorf("Expected Database idle in tx session timeout %s, got %s", tc.expectedDatabaseIdleInTxSessionTimeout, config.DatabaseIdleInTxSessionTimeout) - } - if config.DatabaseLockTimeout != tc.expectedDatabaseLockTimeout { - t.Errorf("Expected Database lock timeout %s, got %s", tc.expectedDatabaseLockTimeout, config.DatabaseLockTimeout) - } - if config.DatabaseQueryTimeout != tc.expectedDatabaseQueryTimeout { - t.Errorf("Expected Database query timeout %s, got %s", tc.expectedDatabaseQueryTimeout, config.DatabaseQueryTimeout) - } - if config.DatabaseLogSQL != tc.expectedDatabaseLogSQL { - t.Errorf("Expected Database log sql %t, got %t", tc.expectedDatabaseLogSQL, config.DatabaseLogSQL) - } - if config.DatabaseMaxOpenConns != tc.expectedDatabaseMaxOpenConns { - t.Errorf("Expected Database max open conns %d, got %d", tc.expectedDatabaseMaxOpenConns, config.DatabaseMaxOpenConns) - } - if config.DatabaseMaxIdleConns != tc.expectedDatabaseMaxIdleConns { - t.Errorf("Expected Database max idle conns %d, got %d", tc.expectedDatabaseMaxIdleConns, config.DatabaseMaxIdleConns) - } - - if config.PrometheusPort != tc.expectedPrometheusPort { - t.Errorf("Expected Prometheus port %d, got %d", tc.expectedPrometheusPort, config.PrometheusPort) - } - if config.TracingEnabled != tc.expectedTracingEnabled { - t.Errorf("Expected tracingEnabled %v, got %v", tc.expectedTracingEnabled, config.TracingEnabled) - } - if config.TracingCollectorTarget != tc.expectedTracingCollectorTarget { - t.Errorf("Expected tracingCollectorTarget %s, got %s", tc.expectedTracingCollectorTarget, config.TracingCollectorTarget) - } - if config.TracingSamplingRatio != tc.expectedTracingSamplingRatio { - t.Errorf("Expected tracingSamplingRatio %f, got %f", tc.expectedTracingSamplingRatio, config.TracingSamplingRatio) - } - if config.TracingTLSCertPath != tc.expectedTracingTLSCertPath { - t.Errorf("Expected tracingTLSCertPath %s, got %s", tc.expectedTracingTLSCertPath, config.TracingTLSCertPath) - } - if config.TelemetryEnabled != tc.expectedTelemetryEnabled { - t.Errorf("Expected telemetryEnabled %v, got %v", tc.expectedTelemetryEnabled, config.TelemetryEnabled) - } - if config.TelemetryEndpoint != tc.expectedTelemetryEndpoint { - t.Errorf("Expected telemetryEndpoint %s, got %s", tc.expectedTelemetryEndpoint, config.TelemetryEndpoint) - } - if config.TelemetryInsecureConnection != tc.expectedTelemetryInsecureConn { - t.Errorf("Expected telemetryInsecureConn %v, got %v", tc.expectedTelemetryInsecureConn, config.TelemetryInsecureConnection) - } - if config.TelemetryCACertFile != tc.expectedTelemetryCACertFile { - t.Errorf("Expected telemetryCACertFile %s, got %s", tc.expectedTelemetryCACertFile, config.TelemetryCACertFile) - } - if !maps.Equal(config.TelemetryAttributes, tc.expectedTelemetryAttributes) { - t.Errorf("Expected telemetryAttributes %v, got %v", tc.expectedTelemetryAttributes, config.TelemetryAttributes) - } - if config.TelemetryTraceSampleRatio != tc.expectedTelemetryTraceSampleRatio { - t.Errorf("Expected telemetryTraceSampleRatio %f, got %f", tc.expectedTelemetryTraceSampleRatio, config.TelemetryTraceSampleRatio) - } - if !maps.Equal(config.TelemetryAuthHeaders, tc.expectedTelemetryAuthHeaders) { - t.Errorf("Expected telemetryAuthHeaders %v, got %v", tc.expectedTelemetryAuthHeaders, config.TelemetryAuthHeaders) - } - if config.TelemetryAuthPubKeyHex != tc.expectedTelemetryAuthPubKeyHex { - t.Errorf("Expected telemetryAuthPubKeyHex %s, got %s", tc.expectedTelemetryAuthPubKeyHex, config.TelemetryAuthPubKeyHex) - } - if config.TelemetryEmitterBatchProcessor != tc.expectedTelemetryEmitterBatchProcessor { - t.Errorf("Expected telemetryEmitterBatchProcessor %v, got %v", tc.expectedTelemetryEmitterBatchProcessor, config.TelemetryEmitterBatchProcessor) - } - if config.TelemetryEmitterExportTimeout != tc.expectedTelemetryEmitterExportTimeout { - t.Errorf("Expected telemetryEmitterExportTimeout %v, got %v", tc.expectedTelemetryEmitterExportTimeout, config.TelemetryEmitterExportTimeout) - } - if config.TelemetryEmitterExportInterval != tc.expectedTelemetryEmitterExportInterval { - t.Errorf("Expected telemetryEmitterExportInterval %v, got %v", tc.expectedTelemetryEmitterExportInterval, config.TelemetryEmitterExportInterval) - } - if config.TelemetryEmitterExportMaxBatchSize != tc.expectedTelemetryEmitterExportMaxBatchSize { - t.Errorf("Expected telemetryEmitterExportMaxBatchSize %d, got %d", tc.expectedTelemetryEmitterExportMaxBatchSize, config.TelemetryEmitterExportMaxBatchSize) - } - if config.TelemetryEmitterMaxQueueSize != tc.expectedTelemetryEmitterMaxQueueSize { - t.Errorf("Expected telemetryEmitterMaxQueueSize %d, got %d", tc.expectedTelemetryEmitterMaxQueueSize, config.TelemetryEmitterMaxQueueSize) - } - if config.ChipIngressEndpoint != tc.expectedChipIngressEndpoint { - t.Errorf("Expected ChipIngressEndpoint %s, got %s", tc.expectedChipIngressEndpoint, config.ChipIngressEndpoint) - } - } + require.NoError(t, err) + require.Equal(t, tc.expectConfig, config) } }) } } -func equalOtelAttributes(a, b OtelAttributes) bool { - if len(a) != len(b) { - return false - } - for k, v := range a { - if b[k] != v { - return false - } - } - return true -} - -func equalStringMaps(a, b map[string]string) bool { - if len(a) != len(b) { - return false - } - for k, v := range a { - if b[k] != v { - return false - } - } - return true +var envCfgFull = EnvConfig{ + AppID: "app-id", + + DatabaseURL: config.MustSecretURL("postgres://user:password@localhost:5432/db"), + DatabaseIdleInTxSessionTimeout: 42 * time.Second, + DatabaseLockTimeout: 8 * time.Minute, + DatabaseQueryTimeout: 7 * time.Second, + DatabaseListenerFallbackPollInterval: 17 * time.Second, + DatabaseLogSQL: true, + DatabaseMaxOpenConns: 9999, + DatabaseMaxIdleConns: 8080, + + FeatureLogPoller: true, + + MercuryCacheLatestReportDeadline: time.Millisecond, + MercuryCacheLatestReportTTL: time.Microsecond, + MercuryCacheMaxStaleAge: time.Nanosecond, + + MercuryTransmitterProtocol: "foo", + MercuryTransmitterTransmitQueueMaxSize: 42, + MercuryTransmitterTransmitTimeout: time.Second, + MercuryTransmitterTransmitConcurrency: 13, + MercuryTransmitterReaperFrequency: time.Hour, + MercuryTransmitterReaperMaxAge: time.Minute, + MercuryVerboseLogging: true, + + PrometheusPort: 8080, + + TracingEnabled: true, + TracingAttributes: map[string]string{"XYZ": "value"}, + TracingCollectorTarget: "some:target", + TracingSamplingRatio: 1.0, + TracingTLSCertPath: "internal/test/fixtures/client.pem", + + TelemetryEnabled: true, + TelemetryEndpoint: "example.com/beholder", + TelemetryInsecureConnection: true, + TelemetryCACertFile: "foo/bar", + TelemetryAttributes: OtelAttributes{"foo": "bar", "baz": "42"}, + TelemetryTraceSampleRatio: 0.42, + TelemetryAuthHeaders: map[string]string{"header-key": "header-value"}, + TelemetryAuthPubKeyHex: "pub-key-hex", + TelemetryEmitterBatchProcessor: true, + TelemetryEmitterExportTimeout: 1 * time.Second, + TelemetryEmitterExportInterval: 2 * time.Second, + TelemetryEmitterExportMaxBatchSize: 100, + TelemetryEmitterMaxQueueSize: 1000, + + ChipIngressEndpoint: "http://chip-ingress.example.com", } func TestEnvConfig_AsCmdEnv(t *testing.T) { - envCfg := EnvConfig{ - DatabaseURL: (*config.SecretURL)(&url.URL{Scheme: "postgres", Host: "localhost:5432", User: url.UserPassword("user", "password"), Path: "/db"}), - PrometheusPort: 9090, - - TracingEnabled: true, - TracingCollectorTarget: "http://localhost:9000", - TracingSamplingRatio: 0.1, - TracingTLSCertPath: "some/path", - TracingAttributes: map[string]string{"key": "value"}, - - TelemetryEnabled: true, - TelemetryEndpoint: "example.com/beholder", - TelemetryInsecureConnection: true, - TelemetryCACertFile: "foo/bar", - TelemetryAttributes: OtelAttributes{"foo": "bar", "baz": "42"}, - TelemetryTraceSampleRatio: 0.42, - TelemetryAuthHeaders: map[string]string{"header-key": "header-value"}, - TelemetryAuthPubKeyHex: "pub-key-hex", - TelemetryEmitterBatchProcessor: true, - TelemetryEmitterExportTimeout: 1 * time.Second, - TelemetryEmitterExportInterval: 2 * time.Second, - TelemetryEmitterExportMaxBatchSize: 100, - TelemetryEmitterMaxQueueSize: 1000, - - ChipIngressEndpoint: "http://chip-ingress.example.com", - } got := map[string]string{} - for _, kv := range envCfg.AsCmdEnv() { + for _, kv := range envCfgFull.AsCmdEnv() { pair := strings.SplitN(kv, "=", 2) require.Len(t, pair, 2) got[pair[0]] = pair[1] } assert.Equal(t, "postgres://user:password@localhost:5432/db", got[envDatabaseURL]) - assert.Equal(t, strconv.Itoa(9090), got[envPromPort]) + + assert.Equal(t, "1ms", got[envMercuryCacheLatestReportDeadline]) + assert.Equal(t, "1µs", got[envMercuryCacheLatestReportTTL]) + assert.Equal(t, "1ns", got[envMercuryCacheMaxStaleAge]) + assert.Equal(t, "foo", got[envMercuryTransmitterProtocol]) + assert.Equal(t, "42", got[envMercuryTransmitterTransmitQueueMaxSize]) + assert.Equal(t, "1s", got[envMercuryTransmitterTransmitTimeout]) + assert.Equal(t, "13", got[envMercuryTransmitterTransmitConcurrency]) + assert.Equal(t, "1h0m0s", got[envMercuryTransmitterReaperFrequency]) + assert.Equal(t, "1m0s", got[envMercuryTransmitterReaperMaxAge]) + assert.Equal(t, "true", got[envMercuryVerboseLogging]) + + assert.Equal(t, strconv.Itoa(8080), got[envPromPort]) assert.Equal(t, "true", got[envTracingEnabled]) - assert.Equal(t, "http://localhost:9000", got[envTracingCollectorTarget]) - assert.Equal(t, "0.1", got[envTracingSamplingRatio]) - assert.Equal(t, "some/path", got[envTracingTLSCertPath]) - assert.Equal(t, "value", got[envTracingAttribute+"key"]) + assert.Equal(t, "some:target", got[envTracingCollectorTarget]) + assert.Equal(t, "1", got[envTracingSamplingRatio]) + assert.Equal(t, "internal/test/fixtures/client.pem", got[envTracingTLSCertPath]) + assert.Equal(t, "value", got[envTracingAttribute+"XYZ"]) assert.Equal(t, "true", got[envTelemetryEnabled]) assert.Equal(t, "example.com/beholder", got[envTelemetryEndpoint]) diff --git a/pkg/loop/internal/example-relay/main.go b/pkg/loop/internal/example-relay/main.go index 635dcb839..c74c1f83b 100644 --- a/pkg/loop/internal/example-relay/main.go +++ b/pkg/loop/internal/example-relay/main.go @@ -60,7 +60,7 @@ func (p *pluginRelayer) HealthReport() map[string]error { return map[string]erro func (p *pluginRelayer) Name() string { return p.lggr.Name() } -func (p *pluginRelayer) NewRelayer(ctx context.Context, config string, keystore core.Keystore, cr core.CapabilitiesRegistry) (loop.Relayer, error) { +func (p *pluginRelayer) NewRelayer(ctx context.Context, config string, keystore, csaKeystore core.Keystore, cr core.CapabilitiesRegistry) (loop.Relayer, error) { return &relayer{lggr: logger.Named(p.lggr, "Relayer"), ds: p.ds}, nil } diff --git a/pkg/loop/internal/pb/relayer.pb.go b/pkg/loop/internal/pb/relayer.pb.go index eee091b12..dfe17c046 100644 --- a/pkg/loop/internal/pb/relayer.pb.go +++ b/pkg/loop/internal/pb/relayer.pb.go @@ -28,6 +28,7 @@ type NewRelayerRequest struct { Config string `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"` // toml (is chain instance config enough?) KeystoreID uint32 `protobuf:"varint,2,opt,name=keystoreID,proto3" json:"keystoreID,omitempty"` CapabilityRegistryID uint32 `protobuf:"varint,3,opt,name=capabilityRegistryID,proto3" json:"capabilityRegistryID,omitempty"` + KeystoreCSAID uint32 `protobuf:"varint,4,opt,name=keystoreCSAID,proto3" json:"keystoreCSAID,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -83,6 +84,13 @@ func (x *NewRelayerRequest) GetCapabilityRegistryID() uint32 { return 0 } +func (x *NewRelayerRequest) GetKeystoreCSAID() uint32 { + if x != nil { + return x.KeystoreCSAID + } + return 0 +} + type NewRelayerReply struct { state protoimpl.MessageState `protogen:"open.v1"` RelayerID uint32 `protobuf:"varint,1,opt,name=relayerID,proto3" json:"relayerID,omitempty"` @@ -2509,13 +2517,14 @@ var File_loop_internal_pb_relayer_proto protoreflect.FileDescriptor const file_loop_internal_pb_relayer_proto_rawDesc = "" + "\n" + - "\x1eloop/internal/pb/relayer.proto\x12\x04loop\x1a\x1bgoogle/protobuf/empty.proto\x1a\x1cgoogle/protobuf/struct.proto\x1a&loop/internal/pb/contract_reader.proto\"\x7f\n" + + "\x1eloop/internal/pb/relayer.proto\x12\x04loop\x1a\x1bgoogle/protobuf/empty.proto\x1a\x1cgoogle/protobuf/struct.proto\x1a&loop/internal/pb/contract_reader.proto\"\xa5\x01\n" + "\x11NewRelayerRequest\x12\x16\n" + "\x06config\x18\x01 \x01(\tR\x06config\x12\x1e\n" + "\n" + "keystoreID\x18\x02 \x01(\rR\n" + "keystoreID\x122\n" + - "\x14capabilityRegistryID\x18\x03 \x01(\rR\x14capabilityRegistryID\"/\n" + + "\x14capabilityRegistryID\x18\x03 \x01(\rR\x14capabilityRegistryID\x12$\n" + + "\rkeystoreCSAID\x18\x04 \x01(\rR\rkeystoreCSAID\"/\n" + "\x0fNewRelayerReply\x12\x1c\n" + "\trelayerID\x18\x01 \x01(\rR\trelayerID\"+\n" + "\rAccountsReply\x12\x1a\n" + diff --git a/pkg/loop/internal/pb/relayer.proto b/pkg/loop/internal/pb/relayer.proto index 58aff16a6..aa5e3b246 100644 --- a/pkg/loop/internal/pb/relayer.proto +++ b/pkg/loop/internal/pb/relayer.proto @@ -16,8 +16,7 @@ message NewRelayerRequest { string config = 1; // toml (is chain instance config enough?) uint32 keystoreID = 2; uint32 capabilityRegistryID = 3; - - //TODO prometheus? https://smartcontract-it.atlassian.net/browse/BCF-2075 + uint32 keystoreCSAID = 4; } message NewRelayerReply { diff --git a/pkg/loop/internal/relayer/relayer.go b/pkg/loop/internal/relayer/relayer.go index c44600189..616bfa7f3 100644 --- a/pkg/loop/internal/relayer/relayer.go +++ b/pkg/loop/internal/relayer/relayer.go @@ -46,10 +46,10 @@ func NewPluginRelayerClient(brokerCfg net.BrokerConfig) *PluginRelayerClient { return &PluginRelayerClient{PluginClient: pc, pluginRelayer: pb.NewPluginRelayerClient(pc), ServiceClient: goplugin.NewServiceClient(pc.BrokerExt, pc)} } -func (p *PluginRelayerClient) NewRelayer(ctx context.Context, config string, keystore core.Keystore, capabilityRegistry core.CapabilitiesRegistry) (looptypes.Relayer, error) { - cc := p.NewClientConn("Relayer", func(ctx context.Context) (id uint32, deps net.Resources, err error) { +func (p *PluginRelayerClient) NewRelayer(ctx context.Context, config string, keystore, csaKeystore core.Keystore, capabilityRegistry core.CapabilitiesRegistry) (looptypes.Relayer, error) { + cc := p.NewClientConn("Relayer", func(ctx context.Context) (relayerID uint32, deps net.Resources, err error) { var ksRes net.Resource - id, ksRes, err = p.ServeNew("Keystore", func(s *grpc.Server) { + ksID, ksRes, err := p.ServeNew("Keystore", func(s *grpc.Server) { pb.RegisterKeystoreServer(s, &keystoreServer{impl: keystore}) }) if err != nil { @@ -57,6 +57,15 @@ func (p *PluginRelayerClient) NewRelayer(ctx context.Context, config string, key } deps.Add(ksRes) + var ksCSARes net.Resource + ksCSAID, ksCSARes, err := p.ServeNew("CSAKeystore", func(s *grpc.Server) { + pb.RegisterKeystoreServer(s, &keystoreServer{impl: csaKeystore}) + }) + if err != nil { + return 0, nil, fmt.Errorf("Failed to create relayer client: failed to serve CSA keystore: %w", err) + } + deps.Add(ksCSARes) + capabilityRegistryID, capabilityRegistryResource, err := p.ServeNew("CapabilitiesRegistry", func(s *grpc.Server) { pb.RegisterCapabilitiesRegistryServer(s, capability.NewCapabilitiesRegistryServer(p.BrokerExt, capabilityRegistry)) }) @@ -67,7 +76,8 @@ func (p *PluginRelayerClient) NewRelayer(ctx context.Context, config string, key reply, err := p.pluginRelayer.NewRelayer(ctx, &pb.NewRelayerRequest{ Config: config, - KeystoreID: id, + KeystoreID: ksID, + KeystoreCSAID: ksCSAID, CapabilityRegistryID: capabilityRegistryID, }) if err != nil { @@ -102,22 +112,31 @@ func (p *pluginRelayerServer) NewRelayer(ctx context.Context, request *pb.NewRel if err != nil { return nil, net.ErrConnDial{Name: "Keystore", ID: request.KeystoreID, Err: err} } - ksRes := net.Resource{Closer: ksConn, Name: "CapabilityRegistry"} + ksRes := net.Resource{Closer: ksConn, Name: "Keystore"} + + ksCSAConn, err := p.Dial(request.KeystoreCSAID) + if err != nil { + p.CloseAll(ksRes) + return nil, net.ErrConnDial{Name: "CSAKeystore", ID: request.KeystoreCSAID, Err: err} + } + ksCSARes := net.Resource{Closer: ksConn, Name: "CSAKeystore"} + capRegistryConn, err := p.Dial(request.CapabilityRegistryID) if err != nil { + p.CloseAll(ksRes, ksCSARes) return nil, net.ErrConnDial{Name: "CapabilityRegistry", ID: request.CapabilityRegistryID, Err: err} } crRes := net.Resource{Closer: capRegistryConn, Name: "CapabilityRegistry"} capRegistry := capability.NewCapabilitiesRegistryClient(capRegistryConn, p.BrokerExt) - r, err := p.impl.NewRelayer(ctx, request.Config, newKeystoreClient(ksConn), capRegistry) + r, err := p.impl.NewRelayer(ctx, request.Config, newKeystoreClient(ksConn), newKeystoreClient(ksCSAConn), capRegistry) if err != nil { - p.CloseAll(ksRes, crRes) + p.CloseAll(ksRes, ksCSARes, crRes) return nil, err } err = r.Start(ctx) if err != nil { - p.CloseAll(ksRes, crRes) + p.CloseAll(ksRes, ksCSARes, crRes) return nil, err } @@ -129,7 +148,7 @@ func (p *pluginRelayerServer) NewRelayer(ctx context.Context, request *pb.NewRel if evmService, ok := r.(types.EVMService); ok { evmpb.RegisterEVMServer(s, newEVMServer(evmService, p.BrokerExt)) } - }, rRes, ksRes, crRes) + }, rRes, ksRes, ksCSARes, crRes) if err != nil { return nil, err } diff --git a/pkg/loop/internal/relayer/test/relayer.go b/pkg/loop/internal/relayer/test/relayer.go index e10397d2e..7af353998 100644 --- a/pkg/loop/internal/relayer/test/relayer.go +++ b/pkg/loop/internal/relayer/test/relayer.go @@ -141,7 +141,7 @@ func (s staticPluginRelayer) HealthReport() map[string]error { return hp } -func (s staticPluginRelayer) NewRelayer(ctx context.Context, config string, keystore core.Keystore, capabilityRegistry core.CapabilitiesRegistry) (looptypes.Relayer, error) { +func (s staticPluginRelayer) NewRelayer(ctx context.Context, config string, keystore, csaKeystore core.Keystore, capabilityRegistry core.CapabilitiesRegistry) (looptypes.Relayer, error) { if s.relayer.StaticChecks && config != ConfigTOML { return nil, fmt.Errorf("expected config %q but got %q", ConfigTOML, config) } @@ -152,6 +152,13 @@ func (s staticPluginRelayer) NewRelayer(ctx context.Context, config string, keys if len(keys) == 0 { return nil, fmt.Errorf("expected at least one key but got none") } + keys, err = csaKeystore.Accounts(ctx) + if err != nil { + return nil, err + } + if len(keys) == 0 { + return nil, fmt.Errorf("expected at least one CSA key but got none") + } return s.relayer, nil } @@ -433,7 +440,7 @@ func newRelayArgsWithProviderType(_type types.OCR2PluginType) types.RelayArgs { func RunPlugin(t *testing.T, p looptypes.PluginRelayer) { t.Run("Relayer", func(t *testing.T) { ctx := t.Context() - relayer, err := p.NewRelayer(ctx, ConfigTOML, keystoretest.Keystore, nil) + relayer, err := p.NewRelayer(ctx, ConfigTOML, keystoretest.Keystore, keystoretest.Keystore, nil) require.NoError(t, err) servicetest.Run(t, relayer) Run(t, relayer) @@ -480,7 +487,7 @@ func RunFuzzPluginRelayer(f *testing.F, relayerFunc func(*testing.T) looptypes.P } ctx := t.Context() - _, err := relayerFunc(t).NewRelayer(ctx, fConfig, keystore, nil) + _, err := relayerFunc(t).NewRelayer(ctx, fConfig, keystore, keystore, nil) grpcUnavailableErr(t, err) }) diff --git a/pkg/loop/internal/types/types.go b/pkg/loop/internal/types/types.go index bf4132cf2..7b1f384a9 100644 --- a/pkg/loop/internal/types/types.go +++ b/pkg/loop/internal/types/types.go @@ -11,7 +11,7 @@ import ( type PluginRelayer interface { services.Service - NewRelayer(ctx context.Context, config string, keystore core.Keystore, capabilityRegistry core.CapabilitiesRegistry) (Relayer, error) + NewRelayer(ctx context.Context, config string, keystore, csaKeystore core.Keystore, capabilityRegistry core.CapabilitiesRegistry) (Relayer, error) } type MedianProvider interface { diff --git a/pkg/loop/plugin_median_test.go b/pkg/loop/plugin_median_test.go index e3d29ab9c..b2cf94050 100644 --- a/pkg/loop/plugin_median_test.go +++ b/pkg/loop/plugin_median_test.go @@ -85,7 +85,7 @@ func newStopCh(t *testing.T) <-chan struct{} { func newMedianProvider(t *testing.T, pr loop.PluginRelayer) types.MedianProvider { ctx := t.Context() - r, err := pr.NewRelayer(ctx, test.ConfigTOML, keystoretest.Keystore, nil) + r, err := pr.NewRelayer(ctx, test.ConfigTOML, keystoretest.Keystore, keystoretest.Keystore, nil) require.NoError(t, err) servicetest.Run(t, r) p, err := r.NewPluginProvider(ctx, relayertest.RelayArgs, relayertest.PluginArgs) diff --git a/pkg/loop/plugin_mercury_test.go b/pkg/loop/plugin_mercury_test.go index 20e16142b..d349b700e 100644 --- a/pkg/loop/plugin_mercury_test.go +++ b/pkg/loop/plugin_mercury_test.go @@ -67,7 +67,7 @@ func TestPluginMercuryExec(t *testing.T) { func newMercuryProvider(t *testing.T, pr loop.PluginRelayer) types.MercuryProvider { ctx := t.Context() - r, err := pr.NewRelayer(ctx, test.ConfigTOML, keystoretest.Keystore, nil) + r, err := pr.NewRelayer(ctx, test.ConfigTOML, keystoretest.Keystore, keystoretest.Keystore, nil) require.NoError(t, err) servicetest.Run(t, r) p, err := r.NewPluginProvider(ctx, mercurytest.RelayArgs, mercurytest.PluginArgs) diff --git a/pkg/loop/plugin_relayer_test.go b/pkg/loop/plugin_relayer_test.go index 62aa7a7fe..44b650df1 100644 --- a/pkg/loop/plugin_relayer_test.go +++ b/pkg/loop/plugin_relayer_test.go @@ -58,7 +58,7 @@ func FuzzRelayer(f *testing.F) { p := newPluginRelayerExec(t, false, stopCh) ctx := t.Context() capRegistry := mocks.NewCapabilitiesRegistry(t) - relayer, err := p.NewRelayer(ctx, test.ConfigTOML, keystoretest.Keystore, capRegistry) + relayer, err := p.NewRelayer(ctx, test.ConfigTOML, keystoretest.Keystore, keystoretest.Keystore, capRegistry) require.NoError(t, err) diff --git a/pkg/loop/relayer_service.go b/pkg/loop/relayer_service.go index a67cd1574..1ede94eaa 100644 --- a/pkg/loop/relayer_service.go +++ b/pkg/loop/relayer_service.go @@ -22,13 +22,13 @@ type RelayerService struct { // NewRelayerService returns a new [*RelayerService]. // cmd must return a new exec.Cmd each time it is called. -func NewRelayerService(lggr logger.Logger, grpcOpts GRPCOpts, cmd func() *exec.Cmd, config string, keystore core.Keystore, capabilityRegistry core.CapabilitiesRegistry) *RelayerService { +func NewRelayerService(lggr logger.Logger, grpcOpts GRPCOpts, cmd func() *exec.Cmd, config string, keystore core.Keystore, csaKeystore core.Keystore, capabilityRegistry core.CapabilitiesRegistry) *RelayerService { newService := func(ctx context.Context, instance any) (Relayer, services.HealthReporter, error) { plug, ok := instance.(PluginRelayer) if !ok { return nil, nil, fmt.Errorf("expected PluginRelayer but got %T", instance) } - r, err := plug.NewRelayer(ctx, config, keystore, capabilityRegistry) + r, err := plug.NewRelayer(ctx, config, keystore, csaKeystore, capabilityRegistry) if err != nil { return nil, nil, fmt.Errorf("failed to create Relayer: %w", err) } diff --git a/pkg/loop/relayer_service_test.go b/pkg/loop/relayer_service_test.go index 9852dbd5e..394b87d29 100644 --- a/pkg/loop/relayer_service_test.go +++ b/pkg/loop/relayer_service_test.go @@ -43,7 +43,7 @@ func TestRelayerService(t *testing.T) { capRegistry := mocks.NewCapabilitiesRegistry(t) relayer := loop.NewRelayerService(logger.Test(t), loop.GRPCOpts{}, func() *exec.Cmd { return NewHelperProcessCommand(loop.PluginRelayerName, false, 0) - }, test.ConfigTOML, keystoretest.Keystore, capRegistry) + }, test.ConfigTOML, keystoretest.Keystore, keystoretest.Keystore, capRegistry) hook := relayer.XXXTestHook() servicetest.Run(t, relayer) @@ -83,12 +83,16 @@ func TestRelayerService_recovery(t *testing.T) { Command: loop.PluginRelayerName, Limit: int(limit.Add(1)), }.New() - }, test.ConfigTOML, keystoretest.Keystore, nil) + }, test.ConfigTOML, keystoretest.Keystore, keystoretest.Keystore, nil) servicetest.Run(t, relayer) relayertest.Run(t, relayer) - servicetest.AssertHealthReportNames(t, relayer.HealthReport(), relayerServiceNames[:2]...) + if hp := relayer.HealthReport(); len(hp) == 2 { + servicetest.AssertHealthReportNames(t, hp, relayerServiceNames[:2]...) + } else { + servicetest.AssertHealthReportNames(t, hp, relayerServiceNames...) + } } @@ -100,7 +104,7 @@ func TestRelayerService_HealthReport(t *testing.T) { capRegistry := mocks.NewCapabilitiesRegistry(t) s := loop.NewRelayerService(lggr, loop.GRPCOpts{}, func() *exec.Cmd { return HelperProcessCommand{Command: loop.PluginRelayerName}.New() - }, test.ConfigTOML, keystoretest.Keystore, capRegistry) + }, test.ConfigTOML, keystoretest.Keystore, keystoretest.Keystore, capRegistry) servicetest.AssertHealthReportNames(t, s.HealthReport(), relayerServiceNames[0]) diff --git a/pkg/loop/server.go b/pkg/loop/server.go index abe79ed8a..e2960bd55 100644 --- a/pkg/loop/server.go +++ b/pkg/loop/server.go @@ -9,6 +9,7 @@ import ( "github.com/jmoiron/sqlx" "go.opentelemetry.io/otel/attribute" + semconv "go.opentelemetry.io/otel/semconv/v1.17.0" "github.com/smartcontractkit/chainlink-common/pkg/beholder" "github.com/smartcontractkit/chainlink-common/pkg/config/build" @@ -51,6 +52,7 @@ func MustNewStartedServer(loggerName string) *Server { // Server holds common plugin server fields. type Server struct { + EnvConfig EnvConfig GRPCOpts GRPCOpts Logger logger.SugaredLogger db *sqlx.DB // optional @@ -81,21 +83,25 @@ func (s *Server) start() error { stopAfter := context.AfterFunc(ctx, stopSig) defer stopAfter() - var envCfg EnvConfig - if err := envCfg.parse(); err != nil { + if err := s.EnvConfig.parse(); err != nil { return fmt.Errorf("error getting environment configuration: %w", err) } + tracingAttrs := s.EnvConfig.TracingAttributes + if tracingAttrs == nil { + tracingAttrs = make(map[string]string, 1) + } + tracingAttrs[string(semconv.ServiceInstanceIDKey)] = s.EnvConfig.AppID tracingConfig := TracingConfig{ - Enabled: envCfg.TracingEnabled, - CollectorTarget: envCfg.TracingCollectorTarget, - SamplingRatio: envCfg.TracingSamplingRatio, - TLSCertPath: envCfg.TracingTLSCertPath, - NodeAttributes: envCfg.TracingAttributes, + Enabled: s.EnvConfig.TracingEnabled, + CollectorTarget: s.EnvConfig.TracingCollectorTarget, + SamplingRatio: s.EnvConfig.TracingSamplingRatio, + TLSCertPath: s.EnvConfig.TracingTLSCertPath, + NodeAttributes: tracingAttrs, OnDialError: func(err error) { s.Logger.Errorw("Failed to dial", "err", err) }, } - if envCfg.TelemetryEndpoint == "" { + if s.EnvConfig.TelemetryEndpoint == "" { err := SetupTracing(tracingConfig) if err != nil { return fmt.Errorf("failed to setup tracing: %w", err) @@ -106,20 +112,20 @@ func (s *Server) start() error { attributes = tracingConfig.Attributes() } beholderCfg := beholder.Config{ - InsecureConnection: envCfg.TelemetryInsecureConnection, - CACertFile: envCfg.TelemetryCACertFile, - OtelExporterGRPCEndpoint: envCfg.TelemetryEndpoint, - ResourceAttributes: append(attributes, envCfg.TelemetryAttributes.AsStringAttributes()...), - TraceSampleRatio: envCfg.TelemetryTraceSampleRatio, - AuthHeaders: envCfg.TelemetryAuthHeaders, - AuthPublicKeyHex: envCfg.TelemetryAuthPubKeyHex, - EmitterBatchProcessor: envCfg.TelemetryEmitterBatchProcessor, - EmitterExportTimeout: envCfg.TelemetryEmitterExportTimeout, - EmitterExportInterval: envCfg.TelemetryEmitterExportInterval, - EmitterExportMaxBatchSize: envCfg.TelemetryEmitterExportMaxBatchSize, - EmitterMaxQueueSize: envCfg.TelemetryEmitterMaxQueueSize, - ChipIngressEmitterEnabled: envCfg.ChipIngressEndpoint != "", - ChipIngressEmitterGRPCEndpoint: envCfg.ChipIngressEndpoint, + InsecureConnection: s.EnvConfig.TelemetryInsecureConnection, + CACertFile: s.EnvConfig.TelemetryCACertFile, + OtelExporterGRPCEndpoint: s.EnvConfig.TelemetryEndpoint, + ResourceAttributes: append(attributes, s.EnvConfig.TelemetryAttributes.AsStringAttributes()...), + TraceSampleRatio: s.EnvConfig.TelemetryTraceSampleRatio, + AuthHeaders: s.EnvConfig.TelemetryAuthHeaders, + AuthPublicKeyHex: s.EnvConfig.TelemetryAuthPubKeyHex, + EmitterBatchProcessor: s.EnvConfig.TelemetryEmitterBatchProcessor, + EmitterExportTimeout: s.EnvConfig.TelemetryEmitterExportTimeout, + EmitterExportInterval: s.EnvConfig.TelemetryEmitterExportInterval, + EmitterExportMaxBatchSize: s.EnvConfig.TelemetryEmitterExportMaxBatchSize, + EmitterMaxQueueSize: s.EnvConfig.TelemetryEmitterMaxQueueSize, + ChipIngressEmitterEnabled: s.EnvConfig.ChipIngressEndpoint != "", + ChipIngressEmitterGRPCEndpoint: s.EnvConfig.ChipIngressEndpoint, } if tracingConfig.Enabled { @@ -141,7 +147,7 @@ func (s *Server) start() error { beholder.SetGlobalOtelProviders() } - s.promServer = NewPromServer(envCfg.PrometheusPort, s.Logger) + s.promServer = NewPromServer(s.EnvConfig.PrometheusPort, s.Logger) if err := s.promServer.Start(); err != nil { return fmt.Errorf("error starting prometheus server: %w", err) } @@ -151,22 +157,22 @@ func (s *Server) start() error { return fmt.Errorf("error starting health checker: %w", err) } - if envCfg.DatabaseURL != nil { - pg.SetApplicationName(envCfg.DatabaseURL.URL(), build.Program) - dbURL := envCfg.DatabaseURL.URL().String() + if s.EnvConfig.DatabaseURL != nil { + pg.SetApplicationName(s.EnvConfig.DatabaseURL.URL(), build.Program) + dbURL := s.EnvConfig.DatabaseURL.URL().String() var err error s.db, err = pg.DBConfig{ - IdleInTxSessionTimeout: envCfg.DatabaseIdleInTxSessionTimeout, - LockTimeout: envCfg.DatabaseLockTimeout, - MaxOpenConns: envCfg.DatabaseMaxOpenConns, - MaxIdleConns: envCfg.DatabaseMaxIdleConns, + IdleInTxSessionTimeout: s.EnvConfig.DatabaseIdleInTxSessionTimeout, + LockTimeout: s.EnvConfig.DatabaseLockTimeout, + MaxOpenConns: s.EnvConfig.DatabaseMaxOpenConns, + MaxIdleConns: s.EnvConfig.DatabaseMaxIdleConns, }.New(ctx, dbURL, pg.DriverPostgres) if err != nil { return fmt.Errorf("error connecting to DataBase: %w", err) } s.DataSource = sqlutil.WrapDataSource(s.db, s.Logger, - sqlutil.TimeoutHook(func() time.Duration { return envCfg.DatabaseQueryTimeout }), - sqlutil.MonitorHook(func() bool { return envCfg.DatabaseLogSQL })) + sqlutil.TimeoutHook(func() time.Duration { return s.EnvConfig.DatabaseQueryTimeout }), + sqlutil.MonitorHook(func() bool { return s.EnvConfig.DatabaseLogSQL })) s.dbStatsReporter = pg.NewStatsReporter(s.db.Stats, s.Logger) s.dbStatsReporter.Start() From 54755c8caea7637e627f489b1603c74b1d80cd1d Mon Sep 17 00:00:00 2001 From: Silas Lenihan Date: Tue, 10 Jun 2025 07:52:39 -0400 Subject: [PATCH 12/16] Revert "Seed random for setup and modes (#1236)" This reverts commit bbf13d4e5c0428ed03830c37e5e9cc39e2484602. --- .mockery.yaml | 2 +- pkg/workflows/internal/v2/sdkimpl/runtime.go | 67 +---- .../internal/v2/sdkimpl/runtime_test.go | 74 +----- pkg/workflows/sdk/v2/runtime.go | 2 - pkg/workflows/sdk/v2/runtime_test.go | 9 - pkg/workflows/sdk/v2/testutils/runner.go | 21 +- pkg/workflows/sdk/v2/testutils/runtime.go | 92 +++---- pkg/workflows/wasm/host/execution.go | 24 +- .../host/mock_capability_executor_test.go | 96 ++++++++ .../wasm/host/mock_execution_helper_test.go | 233 ------------------ pkg/workflows/wasm/host/mocks/module_v2.go | 16 +- pkg/workflows/wasm/host/module.go | 54 +--- pkg/workflows/wasm/host/module_test.go | 6 +- .../wasm/host/test/nodag/randoms/cmd/main.go | 54 ---- pkg/workflows/wasm/host/wasip1.go | 34 +-- pkg/workflows/wasm/host/wasm_nodag_test.go | 124 +--------- pkg/workflows/wasm/v2/runner.go | 1 - pkg/workflows/wasm/v2/runner_test_hooks.go | 7 - pkg/workflows/wasm/v2/runner_wasip1.go | 10 - pkg/workflows/wasm/v2/runtime.go | 114 ++++----- pkg/workflows/wasm/v2/runtime_test.go | 25 -- pkg/workflows/wasm/v2/runtime_test_hooks.go | 13 - pkg/workflows/wasm/v2/runtime_wasip1.go | 11 - 23 files changed, 240 insertions(+), 849 deletions(-) create mode 100644 pkg/workflows/wasm/host/mock_capability_executor_test.go delete mode 100644 pkg/workflows/wasm/host/mock_execution_helper_test.go delete mode 100644 pkg/workflows/wasm/host/test/nodag/randoms/cmd/main.go diff --git a/.mockery.yaml b/.mockery.yaml index ae7134afa..bacecec83 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -38,7 +38,7 @@ packages: interfaces: ModuleV1: {} ModuleV2: {} - ExecutionHelper: + CapabilityExecutor: config: inpackage: true filename: "mock_{{.InterfaceName | snakecase}}_test.go" diff --git a/pkg/workflows/internal/v2/sdkimpl/runtime.go b/pkg/workflows/internal/v2/sdkimpl/runtime.go index 47ecc01be..9cfcbf876 100644 --- a/pkg/workflows/internal/v2/sdkimpl/runtime.go +++ b/pkg/workflows/internal/v2/sdkimpl/runtime.go @@ -4,7 +4,6 @@ import ( "fmt" "io" "log/slog" - "math/rand" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/consensus" "github.com/smartcontractkit/chainlink-common/pkg/values" @@ -13,30 +12,21 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/workflows/sdk/v2/pb" ) -type RuntimeHelpers interface { - Call(request *pb.CapabilityRequest) error - Await(request *pb.AwaitCapabilitiesRequest, maxResponseSize uint64) (*pb.AwaitCapabilitiesResponse, error) - SwitchModes(mode pb.Mode) - GetSource(mode pb.Mode) rand.Source -} +type CallCapabilityFn func(request *pb.CapabilityRequest) error +type AwaitCapabilitiesFn func(request *pb.AwaitCapabilitiesRequest, maxResponseSize uint64) (*pb.AwaitCapabilitiesResponse, error) type RuntimeBase struct { ConfigBytes []byte MaxResponseSize uint64 + Call CallCapabilityFn + Await AwaitCapabilitiesFn Writer io.Writer - RuntimeHelpers - source rand.Source - source64 rand.Source64 modeErr error Mode pb.Mode nextCallId int32 } -var _ sdk.RuntimeBase = (*RuntimeBase)(nil) -var _ rand.Source = (*RuntimeBase)(nil) -var _ rand.Source64 = (*RuntimeBase)(nil) - func (r *RuntimeBase) CallCapability(request *pb.CapabilityRequest) sdk.Promise[*pb.CapabilityResponse] { if r.Mode == pb.Mode_DON { r.nextCallId++ @@ -50,7 +40,7 @@ func (r *RuntimeBase) CallCapability(request *pb.CapabilityRequest) sdk.Promise[ return sdk.PromiseFromResult[*pb.CapabilityResponse](nil, r.modeErr) } - err := r.RuntimeHelpers.Call(request) + err := r.Call(request) if err != nil { return sdk.PromiseFromResult[*pb.CapabilityResponse](nil, err) } @@ -85,22 +75,6 @@ func (r *RuntimeBase) Logger() *slog.Logger { return slog.New(slog.NewTextHandler(r.LogWriter(), nil)) } -func (r *RuntimeBase) Rand() (*rand.Rand, error) { - if r.modeErr != nil { - return nil, r.modeErr - } - - if r.source == nil { - r.source = r.RuntimeHelpers.GetSource(r.Mode) - r64, ok := r.source.(rand.Source64) - if ok { - r.source64 = r64 - } - } - - return rand.New(r), nil -} - type DonRuntime struct { RuntimeBase nextNodeCallId int32 @@ -111,9 +85,7 @@ func (d *DonRuntime) RunInNodeMode(fn func(nodeRuntime sdk.NodeRuntime) *pb.Simp nrt.nextCallId = d.nextNodeCallId nrt.Mode = pb.Mode_Node d.modeErr = sdk.DonModeCallInNodeMode() - d.SwitchModes(pb.Mode_Node) observation := fn(nrt) - d.SwitchModes(pb.Mode_DON) nrt.modeErr = sdk.NodeModeCallInDonMode() d.modeErr = nil d.nextNodeCallId = nrt.nextCallId @@ -123,35 +95,6 @@ func (d *DonRuntime) RunInNodeMode(fn func(nodeRuntime sdk.NodeRuntime) *pb.Simp }) } -func (r *RuntimeBase) Int63() int64 { - if r.modeErr != nil { - panic("random cannot be used outside the mode it was created in") - } - - return r.source.Int63() -} - -func (r *RuntimeBase) Uint64() uint64 { - if r.modeErr != nil { - panic("random cannot be used outside the mode it was created in") - } - - // borrowed from math/rand - if r.source64 != nil { - return r.source64.Uint64() - } - - return uint64(r.source.Int63())>>31 | uint64(r.source.Int63())<<32 -} - -func (r *RuntimeBase) Seed(seed int64) { - if r.modeErr != nil { - panic("random cannot be used outside the mode it was created in") - } - - r.source.Seed(seed) -} - var _ sdk.DonRuntime = &DonRuntime{} type NodeRuntime struct { diff --git a/pkg/workflows/internal/v2/sdkimpl/runtime_test.go b/pkg/workflows/internal/v2/sdkimpl/runtime_test.go index c017f33c5..966e3b24c 100644 --- a/pkg/workflows/internal/v2/sdkimpl/runtime_test.go +++ b/pkg/workflows/internal/v2/sdkimpl/runtime_test.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "math/rand" "testing" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/consensus/consensusmock" @@ -111,11 +110,8 @@ func TestRuntime_CallCapability(t *testing.T) { test := func(rt sdk.DonRuntime, _ *basictrigger.Outputs) (string, error) { drt := rt.(*sdkimpl.DonRuntime) - drt.RuntimeHelpers = &awaitOverride{ - RuntimeHelpers: drt.RuntimeHelpers, - await: func(request *pb.AwaitCapabilitiesRequest, maxResponseSize uint64) (*pb.AwaitCapabilitiesResponse, error) { - return nil, expectedErr - }, + drt.Await = func(request *pb.AwaitCapabilitiesRequest, maxResponseSize uint64) (*pb.AwaitCapabilitiesResponse, error) { + return nil, expectedErr } _, err := capability.PerformAction(rt, &basicaction.Inputs{InputThing: true}).Await() return "", err @@ -137,12 +133,8 @@ func TestRuntime_CallCapability(t *testing.T) { test := func(rt sdk.DonRuntime, _ *basictrigger.Outputs) (string, error) { drt := rt.(*sdkimpl.DonRuntime) - - drt.RuntimeHelpers = &awaitOverride{ - RuntimeHelpers: drt.RuntimeHelpers, - await: func(request *pb.AwaitCapabilitiesRequest, maxResponseSize uint64) (*pb.AwaitCapabilitiesResponse, error) { - return &pb.AwaitCapabilitiesResponse{Responses: map[int32]*pb.CapabilityResponse{}}, nil - }, + drt.Await = func(request *pb.AwaitCapabilitiesRequest, maxResponseSize uint64) (*pb.AwaitCapabilitiesResponse, error) { + return &pb.AwaitCapabilitiesResponse{Responses: map[int32]*pb.CapabilityResponse{}}, nil } _, err := capability.PerformAction(rt, &basicaction.Inputs{InputThing: true}).Await() return "", err @@ -153,55 +145,6 @@ func TestRuntime_CallCapability(t *testing.T) { }) } -func TestRuntime_Rand(t *testing.T) { - t.Run("random delegates", func(t *testing.T) { - test := func(rt sdk.DonRuntime, _ *basictrigger.Outputs) (uint64, error) { - r, err := rt.Rand() - if err != nil { - return 0, err - } - return r.Uint64(), nil - } - - ran, result, err := testRuntime(t, test) - require.NoError(t, err) - assert.True(t, ran) - assert.Equal(t, rand.New(rand.NewSource(1)).Uint64(), result) - }) - - t.Run("random does not allow use in the wrong mode", func(t *testing.T) { - test := func(rt sdk.DonRuntime, _ *basictrigger.Outputs) (uint64, error) { - return sdk.RunInNodeMode[uint64](rt, func(_ sdk.NodeRuntime) (uint64, error) { - if _, err := rt.Rand(); err != nil { - return 0, err - } - - return 0, fmt.Errorf("should not be called in node mode") - }, sdk.ConsensusMedianAggregation[uint64]()).Await() - } - - _, _, err := testRuntime(t, test) - require.Error(t, err) - }) - - t.Run("returned random panics if you use it in the wrong mode ", func(t *testing.T) { - assert.Panics(t, func() { - test := func(rt sdk.DonRuntime, _ *basictrigger.Outputs) (uint64, error) { - r, err := rt.Rand() - if err != nil { - return 0, err - } - return sdk.RunInNodeMode[uint64](rt, func(_ sdk.NodeRuntime) (uint64, error) { - r.Uint64() - return 0, fmt.Errorf("should not be called in node mode") - }, sdk.ConsensusMedianAggregation[uint64]()).Await() - } - - _, _, _ = testRuntime(t, test) - }) - }) -} - func TestDonRuntime_RunInNodeMode(t *testing.T) { t.Run("Successful consensus", func(t *testing.T) { nodeMock, err := nodeactionmock.NewBasicActionCapability(t) @@ -373,12 +316,3 @@ type consensusValues struct { Err error Resp int64 } - -type awaitOverride struct { - sdkimpl.RuntimeHelpers - await func(request *pb.AwaitCapabilitiesRequest, maxResponseSize uint64) (*pb.AwaitCapabilitiesResponse, error) -} - -func (a *awaitOverride) Await(request *pb.AwaitCapabilitiesRequest, maxResponseSize uint64) (*pb.AwaitCapabilitiesResponse, error) { - return a.await(request, maxResponseSize) -} diff --git a/pkg/workflows/sdk/v2/runtime.go b/pkg/workflows/sdk/v2/runtime.go index 9c0d474ee..ed34c98a8 100644 --- a/pkg/workflows/sdk/v2/runtime.go +++ b/pkg/workflows/sdk/v2/runtime.go @@ -4,7 +4,6 @@ import ( "errors" "io" "log/slog" - "math/rand" "reflect" "github.com/smartcontractkit/chainlink-common/pkg/values" @@ -18,7 +17,6 @@ type RuntimeBase interface { Config() []byte LogWriter() io.Writer Logger() *slog.Logger - Rand() (*rand.Rand, error) } // NodeRuntime is not thread safe and must not be used concurrently. diff --git a/pkg/workflows/sdk/v2/runtime_test.go b/pkg/workflows/sdk/v2/runtime_test.go index 70ebd3182..f8c09572f 100644 --- a/pkg/workflows/sdk/v2/runtime_test.go +++ b/pkg/workflows/sdk/v2/runtime_test.go @@ -4,7 +4,6 @@ import ( "errors" "io" "log/slog" - "math/rand" "testing" "github.com/smartcontractkit/chainlink-common/pkg/values" @@ -96,10 +95,6 @@ func TestRunInNodeMode_ErrorWrappingDefault(t *testing.T) { // mockNodeRuntime implements NodeRuntime for testing. type mockNodeRuntime struct{} -func (m mockNodeRuntime) Rand() (*rand.Rand, error) { - panic("unused in tests") -} - func (m mockNodeRuntime) CallCapability(_ *pb.CapabilityRequest) sdk.Promise[*pb.CapabilityResponse] { panic("unused in tests") } @@ -120,10 +115,6 @@ func (m mockNodeRuntime) IsNodeRuntime() {} type mockDonRuntime struct{} -func (m *mockDonRuntime) Rand() (*rand.Rand, error) { - panic("unused in tests") -} - func (m *mockDonRuntime) RunInNodeMode(fn func(nodeRuntime sdk.NodeRuntime) *pb.SimpleConsensusInputs) sdk.Promise[values.Value] { req := fn(mockNodeRuntime{}) diff --git a/pkg/workflows/sdk/v2/testutils/runner.go b/pkg/workflows/sdk/v2/testutils/runner.go index 842b65d21..f6debf114 100644 --- a/pkg/workflows/sdk/v2/testutils/runner.go +++ b/pkg/workflows/sdk/v2/testutils/runner.go @@ -4,7 +4,6 @@ import ( "errors" "io" "log/slog" - "math/rand" "testing" "github.com/google/uuid" @@ -28,7 +27,6 @@ type runner[T any] struct { runtime T writer *testWriter base *sdkimpl.RuntimeBase - source rand.Source } func (r *runner[T]) Logs() []string { @@ -43,10 +41,6 @@ func (r *runner[T]) LogWriter() io.Writer { return r.writer } -func (r *runner[T]) SetRandSource(source rand.Source) { - r.source = source -} - type TestRunner interface { Result() (bool, any, error) @@ -59,8 +53,6 @@ type TestRunner interface { SetMaxResponseSizeBytes(maxResponseSizebytes uint64) Logs() []string - - SetRandSource(source rand.Source) } type DonRunner interface { @@ -75,18 +67,14 @@ type NodeRunner interface { func NewDonRunner(tb testing.TB, config []byte) DonRunner { writer := &testWriter{} - drt := &sdkimpl.DonRuntime{} - r := newRunner[sdk.DonRuntime](tb, config, writer, drt, &drt.RuntimeBase) - drt.RuntimeBase = newRuntime(tb, config, writer, func() rand.Source { return r.source }) - return r + drt := &sdkimpl.DonRuntime{RuntimeBase: newRuntime(tb, config, writer)} + return newRunner[sdk.DonRuntime](tb, config, writer, drt, &drt.RuntimeBase) } func NewNodeRunner(tb testing.TB, config []byte) NodeRunner { writer := &testWriter{} - nrt := &sdkimpl.NodeRuntime{} - r := newRunner[sdk.NodeRuntime](tb, config, writer, nrt, &nrt.RuntimeBase) - nrt.RuntimeBase = newRuntime(tb, config, writer, func() rand.Source { return r.source }) - return r + nrt := &sdkimpl.NodeRuntime{RuntimeBase: newRuntime(tb, config, writer)} + return newRunner[sdk.NodeRuntime](tb, config, writer, nrt, &nrt.RuntimeBase) } func newRunner[T any](tb testing.TB, config []byte, writer *testWriter, t T, base *sdkimpl.RuntimeBase) *runner[T] { @@ -98,7 +86,6 @@ func newRunner[T any](tb testing.TB, config []byte, writer *testWriter, t T, bas runtime: t, writer: writer, base: base, - source: rand.NewSource(1), } return r diff --git a/pkg/workflows/sdk/v2/testutils/runtime.go b/pkg/workflows/sdk/v2/testutils/runtime.go index fb3987621..da4d08c0d 100644 --- a/pkg/workflows/sdk/v2/testutils/runtime.go +++ b/pkg/workflows/sdk/v2/testutils/runtime.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "math/rand" "testing" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/consensus/consensusmock" @@ -16,7 +15,9 @@ import ( "google.golang.org/protobuf/proto" ) -func newRuntime(tb testing.TB, configBytes []byte, writer *testWriter, sourceFn func() rand.Source) sdkimpl.RuntimeBase { +func newRuntime(tb testing.TB, configBytes []byte, writer *testWriter) sdkimpl.RuntimeBase { + tb.Cleanup(func() { delete(calls, tb) }) + defaultConsensus, err := consensusmock.NewConsensusCapability(tb) // Do not override if the user provided their own consensus method @@ -27,8 +28,9 @@ func newRuntime(tb testing.TB, configBytes []byte, writer *testWriter, sourceFn return sdkimpl.RuntimeBase{ ConfigBytes: configBytes, MaxResponseSize: sdk.DefaultMaxResponseSizeBytes, + Call: createCallCapability(tb), + Await: createAwaitCapabilities(tb), Writer: writer, - RuntimeHelpers: &runtimeHelpers{tb: tb, calls: map[int32]chan *pb.CapabilityResponse{}, sourceFn: sourceFn}, } } @@ -47,55 +49,59 @@ func defaultSimpleConsensus(_ context.Context, input *pb.SimpleConsensusInputs) } } -type runtimeHelpers struct { - tb testing.TB - calls map[int32]chan *pb.CapabilityResponse - sourceFn func() rand.Source -} +var calls = map[testing.TB]map[int32]chan *pb.CapabilityResponse{} -func (rh *runtimeHelpers) GetSource(_ pb.Mode) rand.Source { - return rh.sourceFn() -} +func createCallCapability(tb testing.TB) func(request *pb.CapabilityRequest) error { + return func(request *pb.CapabilityRequest) error { + reg := registry.GetRegistry(tb) + capability, err := reg.GetCapability(request.Id) + if err != nil { + return err + } -func (rh *runtimeHelpers) Call(request *pb.CapabilityRequest) error { - reg := registry.GetRegistry(rh.tb) - capability, err := reg.GetCapability(request.Id) - if err != nil { - return err + respCh := make(chan *pb.CapabilityResponse, 1) + tbCalls, ok := calls[tb] + if !ok { + tbCalls = map[int32]chan *pb.CapabilityResponse{} + calls[tb] = tbCalls + } + tbCalls[request.CallbackId] = respCh + go func() { + respCh <- capability.Invoke(tb.Context(), request) + }() + return nil } - - respCh := make(chan *pb.CapabilityResponse, 1) - rh.calls[request.CallbackId] = respCh - go func() { - respCh <- capability.Invoke(rh.tb.Context(), request) - }() - return nil } -func (rh *runtimeHelpers) Await(request *pb.AwaitCapabilitiesRequest, maxResponseSize uint64) (*pb.AwaitCapabilitiesResponse, error) { - response := &pb.AwaitCapabilitiesResponse{Responses: map[int32]*pb.CapabilityResponse{}} +func createAwaitCapabilities(tb testing.TB) sdkimpl.AwaitCapabilitiesFn { + return func(request *pb.AwaitCapabilitiesRequest, maxResponseSize uint64) (*pb.AwaitCapabilitiesResponse, error) { + response := &pb.AwaitCapabilitiesResponse{Responses: map[int32]*pb.CapabilityResponse{}} - var errs []error - for _, id := range request.Ids { - ch, ok := rh.calls[id] + testCalls, ok := calls[tb] if !ok { - errs = append(errs, fmt.Errorf("no call found for %d", id)) - continue + return nil, errors.New("no calls found for this test") } - select { - case resp := <-ch: - response.Responses[id] = resp - case <-rh.tb.Context().Done(): - return nil, rh.tb.Context().Err() + + var errs []error + for _, id := range request.Ids { + ch, ok := testCalls[id] + if !ok { + errs = append(errs, fmt.Errorf("no call found for %d", id)) + continue + } + select { + case resp := <-ch: + response.Responses[id] = resp + case <-tb.Context().Done(): + return nil, tb.Context().Err() + } } - } - bytes, _ := proto.Marshal(response) - if len(bytes) > int(maxResponseSize) { - return nil, errors.New(sdk.ResponseBufferTooSmall) - } + bytes, _ := proto.Marshal(response) + if len(bytes) > int(maxResponseSize) { + return nil, errors.New(sdk.ResponseBufferTooSmall) + } - return response, errors.Join(errs...) + return response, errors.Join(errs...) + } } - -func (rh *runtimeHelpers) SwitchModes(_ pb.Mode) {} diff --git a/pkg/workflows/wasm/host/execution.go b/pkg/workflows/wasm/host/execution.go index 91d1d1675..90fec4aa3 100644 --- a/pkg/workflows/wasm/host/execution.go +++ b/pkg/workflows/wasm/host/execution.go @@ -16,11 +16,7 @@ type execution[T any] struct { capabilityResponses map[int32]<-chan *sdkpb.CapabilityResponse lock sync.RWMutex module *module - executor ExecutionHelper - hasRun bool - mode sdkpb.Mode - donSeed int64 - nodeSeed int64 + executor CapabilityExecutor } // callCapAsync async calls a capability by placing execution results onto a @@ -85,22 +81,6 @@ func (e *execution[T]) log(caller *wasmtime.Caller, ptr int32, ptrlen int32) { lggr.Errorf("error calling log: %s", innerErr) return } - + lggr.Info(string(b)) } - -func (e *execution[T]) getSeed(mode int32) int64 { - switch sdkpb.Mode(mode) { - case sdkpb.Mode_DON: - return e.donSeed - case sdkpb.Mode_Node: - return e.nodeSeed - } - - return -1 -} - -func (e *execution[T]) switchModes(_ *wasmtime.Caller, mode int32) { - e.hasRun = true - e.mode = sdkpb.Mode(mode) -} diff --git a/pkg/workflows/wasm/host/mock_capability_executor_test.go b/pkg/workflows/wasm/host/mock_capability_executor_test.go new file mode 100644 index 000000000..89ccc38b7 --- /dev/null +++ b/pkg/workflows/wasm/host/mock_capability_executor_test.go @@ -0,0 +1,96 @@ +// Code generated by mockery v2.53.3. DO NOT EDIT. + +package host + +import ( + context "context" + + pb "github.com/smartcontractkit/chainlink-common/pkg/workflows/sdk/v2/pb" + mock "github.com/stretchr/testify/mock" +) + +// MockCapabilityExecutor is an autogenerated mock type for the CapabilityExecutor type +type MockCapabilityExecutor struct { + mock.Mock +} + +type MockCapabilityExecutor_Expecter struct { + mock *mock.Mock +} + +func (_m *MockCapabilityExecutor) EXPECT() *MockCapabilityExecutor_Expecter { + return &MockCapabilityExecutor_Expecter{mock: &_m.Mock} +} + +// CallCapability provides a mock function with given fields: ctx, request +func (_m *MockCapabilityExecutor) CallCapability(ctx context.Context, request *pb.CapabilityRequest) (*pb.CapabilityResponse, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for CallCapability") + } + + var r0 *pb.CapabilityResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *pb.CapabilityRequest) (*pb.CapabilityResponse, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, *pb.CapabilityRequest) *pb.CapabilityResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*pb.CapabilityResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *pb.CapabilityRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockCapabilityExecutor_CallCapability_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CallCapability' +type MockCapabilityExecutor_CallCapability_Call struct { + *mock.Call +} + +// CallCapability is a helper method to define mock.On call +// - ctx context.Context +// - request *pb.CapabilityRequest +func (_e *MockCapabilityExecutor_Expecter) CallCapability(ctx interface{}, request interface{}) *MockCapabilityExecutor_CallCapability_Call { + return &MockCapabilityExecutor_CallCapability_Call{Call: _e.mock.On("CallCapability", ctx, request)} +} + +func (_c *MockCapabilityExecutor_CallCapability_Call) Run(run func(ctx context.Context, request *pb.CapabilityRequest)) *MockCapabilityExecutor_CallCapability_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*pb.CapabilityRequest)) + }) + return _c +} + +func (_c *MockCapabilityExecutor_CallCapability_Call) Return(_a0 *pb.CapabilityResponse, _a1 error) *MockCapabilityExecutor_CallCapability_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockCapabilityExecutor_CallCapability_Call) RunAndReturn(run func(context.Context, *pb.CapabilityRequest) (*pb.CapabilityResponse, error)) *MockCapabilityExecutor_CallCapability_Call { + _c.Call.Return(run) + return _c +} + +// NewMockCapabilityExecutor creates a new instance of MockCapabilityExecutor. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockCapabilityExecutor(t interface { + mock.TestingT + Cleanup(func()) +}) *MockCapabilityExecutor { + mock := &MockCapabilityExecutor{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/workflows/wasm/host/mock_execution_helper_test.go b/pkg/workflows/wasm/host/mock_execution_helper_test.go deleted file mode 100644 index 68e1efca8..000000000 --- a/pkg/workflows/wasm/host/mock_execution_helper_test.go +++ /dev/null @@ -1,233 +0,0 @@ -// Code generated by mockery v2.53.3. DO NOT EDIT. - -package host - -import ( - context "context" - - pb "github.com/smartcontractkit/chainlink-common/pkg/workflows/sdk/v2/pb" - mock "github.com/stretchr/testify/mock" - - time "time" -) - -// MockExecutionHelper is an autogenerated mock type for the ExecutionHelper type -type MockExecutionHelper struct { - mock.Mock -} - -type MockExecutionHelper_Expecter struct { - mock *mock.Mock -} - -func (_m *MockExecutionHelper) EXPECT() *MockExecutionHelper_Expecter { - return &MockExecutionHelper_Expecter{mock: &_m.Mock} -} - -// CallCapability provides a mock function with given fields: ctx, request -func (_m *MockExecutionHelper) CallCapability(ctx context.Context, request *pb.CapabilityRequest) (*pb.CapabilityResponse, error) { - ret := _m.Called(ctx, request) - - if len(ret) == 0 { - panic("no return value specified for CallCapability") - } - - var r0 *pb.CapabilityResponse - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, *pb.CapabilityRequest) (*pb.CapabilityResponse, error)); ok { - return rf(ctx, request) - } - if rf, ok := ret.Get(0).(func(context.Context, *pb.CapabilityRequest) *pb.CapabilityResponse); ok { - r0 = rf(ctx, request) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*pb.CapabilityResponse) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, *pb.CapabilityRequest) error); ok { - r1 = rf(ctx, request) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// MockExecutionHelper_CallCapability_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CallCapability' -type MockExecutionHelper_CallCapability_Call struct { - *mock.Call -} - -// CallCapability is a helper method to define mock.On call -// - ctx context.Context -// - request *pb.CapabilityRequest -func (_e *MockExecutionHelper_Expecter) CallCapability(ctx interface{}, request interface{}) *MockExecutionHelper_CallCapability_Call { - return &MockExecutionHelper_CallCapability_Call{Call: _e.mock.On("CallCapability", ctx, request)} -} - -func (_c *MockExecutionHelper_CallCapability_Call) Run(run func(ctx context.Context, request *pb.CapabilityRequest)) *MockExecutionHelper_CallCapability_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(*pb.CapabilityRequest)) - }) - return _c -} - -func (_c *MockExecutionHelper_CallCapability_Call) Return(_a0 *pb.CapabilityResponse, _a1 error) *MockExecutionHelper_CallCapability_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MockExecutionHelper_CallCapability_Call) RunAndReturn(run func(context.Context, *pb.CapabilityRequest) (*pb.CapabilityResponse, error)) *MockExecutionHelper_CallCapability_Call { - _c.Call.Return(run) - return _c -} - -// GetDONTime provides a mock function with no fields -func (_m *MockExecutionHelper) GetDONTime() time.Time { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for GetDONTime") - } - - var r0 time.Time - if rf, ok := ret.Get(0).(func() time.Time); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(time.Time) - } - - return r0 -} - -// MockExecutionHelper_GetDONTime_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDONTime' -type MockExecutionHelper_GetDONTime_Call struct { - *mock.Call -} - -// GetDONTime is a helper method to define mock.On call -func (_e *MockExecutionHelper_Expecter) GetDONTime() *MockExecutionHelper_GetDONTime_Call { - return &MockExecutionHelper_GetDONTime_Call{Call: _e.mock.On("GetDONTime")} -} - -func (_c *MockExecutionHelper_GetDONTime_Call) Run(run func()) *MockExecutionHelper_GetDONTime_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *MockExecutionHelper_GetDONTime_Call) Return(_a0 time.Time) *MockExecutionHelper_GetDONTime_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *MockExecutionHelper_GetDONTime_Call) RunAndReturn(run func() time.Time) *MockExecutionHelper_GetDONTime_Call { - _c.Call.Return(run) - return _c -} - -// GetId provides a mock function with no fields -func (_m *MockExecutionHelper) GetId() string { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for GetId") - } - - var r0 string - if rf, ok := ret.Get(0).(func() string); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(string) - } - - return r0 -} - -// MockExecutionHelper_GetId_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetId' -type MockExecutionHelper_GetId_Call struct { - *mock.Call -} - -// GetId is a helper method to define mock.On call -func (_e *MockExecutionHelper_Expecter) GetId() *MockExecutionHelper_GetId_Call { - return &MockExecutionHelper_GetId_Call{Call: _e.mock.On("GetId")} -} - -func (_c *MockExecutionHelper_GetId_Call) Run(run func()) *MockExecutionHelper_GetId_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *MockExecutionHelper_GetId_Call) Return(_a0 string) *MockExecutionHelper_GetId_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *MockExecutionHelper_GetId_Call) RunAndReturn(run func() string) *MockExecutionHelper_GetId_Call { - _c.Call.Return(run) - return _c -} - -// GetNodeTime provides a mock function with no fields -func (_m *MockExecutionHelper) GetNodeTime() time.Time { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for GetNodeTime") - } - - var r0 time.Time - if rf, ok := ret.Get(0).(func() time.Time); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(time.Time) - } - - return r0 -} - -// MockExecutionHelper_GetNodeTime_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetNodeTime' -type MockExecutionHelper_GetNodeTime_Call struct { - *mock.Call -} - -// GetNodeTime is a helper method to define mock.On call -func (_e *MockExecutionHelper_Expecter) GetNodeTime() *MockExecutionHelper_GetNodeTime_Call { - return &MockExecutionHelper_GetNodeTime_Call{Call: _e.mock.On("GetNodeTime")} -} - -func (_c *MockExecutionHelper_GetNodeTime_Call) Run(run func()) *MockExecutionHelper_GetNodeTime_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *MockExecutionHelper_GetNodeTime_Call) Return(_a0 time.Time) *MockExecutionHelper_GetNodeTime_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *MockExecutionHelper_GetNodeTime_Call) RunAndReturn(run func() time.Time) *MockExecutionHelper_GetNodeTime_Call { - _c.Call.Return(run) - return _c -} - -// NewMockExecutionHelper creates a new instance of MockExecutionHelper. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewMockExecutionHelper(t interface { - mock.TestingT - Cleanup(func()) -}) *MockExecutionHelper { - mock := &MockExecutionHelper{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/pkg/workflows/wasm/host/mocks/module_v2.go b/pkg/workflows/wasm/host/mocks/module_v2.go index 72da9a931..ba16a8855 100644 --- a/pkg/workflows/wasm/host/mocks/module_v2.go +++ b/pkg/workflows/wasm/host/mocks/module_v2.go @@ -57,7 +57,7 @@ func (_c *ModuleV2_Close_Call) RunAndReturn(run func()) *ModuleV2_Close_Call { } // Execute provides a mock function with given fields: ctx, request, handler -func (_m *ModuleV2) Execute(ctx context.Context, request *pb.ExecuteRequest, handler host.ExecutionHelper) (*pb.ExecutionResult, error) { +func (_m *ModuleV2) Execute(ctx context.Context, request *pb.ExecuteRequest, handler host.CapabilityExecutor) (*pb.ExecutionResult, error) { ret := _m.Called(ctx, request, handler) if len(ret) == 0 { @@ -66,10 +66,10 @@ func (_m *ModuleV2) Execute(ctx context.Context, request *pb.ExecuteRequest, han var r0 *pb.ExecutionResult var r1 error - if rf, ok := ret.Get(0).(func(context.Context, *pb.ExecuteRequest, host.ExecutionHelper) (*pb.ExecutionResult, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, *pb.ExecuteRequest, host.CapabilityExecutor) (*pb.ExecutionResult, error)); ok { return rf(ctx, request, handler) } - if rf, ok := ret.Get(0).(func(context.Context, *pb.ExecuteRequest, host.ExecutionHelper) *pb.ExecutionResult); ok { + if rf, ok := ret.Get(0).(func(context.Context, *pb.ExecuteRequest, host.CapabilityExecutor) *pb.ExecutionResult); ok { r0 = rf(ctx, request, handler) } else { if ret.Get(0) != nil { @@ -77,7 +77,7 @@ func (_m *ModuleV2) Execute(ctx context.Context, request *pb.ExecuteRequest, han } } - if rf, ok := ret.Get(1).(func(context.Context, *pb.ExecuteRequest, host.ExecutionHelper) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, *pb.ExecuteRequest, host.CapabilityExecutor) error); ok { r1 = rf(ctx, request, handler) } else { r1 = ret.Error(1) @@ -94,14 +94,14 @@ type ModuleV2_Execute_Call struct { // Execute is a helper method to define mock.On call // - ctx context.Context // - request *pb.ExecuteRequest -// - handler host.ExecutionHelper +// - handler host.CapabilityExecutor func (_e *ModuleV2_Expecter) Execute(ctx interface{}, request interface{}, handler interface{}) *ModuleV2_Execute_Call { return &ModuleV2_Execute_Call{Call: _e.mock.On("Execute", ctx, request, handler)} } -func (_c *ModuleV2_Execute_Call) Run(run func(ctx context.Context, request *pb.ExecuteRequest, handler host.ExecutionHelper)) *ModuleV2_Execute_Call { +func (_c *ModuleV2_Execute_Call) Run(run func(ctx context.Context, request *pb.ExecuteRequest, handler host.CapabilityExecutor)) *ModuleV2_Execute_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(*pb.ExecuteRequest), args[2].(host.ExecutionHelper)) + run(args[0].(context.Context), args[1].(*pb.ExecuteRequest), args[2].(host.CapabilityExecutor)) }) return _c } @@ -111,7 +111,7 @@ func (_c *ModuleV2_Execute_Call) Return(_a0 *pb.ExecutionResult, _a1 error) *Mod return _c } -func (_c *ModuleV2_Execute_Call) RunAndReturn(run func(context.Context, *pb.ExecuteRequest, host.ExecutionHelper) (*pb.ExecutionResult, error)) *ModuleV2_Execute_Call { +func (_c *ModuleV2_Execute_Call) RunAndReturn(run func(context.Context, *pb.ExecuteRequest, host.CapabilityExecutor) (*pb.ExecutionResult, error)) *ModuleV2_Execute_Call { _c.Call.Return(run) return _c } diff --git a/pkg/workflows/wasm/host/module.go b/pkg/workflows/wasm/host/module.go index 1d9df7670..6e4e1f249 100644 --- a/pkg/workflows/wasm/host/module.go +++ b/pkg/workflows/wasm/host/module.go @@ -8,10 +8,8 @@ import ( "encoding/json" "errors" "fmt" - "hash/fnv" "io" "math" - "math/rand" "regexp" "strings" "sync" @@ -88,19 +86,13 @@ type ModuleV2 interface { ModuleBase // V2/"NoDAG" API - request either the list of Trigger Subscriptions or launch workflow execution - Execute(ctx context.Context, request *wasmpb.ExecuteRequest, handler ExecutionHelper) (*wasmpb.ExecutionResult, error) + Execute(ctx context.Context, request *wasmpb.ExecuteRequest, handler CapabilityExecutor) (*wasmpb.ExecutionResult, error) } -// ExecutionHelper Implemented by those running the host, for example the Workflow Engine -type ExecutionHelper interface { - // CallCapability blocking call to the Workflow Engine +// Implemented by the Workflow Engine +type CapabilityExecutor interface { + // blocking call to the Workflow Engine CallCapability(ctx context.Context, request *sdkpb.CapabilityRequest) (*sdkpb.CapabilityResponse, error) - - GetId() string - - GetNodeTime() time.Time - - GetDONTime() time.Time } type module struct { @@ -251,7 +243,7 @@ func NewModule(modCfg *ModuleConfig, binary []byte, opts ...func(*ModuleConfig)) } func linkNoDAG(m *module, store *wasmtime.Store, exec *execution[*wasmpb.ExecutionResult]) (*wasmtime.Instance, error) { - linker, err := newWasiLinker(exec, m.engine) + linker, err := newWasiLinker(m.cfg, m.engine) if err != nil { return nil, err } @@ -298,25 +290,11 @@ func linkNoDAG(m *module, store *wasmtime.Store, exec *execution[*wasmpb.Executi return nil, fmt.Errorf("error wrapping log func: %w", err) } - if err = linker.FuncWrap( - "env", - "switch_modes", - exec.switchModes); err != nil { - return nil, fmt.Errorf("error wrapping switchModes func: %w", err) - } - - if err = linker.FuncWrap( - "env", - "random_seed", - exec.getSeed); err != nil { - return nil, fmt.Errorf("error wrapping getSeed func: %w", err) - } - return linker.Instantiate(store, m.module) } func linkLegacyDAG(m *module, store *wasmtime.Store, exec *execution[*wasmdagpb.Response]) (*wasmtime.Instance, error) { - linker, err := newDagWasiLinker(m.cfg, m.engine) + linker, err := newWasiLinker(m.cfg, m.engine) if err != nil { return nil, err } @@ -392,7 +370,7 @@ func (m *module) IsLegacyDAG() bool { return m.v2ImportName == "" } -func (m *module) Execute(ctx context.Context, req *wasmpb.ExecuteRequest, executor ExecutionHelper) (*wasmpb.ExecutionResult, error) { +func (m *module) Execute(ctx context.Context, req *wasmpb.ExecuteRequest, executor CapabilityExecutor) (*wasmpb.ExecutionResult, error) { if m.IsLegacyDAG() { return nil, errors.New("cannot execute a legacy dag workflow") } @@ -444,7 +422,7 @@ func runWasm[I, O proto.Message]( request I, setMaxResponseSize func(i I, maxSize uint64), linkWasm linkFn[O], - helper ExecutionHelper) (O, error) { + executor CapabilityExecutor) (O, error) { var o O @@ -490,23 +468,11 @@ func runWasm[I, O proto.Message]( deadline := *m.cfg.Timeout / m.cfg.TickInterval store.SetEpochDeadline(uint64(deadline)) - h := fnv.New64a() - if helper != nil { - id := helper.GetId() - _, _ = h.Write([]byte(id)) - } - - donSeed := int64(h.Sum64()) - - _ = ctxWithTimeout exec := &execution[O]{ - //ctx: ctxWithTimeout, - ctx: ctx, + ctx: ctxWithTimeout, capabilityResponses: map[int32]<-chan *sdkpb.CapabilityResponse{}, module: m, - executor: helper, - donSeed: donSeed, - nodeSeed: int64(rand.Uint64()), + executor: executor, } instance, err := linkWasm(m, store, exec) diff --git a/pkg/workflows/wasm/host/module_test.go b/pkg/workflows/wasm/host/module_test.go index c284320ae..8637e1b23 100644 --- a/pkg/workflows/wasm/host/module_test.go +++ b/pkg/workflows/wasm/host/module_test.go @@ -590,8 +590,8 @@ func Test_toEmissible(t *testing.T) { // CallAwaitRace validates that every call can be awaited. func Test_CallAwaitRace(t *testing.T) { ctx := t.Context() - mockExecHelper := NewMockExecutionHelper(t) - mockExecHelper.EXPECT(). + mockCapExec := NewMockCapabilityExecutor(t) + mockCapExec.EXPECT(). CallCapability(matches.AnyContext, mock.Anything). Return(&sdkpb.CapabilityResponse{}, nil) @@ -604,7 +604,7 @@ func Test_CallAwaitRace(t *testing.T) { module: m, capabilityResponses: map[int32]<-chan *sdkpb.CapabilityResponse{}, ctx: t.Context(), - executor: mockExecHelper, + executor: mockCapExec, } wg.Add(wantAttempts) diff --git a/pkg/workflows/wasm/host/test/nodag/randoms/cmd/main.go b/pkg/workflows/wasm/host/test/nodag/randoms/cmd/main.go deleted file mode 100644 index 27ca5ab0c..000000000 --- a/pkg/workflows/wasm/host/test/nodag/randoms/cmd/main.go +++ /dev/null @@ -1,54 +0,0 @@ -//go:build wasip1 - -package main - -import ( - "strconv" - - "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc/pkg/test_capabilities/basictrigger" - "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc/pkg/test_capabilities/nodeaction" - "github.com/smartcontractkit/chainlink-common/pkg/workflows/sdk/v2" - "github.com/smartcontractkit/chainlink-common/pkg/workflows/testhelpers/v2" - "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/v2" -) - -func main() { - runner := wasm.NewDonRunner() - basic := &basictrigger.Basic{} - - runner.Run(&sdk.WorkflowArgs[sdk.DonRuntime]{ - Handlers: []sdk.Handler[sdk.DonRuntime]{ - sdk.NewDonHandler( - basic.Trigger(testhelpers.TestWorkflowTriggerConfig()), - func(runtime sdk.DonRuntime, trigger *basictrigger.Outputs) (uint64, error) { - r, err := runtime.Rand() - if err != nil { - return 0, err - } - total := r.Uint64() - sdk.RunInNodeMode[uint64](runtime, func(nrt sdk.NodeRuntime) (uint64, error) { - node, err := (&nodeaction.BasicAction{}).PerformAction(nrt, &nodeaction.NodeInputs{ - InputThing: false, - }).Await() - - if err != nil { - return 0, err - } - - // Conditionally generate a random number based on the node output. - // This ensures it doesn't impact the next DON mode number. - if node.OutputThing < 100 { - nr, err := nrt.Rand() - if err != nil { - return 0, err - } - runtime.LogWriter().Write([]byte(strconv.FormatUint(nr.Uint64(), 10))) - } - return 0, nil - }, sdk.ConsensusIdenticalAggregation[uint64]()) - total += r.Uint64() - return total, nil - }), - }, - }) -} diff --git a/pkg/workflows/wasm/host/wasip1.go b/pkg/workflows/wasm/host/wasip1.go index b72181911..3343bbf70 100644 --- a/pkg/workflows/wasm/host/wasip1.go +++ b/pkg/workflows/wasm/host/wasip1.go @@ -17,39 +17,7 @@ var ( tick = 100 * time.Millisecond ) -func newWasiLinker[T any](exec *execution[T], engine *wasmtime.Engine) (*wasmtime.Linker, error) { - linker := wasmtime.NewLinker(engine) - linker.AllowShadowing(true) - - err := linker.DefineWasi() - if err != nil { - return nil, err - } - - // TODO: https://smartcontract-it.atlassian.net/browse/CAPPL-903 - err = linker.FuncWrap( - "wasi_snapshot_preview1", - "poll_oneoff", - pollOneoff, - ) - if err != nil { - return nil, err - } - - // TODO: https://smartcontract-it.atlassian.net/browse/CAPPL-903 - err = linker.FuncWrap( - "wasi_snapshot_preview1", - "clock_time_get", - clockTimeGet, - ) - if err != nil { - return nil, err - } - - return linker, nil -} - -func newDagWasiLinker(modCfg *ModuleConfig, engine *wasmtime.Engine) (*wasmtime.Linker, error) { +func newWasiLinker(modCfg *ModuleConfig, engine *wasmtime.Engine) (*wasmtime.Linker, error) { linker := wasmtime.NewLinker(engine) linker.AllowShadowing(true) diff --git a/pkg/workflows/wasm/host/wasm_nodag_test.go b/pkg/workflows/wasm/host/wasm_nodag_test.go index e99193d97..84ec5de2b 100644 --- a/pkg/workflows/wasm/host/wasm_nodag_test.go +++ b/pkg/workflows/wasm/host/wasm_nodag_test.go @@ -7,7 +7,6 @@ import ( "strings" "testing" - "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc/pkg/test_capabilities/nodeaction" sdkpb "github.com/smartcontractkit/chainlink-common/pkg/workflows/sdk/v2/pb" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -24,12 +23,10 @@ import ( ) const ( + nodagBinaryLocation = "test/nodag/singlehandler/cmd/testmodule.wasm" + nodagMultiTriggerBinaryLocation = "test/nodag/multihandler/cmd/testmodule.wasm" nodagBinaryCmd = "test/nodag/singlehandler/cmd" - nodagBinaryLocation = nodagBinaryCmd + "/testmodule.wasm" nodagMultiTriggerBinaryCmd = "test/nodag/multihandler/cmd" - nodagMultiTriggerBinaryLocation = nodagMultiTriggerBinaryCmd + "/testmodule.wasm" - nodagRandomBinaryCmd = "test/nodag/randoms/cmd" - nodagRandomBinaryLocation = nodagRandomBinaryCmd + "/testmodule.wasm" ) var wordList = []string{"Hello, ", "world", "!"} @@ -39,7 +36,7 @@ func Test_NoDag_Run(t *testing.T) { binary := createTestBinary(nodagBinaryCmd, nodagBinaryLocation, true, t) - t.Run("NOK fails with unset ExecutionHelper for trigger", func(t *testing.T) { + t.Run("NOK fails with unset CapabilityExecutor for trigger", func(t *testing.T) { mc := defaultNoDAGModCfg(t) m, err := NewModule(mc, binary) require.NoError(t, err) @@ -57,7 +54,7 @@ func Test_NoDag_Run(t *testing.T) { require.ErrorContains(t, err, "invalid capability executor") }) - t.Run("OK can subscribe without setting ExecutionHelper", func(t *testing.T) { + t.Run("OK can subscribe without setting CapabilityExecutor", func(t *testing.T) { mc := defaultNoDAGModCfg(t) m, err := NewModule(mc, binary) require.NoError(t, err) @@ -84,8 +81,7 @@ func Test_NoDag_Run(t *testing.T) { m.Start() defer m.Close() - mockExecutionHelper := NewMockExecutionHelper(t) - mockExecutionHelper.EXPECT().GetId().Return("Id") + mockCapExecutor := NewMockCapabilityExecutor(t) // wrap some common payload newWantedCapResponse := func(i int) *sdkpb.CapabilityResponse { @@ -101,7 +97,7 @@ func Test_NoDag_Run(t *testing.T) { for i := 1; i < len(wordList); i++ { wantCapResp := newWantedCapResponse(i) - mockExecutionHelper.EXPECT().CallCapability(mock.Anything, mock.Anything). + mockCapExecutor.EXPECT().CallCapability(mock.Anything, mock.Anything). Run( func(ctx context.Context, request *sdkpb.CapabilityRequest) { require.Equal(t, "basic-test-action@1.0.0", request.Id) @@ -125,7 +121,7 @@ func Test_NoDag_Run(t *testing.T) { }, } - response, err := m.Execute(ctx, req, mockExecutionHelper) + response, err := m.Execute(ctx, req, mockCapExecutor) require.NoError(t, err) logs := observer.TakeAll() @@ -199,9 +195,9 @@ func Test_NoDag_MultipleTriggers_Run(t *testing.T) { m.Start() defer m.Close() - mockExecutionHelper := NewMockExecutionHelper(t) - mockExecutionHelper.EXPECT().GetId().Return("Id") + mockCapExecutor := NewMockCapabilityExecutor(t) + // wrap some common payload newWantedCapResponse := func(i int) *sdkpb.CapabilityResponse { action := &basicaction.Outputs{AdaptedThing: wordList[i]} anyAction, err := anypb.New(action) @@ -215,7 +211,7 @@ func Test_NoDag_MultipleTriggers_Run(t *testing.T) { for i := 1; i < len(wordList); i++ { wantCapResp := newWantedCapResponse(i) - mockExecutionHelper.EXPECT().CallCapability(mock.Anything, mock.Anything). + mockCapExecutor.EXPECT().CallCapability(mock.Anything, mock.Anything). Run( func(ctx context.Context, request *sdkpb.CapabilityRequest) { require.Equal(t, "basic-test-action@1.0.0", request.Id) @@ -238,7 +234,7 @@ func Test_NoDag_MultipleTriggers_Run(t *testing.T) { }, }, } - response, err := m.Execute(ctx, req, mockExecutionHelper) + response, err := m.Execute(ctx, req, mockCapExecutor) require.NoError(t, err) switch output := response.Result.(type) { @@ -255,100 +251,6 @@ func Test_NoDag_MultipleTriggers_Run(t *testing.T) { }) } -func Test_NoDag_Random(t *testing.T) { - t.Parallel() - - mc := defaultNoDAGModCfg(t) - lggr, observed := logger.TestObserved(t, zapcore.DebugLevel) - mc.Logger = lggr - - binary := createTestBinary(nodagRandomBinaryCmd, nodagRandomBinaryLocation, true, t) - - m, err := NewModule(mc, binary) - require.NoError(t, err) - - // Test binary executes node mode code conditionally based on the value >= 100 - anyId := "Id" - gte100Exec := NewMockExecutionHelper(t) - gte100Exec.EXPECT().GetId().Return(anyId) - gte100 := &nodeaction.NodeOutputs{OutputThing: 120} - gte100Payload, err := anypb.New(gte100) - require.NoError(t, err) - - gte100Exec.EXPECT().CallCapability(mock.Anything, mock.Anything).Return(&sdkpb.CapabilityResponse{ - Response: &sdkpb.CapabilityResponse_Payload{ - Payload: gte100Payload, - }, - }, nil) - - m.Start() - defer m.Close() - - trigger := &basictrigger.Outputs{CoolOutput: "trigger1"} - triggerPayload, err := anypb.New(trigger) - require.NoError(t, err) - anyRequest := &wasmpb.ExecuteRequest{ - Request: &wasmpb.ExecuteRequest_Trigger{ - Trigger: &sdkpb.Trigger{ - Id: uint64(0), - Payload: triggerPayload, - }, - }, - } - execution1Result, err := m.Execute(t.Context(), anyRequest, gte100Exec) - require.NoError(t, err) - wrappedValue1, err := values.FromProto(execution1Result.GetValue()) - require.NoError(t, err) - value1, err := wrappedValue1.Unwrap() - require.NoError(t, err) - - t.Run("Same execution id gives the same randoms, even if random is called in node mode", func(t *testing.T) { - // Clear from any previous test - observed.TakeAll() - - lt100Exec := NewMockExecutionHelper(t) - lt100Exec.EXPECT().GetId().Return(anyId) - lt100 := &nodeaction.NodeOutputs{OutputThing: 120} - lt100Payload, err := anypb.New(lt100) - require.NoError(t, err) - - lt100Exec.EXPECT().CallCapability(mock.Anything, mock.Anything).Return(&sdkpb.CapabilityResponse{ - Response: &sdkpb.CapabilityResponse_Payload{ - Payload: lt100Payload, - }, - }, nil) - - exectuion2Result, err := m.Execute(t.Context(), anyRequest, lt100Exec) - require.NoError(t, err) - wrappedValue2, err := values.FromProto(exectuion2Result.GetValue()) - require.NoError(t, err) - value2, err := wrappedValue2.Unwrap() - require.NoError(t, err) - require.Equal(t, value1, value2, "Expected the same random number to be generated for the same trigger") - }) - - t.Run("Different execution id give different randoms", func(t *testing.T) { - require.NoError(t, err) - - gte100Exec2 := NewMockExecutionHelper(t) - gte100Exec2.EXPECT().GetId().Return("differentId") - - gte100Exec2.EXPECT().CallCapability(mock.Anything, mock.Anything).Return(&sdkpb.CapabilityResponse{ - Response: &sdkpb.CapabilityResponse_Payload{ - Payload: gte100Payload, - }, - }, nil) - - executionResult2, err := m.Execute(t.Context(), anyRequest, gte100Exec2) - require.NoError(t, err) - wrappedValue2, err := values.FromProto(executionResult2.GetValue()) - require.NoError(t, err) - value2, err := wrappedValue2.Unwrap() - require.NoError(t, err) - require.NotEqual(t, value1, value2, "Expected different random numbers for different triggers") - }) -} - func defaultNoDAGModCfg(t testing.TB) *ModuleConfig { return &ModuleConfig{ Logger: logger.Test(t), @@ -357,12 +259,10 @@ func defaultNoDAGModCfg(t testing.TB) *ModuleConfig { } func getTriggersSpec(t *testing.T, m ModuleV2, config []byte) (*sdkpb.TriggerSubscriptionRequest, error) { - helper := NewMockExecutionHelper(t) - helper.EXPECT().GetId().Return("Id") execResult, err := m.Execute(t.Context(), &wasmpb.ExecuteRequest{ Config: config, Request: &wasmpb.ExecuteRequest_Subscribe{Subscribe: &emptypb.Empty{}}, - }, helper) + }, NewMockCapabilityExecutor(t)) if err != nil { return nil, err diff --git a/pkg/workflows/wasm/v2/runner.go b/pkg/workflows/wasm/v2/runner.go index 093fdb627..3d6930a77 100644 --- a/pkg/workflows/wasm/v2/runner.go +++ b/pkg/workflows/wasm/v2/runner.go @@ -20,7 +20,6 @@ type runnerInternals interface { args() []string sendResponse(response unsafe.Pointer, responseLen int32) int32 versionV2() - switchModes(mode int32) } func newDonRunner(runnerInternals runnerInternals, runtimeInternals runtimeInternals) sdk.DonRunner { diff --git a/pkg/workflows/wasm/v2/runner_test_hooks.go b/pkg/workflows/wasm/v2/runner_test_hooks.go index 6344d29c1..83c7527c9 100644 --- a/pkg/workflows/wasm/v2/runner_test_hooks.go +++ b/pkg/workflows/wasm/v2/runner_test_hooks.go @@ -12,8 +12,6 @@ type runnerInternalsTestHook struct { execId string arguments []string sentResponse []byte - modeSwitched bool - mode int32 } func (r *runnerInternalsTestHook) args() []string { @@ -27,9 +25,4 @@ func (r *runnerInternalsTestHook) sendResponse(response unsafe.Pointer, response func (r *runnerInternalsTestHook) versionV2() {} -func (r *runnerInternalsTestHook) switchModes(mode int32) { - r.mode = mode - r.modeSwitched = true -} - var _ runnerInternals = (*runnerInternalsTestHook)(nil) diff --git a/pkg/workflows/wasm/v2/runner_wasip1.go b/pkg/workflows/wasm/v2/runner_wasip1.go index 145b76738..d4d29fb4f 100644 --- a/pkg/workflows/wasm/v2/runner_wasip1.go +++ b/pkg/workflows/wasm/v2/runner_wasip1.go @@ -5,7 +5,6 @@ import ( "unsafe" "github.com/smartcontractkit/chainlink-common/pkg/workflows/sdk/v2" - sdkpb "github.com/smartcontractkit/chainlink-common/pkg/workflows/sdk/v2/pb" ) //go:wasmimport env send_response @@ -14,16 +13,11 @@ func sendResponse(response unsafe.Pointer, responseLen int32) int32 //go:wasmimport env version_v2 func versionV2() -//go:wasmimport env switch_modes -func switchModes(mode int32) - func NewDonRunner() sdk.DonRunner { - switchModes((int32)(sdkpb.Mode_DON)) return newDonRunner(runnerInternalsImpl{}, runtimeInternalsImpl{}) } func NewNodeRunner() sdk.NodeRunner { - switchModes((int32)(sdkpb.Mode_Node)) return newNodeRunner(runnerInternalsImpl{}, runtimeInternalsImpl{}) } @@ -42,7 +36,3 @@ func (r runnerInternalsImpl) sendResponse(response unsafe.Pointer, responseLen i func (r runnerInternalsImpl) versionV2() { versionV2() } - -func (r runnerInternalsImpl) switchModes(mode int32) { - switchModes(mode) -} diff --git a/pkg/workflows/wasm/v2/runtime.go b/pkg/workflows/wasm/v2/runtime.go index 033991d35..07e3953d0 100644 --- a/pkg/workflows/wasm/v2/runtime.go +++ b/pkg/workflows/wasm/v2/runtime.go @@ -2,7 +2,6 @@ package wasm import ( "errors" - "math/rand" "unsafe" "github.com/smartcontractkit/chainlink-common/pkg/workflows/internal/v2/sdkimpl" @@ -13,92 +12,69 @@ import ( type runtimeInternals interface { callCapability(req unsafe.Pointer, reqLen int32) int64 awaitCapabilities(awaitRequest unsafe.Pointer, awaitRequestLen int32, responseBuffer unsafe.Pointer, maxResponseLen int32) int64 - switchModes(mode int32) - getSeed(mode int32) int64 } func newRuntime(internals runtimeInternals, mode sdkpb.Mode) sdkimpl.RuntimeBase { return sdkimpl.RuntimeBase{ - Writer: &writer{}, - Mode: mode, - RuntimeHelpers: &runtimeHelper{runtimeInternals: internals}, + Call: callCapabilityWasmWrapper(internals), + Await: awaitCapabilitiesWasmWrapper(internals), + Writer: &writer{}, + Mode: mode, } } -type runtimeHelper struct { - runtimeInternals - donSource rand.Source - nodeSource rand.Source -} - -func (r *runtimeHelper) GetSource(mode sdkpb.Mode) rand.Source { - switch mode { - case sdkpb.Mode_DON: - if r.donSource == nil { - seed := r.getSeed(int32(mode)) - r.donSource = rand.NewSource(seed) - } - return r.donSource - default: - if r.nodeSource == nil { - seed := r.getSeed(int32(mode)) - r.nodeSource = rand.NewSource(seed) +func callCapabilityWasmWrapper(internals runtimeInternals) func(request *sdkpb.CapabilityRequest) error { + return func(request *sdkpb.CapabilityRequest) error { + marshalled, err := proto.Marshal(request) + if err != nil { + return err } - return r.nodeSource - } -} -func (r *runtimeHelper) Call(request *sdkpb.CapabilityRequest) error { - marshalled, err := proto.Marshal(request) - if err != nil { - return err - } + marshalledPtr, marshalledLen, err := bufferToPointerLen(marshalled) + if err != nil { + return err + } - marshalledPtr, marshalledLen, err := bufferToPointerLen(marshalled) - if err != nil { - return err - } + // TODO (CAPPL-846): callCapability should also have a response pointer and response pointer buffer + result := internals.callCapability(marshalledPtr, marshalledLen) + if result < 0 { + return errors.New("cannot find capability " + request.Id) + } - // TODO (CAPPL-846): callCapability should also have a response pointer and response pointer buffer - result := r.callCapability(marshalledPtr, marshalledLen) - if result < 0 { - return errors.New("cannot find capability " + request.Id) + return nil } - - return nil } -func (r *runtimeHelper) Await(request *sdkpb.AwaitCapabilitiesRequest, maxResponseSize uint64) (*sdkpb.AwaitCapabilitiesResponse, error) { - m, err := proto.Marshal(request) - if err != nil { - return nil, err - } +func awaitCapabilitiesWasmWrapper(internals runtimeInternals) func(request *sdkpb.AwaitCapabilitiesRequest, maxResponseSize uint64) (*sdkpb.AwaitCapabilitiesResponse, error) { + return func(request *sdkpb.AwaitCapabilitiesRequest, maxResponseSize uint64) (*sdkpb.AwaitCapabilitiesResponse, error) { - mptr, mlen, err := bufferToPointerLen(m) - if err != nil { - return nil, err - } + m, err := proto.Marshal(request) + if err != nil { + return nil, err + } - response := make([]byte, maxResponseSize) - responsePtr, responseLen, err := bufferToPointerLen(response) - if err != nil { - return nil, err - } + mptr, mlen, err := bufferToPointerLen(m) + if err != nil { + return nil, err + } - bytes := r.awaitCapabilities(mptr, mlen, responsePtr, responseLen) - if bytes < 0 { - return nil, errors.New(string(response[:-bytes])) - } + response := make([]byte, maxResponseSize) + responsePtr, responseLen, err := bufferToPointerLen(response) + if err != nil { + return nil, err + } - awaitResponse := &sdkpb.AwaitCapabilitiesResponse{} - err = proto.Unmarshal(response[:bytes], awaitResponse) - if err != nil { - return nil, err - } + bytes := internals.awaitCapabilities(mptr, mlen, responsePtr, responseLen) + if bytes < 0 { + return nil, errors.New(string(response[:-bytes])) + } - return awaitResponse, nil -} + awaitResponse := &sdkpb.AwaitCapabilitiesResponse{} + err = proto.Unmarshal(response[:bytes], awaitResponse) + if err != nil { + return nil, err + } -func (r *runtimeHelper) SwitchModes(mode sdkpb.Mode) { - r.switchModes(int32(mode)) + return awaitResponse, nil + } } diff --git a/pkg/workflows/wasm/v2/runtime_test.go b/pkg/workflows/wasm/v2/runtime_test.go index f8292808d..493be10dc 100644 --- a/pkg/workflows/wasm/v2/runtime_test.go +++ b/pkg/workflows/wasm/v2/runtime_test.go @@ -3,7 +3,6 @@ package wasm import ( "context" "errors" - "math/rand" "testing" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc/pkg/test_capabilities/basicaction" @@ -82,30 +81,6 @@ func TestRuntimeBase_LogWriter(t *testing.T) { assert.IsType(t, &writer{}, runtime.LogWriter()) } -func Test_runtimeInternals_UsesSeeds(t *testing.T) { - anyDonSeed := int64(123456789) - anyNodeSeed := int64(987654321) - helper := &runtimeHelper{runtimeInternals: &runtimeInternalsTestHook{ - donSeed: anyDonSeed, - nodeSeed: anyNodeSeed, - }} - assertRnd(t, helper, sdkpb.Mode_DON, anyDonSeed) - assertRnd(t, helper, sdkpb.Mode_Node, anyNodeSeed) -} - -func assertRnd(t *testing.T, helper *runtimeHelper, mode sdkpb.Mode, seed int64) { - rnd := rand.New(helper.GetSource(mode)) - buff := make([]byte, 1000) - n, err := rnd.Read(buff) - require.NoError(t, err) - assert.Equal(t, len(buff), n) - expectedBuf := make([]byte, 1000) - n, err = rand.New(rand.NewSource(seed)).Read(expectedBuf) - require.NoError(t, err) - assert.Equal(t, len(expectedBuf), n) - assert.Equal(t, string(expectedBuf), string(buff)) -} - func newTestRuntime(t *testing.T, callCapabilityErr bool, awaitResponseOverride func() ([]byte, error)) sdkimpl.RuntimeBase { internals := testRuntimeInternals(t) internals.callCapabilityErr = callCapabilityErr diff --git a/pkg/workflows/wasm/v2/runtime_test_hooks.go b/pkg/workflows/wasm/v2/runtime_test_hooks.go index 33ae81263..3be8d9d38 100644 --- a/pkg/workflows/wasm/v2/runtime_test_hooks.go +++ b/pkg/workflows/wasm/v2/runtime_test_hooks.go @@ -19,8 +19,6 @@ type runtimeInternalsTestHook struct { awaitResponseOverride func() ([]byte, error) callCapabilityErr bool outstandingCalls map[int32]sdk.Promise[*sdkpb.CapabilityResponse] - nodeSeed int64 - donSeed int64 } var _ runtimeInternals = (*runtimeInternalsTestHook)(nil) @@ -120,14 +118,3 @@ func readHostMessage(response []byte, msg string, isError bool) int64 { return written } - -func (r *runtimeInternalsTestHook) switchModes(_ int32) {} - -func (r *runtimeInternalsTestHook) getSeed(mode int32) int64 { - switch mode { - case int32(sdkpb.Mode_DON): - return r.donSeed - default: - return r.nodeSeed - } -} diff --git a/pkg/workflows/wasm/v2/runtime_wasip1.go b/pkg/workflows/wasm/v2/runtime_wasip1.go index d3e2fe80f..6cfd330c2 100644 --- a/pkg/workflows/wasm/v2/runtime_wasip1.go +++ b/pkg/workflows/wasm/v2/runtime_wasip1.go @@ -10,9 +10,6 @@ func callCapability(req unsafe.Pointer, reqLen int32) int64 //go:wasmimport env await_capabilities func awaitCapabilities(awaitRequest unsafe.Pointer, awaitRequestLen int32, responseBuffer unsafe.Pointer, maxResponseLen int32) int64 -//go:wasmimport env random_seed -func getSeed(mode int32) int64 - type runtimeInternalsImpl struct{} var _ runtimeInternals = runtimeInternalsImpl{} @@ -24,11 +21,3 @@ func (r runtimeInternalsImpl) callCapability(req unsafe.Pointer, reqLen int32) i func (r runtimeInternalsImpl) awaitCapabilities(awaitRequest unsafe.Pointer, awaitRequestLen int32, responseBuffer unsafe.Pointer, maxResponseLen int32) int64 { return awaitCapabilities(awaitRequest, awaitRequestLen, responseBuffer, maxResponseLen) } - -func (r runtimeInternalsImpl) switchModes(mode int32) { - switchModes(mode) -} - -func (r runtimeInternalsImpl) getSeed(mode int32) int64 { - return getSeed(mode) -} From 491d1b22ac61e81d5cb74388b0c4395de5906849 Mon Sep 17 00:00:00 2001 From: Silas Lenihan Date: Tue, 10 Jun 2025 14:00:06 -0400 Subject: [PATCH 13/16] Reapply "Seed random for setup and modes (#1236)" This reverts commit 54755c8caea7637e627f489b1603c74b1d80cd1d. --- .mockery.yaml | 2 +- pkg/workflows/internal/v2/sdkimpl/runtime.go | 67 ++++- .../internal/v2/sdkimpl/runtime_test.go | 74 +++++- pkg/workflows/sdk/v2/runtime.go | 2 + pkg/workflows/sdk/v2/runtime_test.go | 9 + pkg/workflows/sdk/v2/testutils/runner.go | 21 +- pkg/workflows/sdk/v2/testutils/runtime.go | 92 ++++--- pkg/workflows/wasm/host/execution.go | 24 +- .../host/mock_capability_executor_test.go | 96 -------- .../wasm/host/mock_execution_helper_test.go | 233 ++++++++++++++++++ pkg/workflows/wasm/host/mocks/module_v2.go | 16 +- pkg/workflows/wasm/host/module.go | 54 +++- pkg/workflows/wasm/host/module_test.go | 6 +- .../wasm/host/test/nodag/randoms/cmd/main.go | 54 ++++ pkg/workflows/wasm/host/wasip1.go | 34 ++- pkg/workflows/wasm/host/wasm_nodag_test.go | 124 +++++++++- pkg/workflows/wasm/v2/runner.go | 1 + pkg/workflows/wasm/v2/runner_test_hooks.go | 7 + pkg/workflows/wasm/v2/runner_wasip1.go | 10 + pkg/workflows/wasm/v2/runtime.go | 114 +++++---- pkg/workflows/wasm/v2/runtime_test.go | 25 ++ pkg/workflows/wasm/v2/runtime_test_hooks.go | 13 + pkg/workflows/wasm/v2/runtime_wasip1.go | 11 + 23 files changed, 849 insertions(+), 240 deletions(-) delete mode 100644 pkg/workflows/wasm/host/mock_capability_executor_test.go create mode 100644 pkg/workflows/wasm/host/mock_execution_helper_test.go create mode 100644 pkg/workflows/wasm/host/test/nodag/randoms/cmd/main.go diff --git a/.mockery.yaml b/.mockery.yaml index bacecec83..ae7134afa 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -38,7 +38,7 @@ packages: interfaces: ModuleV1: {} ModuleV2: {} - CapabilityExecutor: + ExecutionHelper: config: inpackage: true filename: "mock_{{.InterfaceName | snakecase}}_test.go" diff --git a/pkg/workflows/internal/v2/sdkimpl/runtime.go b/pkg/workflows/internal/v2/sdkimpl/runtime.go index 9cfcbf876..47ecc01be 100644 --- a/pkg/workflows/internal/v2/sdkimpl/runtime.go +++ b/pkg/workflows/internal/v2/sdkimpl/runtime.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "log/slog" + "math/rand" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/consensus" "github.com/smartcontractkit/chainlink-common/pkg/values" @@ -12,21 +13,30 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/workflows/sdk/v2/pb" ) -type CallCapabilityFn func(request *pb.CapabilityRequest) error -type AwaitCapabilitiesFn func(request *pb.AwaitCapabilitiesRequest, maxResponseSize uint64) (*pb.AwaitCapabilitiesResponse, error) +type RuntimeHelpers interface { + Call(request *pb.CapabilityRequest) error + Await(request *pb.AwaitCapabilitiesRequest, maxResponseSize uint64) (*pb.AwaitCapabilitiesResponse, error) + SwitchModes(mode pb.Mode) + GetSource(mode pb.Mode) rand.Source +} type RuntimeBase struct { ConfigBytes []byte MaxResponseSize uint64 - Call CallCapabilityFn - Await AwaitCapabilitiesFn Writer io.Writer + RuntimeHelpers + source rand.Source + source64 rand.Source64 modeErr error Mode pb.Mode nextCallId int32 } +var _ sdk.RuntimeBase = (*RuntimeBase)(nil) +var _ rand.Source = (*RuntimeBase)(nil) +var _ rand.Source64 = (*RuntimeBase)(nil) + func (r *RuntimeBase) CallCapability(request *pb.CapabilityRequest) sdk.Promise[*pb.CapabilityResponse] { if r.Mode == pb.Mode_DON { r.nextCallId++ @@ -40,7 +50,7 @@ func (r *RuntimeBase) CallCapability(request *pb.CapabilityRequest) sdk.Promise[ return sdk.PromiseFromResult[*pb.CapabilityResponse](nil, r.modeErr) } - err := r.Call(request) + err := r.RuntimeHelpers.Call(request) if err != nil { return sdk.PromiseFromResult[*pb.CapabilityResponse](nil, err) } @@ -75,6 +85,22 @@ func (r *RuntimeBase) Logger() *slog.Logger { return slog.New(slog.NewTextHandler(r.LogWriter(), nil)) } +func (r *RuntimeBase) Rand() (*rand.Rand, error) { + if r.modeErr != nil { + return nil, r.modeErr + } + + if r.source == nil { + r.source = r.RuntimeHelpers.GetSource(r.Mode) + r64, ok := r.source.(rand.Source64) + if ok { + r.source64 = r64 + } + } + + return rand.New(r), nil +} + type DonRuntime struct { RuntimeBase nextNodeCallId int32 @@ -85,7 +111,9 @@ func (d *DonRuntime) RunInNodeMode(fn func(nodeRuntime sdk.NodeRuntime) *pb.Simp nrt.nextCallId = d.nextNodeCallId nrt.Mode = pb.Mode_Node d.modeErr = sdk.DonModeCallInNodeMode() + d.SwitchModes(pb.Mode_Node) observation := fn(nrt) + d.SwitchModes(pb.Mode_DON) nrt.modeErr = sdk.NodeModeCallInDonMode() d.modeErr = nil d.nextNodeCallId = nrt.nextCallId @@ -95,6 +123,35 @@ func (d *DonRuntime) RunInNodeMode(fn func(nodeRuntime sdk.NodeRuntime) *pb.Simp }) } +func (r *RuntimeBase) Int63() int64 { + if r.modeErr != nil { + panic("random cannot be used outside the mode it was created in") + } + + return r.source.Int63() +} + +func (r *RuntimeBase) Uint64() uint64 { + if r.modeErr != nil { + panic("random cannot be used outside the mode it was created in") + } + + // borrowed from math/rand + if r.source64 != nil { + return r.source64.Uint64() + } + + return uint64(r.source.Int63())>>31 | uint64(r.source.Int63())<<32 +} + +func (r *RuntimeBase) Seed(seed int64) { + if r.modeErr != nil { + panic("random cannot be used outside the mode it was created in") + } + + r.source.Seed(seed) +} + var _ sdk.DonRuntime = &DonRuntime{} type NodeRuntime struct { diff --git a/pkg/workflows/internal/v2/sdkimpl/runtime_test.go b/pkg/workflows/internal/v2/sdkimpl/runtime_test.go index 966e3b24c..c017f33c5 100644 --- a/pkg/workflows/internal/v2/sdkimpl/runtime_test.go +++ b/pkg/workflows/internal/v2/sdkimpl/runtime_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "math/rand" "testing" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/consensus/consensusmock" @@ -110,8 +111,11 @@ func TestRuntime_CallCapability(t *testing.T) { test := func(rt sdk.DonRuntime, _ *basictrigger.Outputs) (string, error) { drt := rt.(*sdkimpl.DonRuntime) - drt.Await = func(request *pb.AwaitCapabilitiesRequest, maxResponseSize uint64) (*pb.AwaitCapabilitiesResponse, error) { - return nil, expectedErr + drt.RuntimeHelpers = &awaitOverride{ + RuntimeHelpers: drt.RuntimeHelpers, + await: func(request *pb.AwaitCapabilitiesRequest, maxResponseSize uint64) (*pb.AwaitCapabilitiesResponse, error) { + return nil, expectedErr + }, } _, err := capability.PerformAction(rt, &basicaction.Inputs{InputThing: true}).Await() return "", err @@ -133,8 +137,12 @@ func TestRuntime_CallCapability(t *testing.T) { test := func(rt sdk.DonRuntime, _ *basictrigger.Outputs) (string, error) { drt := rt.(*sdkimpl.DonRuntime) - drt.Await = func(request *pb.AwaitCapabilitiesRequest, maxResponseSize uint64) (*pb.AwaitCapabilitiesResponse, error) { - return &pb.AwaitCapabilitiesResponse{Responses: map[int32]*pb.CapabilityResponse{}}, nil + + drt.RuntimeHelpers = &awaitOverride{ + RuntimeHelpers: drt.RuntimeHelpers, + await: func(request *pb.AwaitCapabilitiesRequest, maxResponseSize uint64) (*pb.AwaitCapabilitiesResponse, error) { + return &pb.AwaitCapabilitiesResponse{Responses: map[int32]*pb.CapabilityResponse{}}, nil + }, } _, err := capability.PerformAction(rt, &basicaction.Inputs{InputThing: true}).Await() return "", err @@ -145,6 +153,55 @@ func TestRuntime_CallCapability(t *testing.T) { }) } +func TestRuntime_Rand(t *testing.T) { + t.Run("random delegates", func(t *testing.T) { + test := func(rt sdk.DonRuntime, _ *basictrigger.Outputs) (uint64, error) { + r, err := rt.Rand() + if err != nil { + return 0, err + } + return r.Uint64(), nil + } + + ran, result, err := testRuntime(t, test) + require.NoError(t, err) + assert.True(t, ran) + assert.Equal(t, rand.New(rand.NewSource(1)).Uint64(), result) + }) + + t.Run("random does not allow use in the wrong mode", func(t *testing.T) { + test := func(rt sdk.DonRuntime, _ *basictrigger.Outputs) (uint64, error) { + return sdk.RunInNodeMode[uint64](rt, func(_ sdk.NodeRuntime) (uint64, error) { + if _, err := rt.Rand(); err != nil { + return 0, err + } + + return 0, fmt.Errorf("should not be called in node mode") + }, sdk.ConsensusMedianAggregation[uint64]()).Await() + } + + _, _, err := testRuntime(t, test) + require.Error(t, err) + }) + + t.Run("returned random panics if you use it in the wrong mode ", func(t *testing.T) { + assert.Panics(t, func() { + test := func(rt sdk.DonRuntime, _ *basictrigger.Outputs) (uint64, error) { + r, err := rt.Rand() + if err != nil { + return 0, err + } + return sdk.RunInNodeMode[uint64](rt, func(_ sdk.NodeRuntime) (uint64, error) { + r.Uint64() + return 0, fmt.Errorf("should not be called in node mode") + }, sdk.ConsensusMedianAggregation[uint64]()).Await() + } + + _, _, _ = testRuntime(t, test) + }) + }) +} + func TestDonRuntime_RunInNodeMode(t *testing.T) { t.Run("Successful consensus", func(t *testing.T) { nodeMock, err := nodeactionmock.NewBasicActionCapability(t) @@ -316,3 +373,12 @@ type consensusValues struct { Err error Resp int64 } + +type awaitOverride struct { + sdkimpl.RuntimeHelpers + await func(request *pb.AwaitCapabilitiesRequest, maxResponseSize uint64) (*pb.AwaitCapabilitiesResponse, error) +} + +func (a *awaitOverride) Await(request *pb.AwaitCapabilitiesRequest, maxResponseSize uint64) (*pb.AwaitCapabilitiesResponse, error) { + return a.await(request, maxResponseSize) +} diff --git a/pkg/workflows/sdk/v2/runtime.go b/pkg/workflows/sdk/v2/runtime.go index ed34c98a8..9c0d474ee 100644 --- a/pkg/workflows/sdk/v2/runtime.go +++ b/pkg/workflows/sdk/v2/runtime.go @@ -4,6 +4,7 @@ import ( "errors" "io" "log/slog" + "math/rand" "reflect" "github.com/smartcontractkit/chainlink-common/pkg/values" @@ -17,6 +18,7 @@ type RuntimeBase interface { Config() []byte LogWriter() io.Writer Logger() *slog.Logger + Rand() (*rand.Rand, error) } // NodeRuntime is not thread safe and must not be used concurrently. diff --git a/pkg/workflows/sdk/v2/runtime_test.go b/pkg/workflows/sdk/v2/runtime_test.go index f8c09572f..70ebd3182 100644 --- a/pkg/workflows/sdk/v2/runtime_test.go +++ b/pkg/workflows/sdk/v2/runtime_test.go @@ -4,6 +4,7 @@ import ( "errors" "io" "log/slog" + "math/rand" "testing" "github.com/smartcontractkit/chainlink-common/pkg/values" @@ -95,6 +96,10 @@ func TestRunInNodeMode_ErrorWrappingDefault(t *testing.T) { // mockNodeRuntime implements NodeRuntime for testing. type mockNodeRuntime struct{} +func (m mockNodeRuntime) Rand() (*rand.Rand, error) { + panic("unused in tests") +} + func (m mockNodeRuntime) CallCapability(_ *pb.CapabilityRequest) sdk.Promise[*pb.CapabilityResponse] { panic("unused in tests") } @@ -115,6 +120,10 @@ func (m mockNodeRuntime) IsNodeRuntime() {} type mockDonRuntime struct{} +func (m *mockDonRuntime) Rand() (*rand.Rand, error) { + panic("unused in tests") +} + func (m *mockDonRuntime) RunInNodeMode(fn func(nodeRuntime sdk.NodeRuntime) *pb.SimpleConsensusInputs) sdk.Promise[values.Value] { req := fn(mockNodeRuntime{}) diff --git a/pkg/workflows/sdk/v2/testutils/runner.go b/pkg/workflows/sdk/v2/testutils/runner.go index f6debf114..842b65d21 100644 --- a/pkg/workflows/sdk/v2/testutils/runner.go +++ b/pkg/workflows/sdk/v2/testutils/runner.go @@ -4,6 +4,7 @@ import ( "errors" "io" "log/slog" + "math/rand" "testing" "github.com/google/uuid" @@ -27,6 +28,7 @@ type runner[T any] struct { runtime T writer *testWriter base *sdkimpl.RuntimeBase + source rand.Source } func (r *runner[T]) Logs() []string { @@ -41,6 +43,10 @@ func (r *runner[T]) LogWriter() io.Writer { return r.writer } +func (r *runner[T]) SetRandSource(source rand.Source) { + r.source = source +} + type TestRunner interface { Result() (bool, any, error) @@ -53,6 +59,8 @@ type TestRunner interface { SetMaxResponseSizeBytes(maxResponseSizebytes uint64) Logs() []string + + SetRandSource(source rand.Source) } type DonRunner interface { @@ -67,14 +75,18 @@ type NodeRunner interface { func NewDonRunner(tb testing.TB, config []byte) DonRunner { writer := &testWriter{} - drt := &sdkimpl.DonRuntime{RuntimeBase: newRuntime(tb, config, writer)} - return newRunner[sdk.DonRuntime](tb, config, writer, drt, &drt.RuntimeBase) + drt := &sdkimpl.DonRuntime{} + r := newRunner[sdk.DonRuntime](tb, config, writer, drt, &drt.RuntimeBase) + drt.RuntimeBase = newRuntime(tb, config, writer, func() rand.Source { return r.source }) + return r } func NewNodeRunner(tb testing.TB, config []byte) NodeRunner { writer := &testWriter{} - nrt := &sdkimpl.NodeRuntime{RuntimeBase: newRuntime(tb, config, writer)} - return newRunner[sdk.NodeRuntime](tb, config, writer, nrt, &nrt.RuntimeBase) + nrt := &sdkimpl.NodeRuntime{} + r := newRunner[sdk.NodeRuntime](tb, config, writer, nrt, &nrt.RuntimeBase) + nrt.RuntimeBase = newRuntime(tb, config, writer, func() rand.Source { return r.source }) + return r } func newRunner[T any](tb testing.TB, config []byte, writer *testWriter, t T, base *sdkimpl.RuntimeBase) *runner[T] { @@ -86,6 +98,7 @@ func newRunner[T any](tb testing.TB, config []byte, writer *testWriter, t T, bas runtime: t, writer: writer, base: base, + source: rand.NewSource(1), } return r diff --git a/pkg/workflows/sdk/v2/testutils/runtime.go b/pkg/workflows/sdk/v2/testutils/runtime.go index da4d08c0d..fb3987621 100644 --- a/pkg/workflows/sdk/v2/testutils/runtime.go +++ b/pkg/workflows/sdk/v2/testutils/runtime.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "math/rand" "testing" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/consensus/consensusmock" @@ -15,9 +16,7 @@ import ( "google.golang.org/protobuf/proto" ) -func newRuntime(tb testing.TB, configBytes []byte, writer *testWriter) sdkimpl.RuntimeBase { - tb.Cleanup(func() { delete(calls, tb) }) - +func newRuntime(tb testing.TB, configBytes []byte, writer *testWriter, sourceFn func() rand.Source) sdkimpl.RuntimeBase { defaultConsensus, err := consensusmock.NewConsensusCapability(tb) // Do not override if the user provided their own consensus method @@ -28,9 +27,8 @@ func newRuntime(tb testing.TB, configBytes []byte, writer *testWriter) sdkimpl.R return sdkimpl.RuntimeBase{ ConfigBytes: configBytes, MaxResponseSize: sdk.DefaultMaxResponseSizeBytes, - Call: createCallCapability(tb), - Await: createAwaitCapabilities(tb), Writer: writer, + RuntimeHelpers: &runtimeHelpers{tb: tb, calls: map[int32]chan *pb.CapabilityResponse{}, sourceFn: sourceFn}, } } @@ -49,59 +47,55 @@ func defaultSimpleConsensus(_ context.Context, input *pb.SimpleConsensusInputs) } } -var calls = map[testing.TB]map[int32]chan *pb.CapabilityResponse{} +type runtimeHelpers struct { + tb testing.TB + calls map[int32]chan *pb.CapabilityResponse + sourceFn func() rand.Source +} -func createCallCapability(tb testing.TB) func(request *pb.CapabilityRequest) error { - return func(request *pb.CapabilityRequest) error { - reg := registry.GetRegistry(tb) - capability, err := reg.GetCapability(request.Id) - if err != nil { - return err - } +func (rh *runtimeHelpers) GetSource(_ pb.Mode) rand.Source { + return rh.sourceFn() +} - respCh := make(chan *pb.CapabilityResponse, 1) - tbCalls, ok := calls[tb] - if !ok { - tbCalls = map[int32]chan *pb.CapabilityResponse{} - calls[tb] = tbCalls - } - tbCalls[request.CallbackId] = respCh - go func() { - respCh <- capability.Invoke(tb.Context(), request) - }() - return nil +func (rh *runtimeHelpers) Call(request *pb.CapabilityRequest) error { + reg := registry.GetRegistry(rh.tb) + capability, err := reg.GetCapability(request.Id) + if err != nil { + return err } + + respCh := make(chan *pb.CapabilityResponse, 1) + rh.calls[request.CallbackId] = respCh + go func() { + respCh <- capability.Invoke(rh.tb.Context(), request) + }() + return nil } -func createAwaitCapabilities(tb testing.TB) sdkimpl.AwaitCapabilitiesFn { - return func(request *pb.AwaitCapabilitiesRequest, maxResponseSize uint64) (*pb.AwaitCapabilitiesResponse, error) { - response := &pb.AwaitCapabilitiesResponse{Responses: map[int32]*pb.CapabilityResponse{}} +func (rh *runtimeHelpers) Await(request *pb.AwaitCapabilitiesRequest, maxResponseSize uint64) (*pb.AwaitCapabilitiesResponse, error) { + response := &pb.AwaitCapabilitiesResponse{Responses: map[int32]*pb.CapabilityResponse{}} - testCalls, ok := calls[tb] + var errs []error + for _, id := range request.Ids { + ch, ok := rh.calls[id] if !ok { - return nil, errors.New("no calls found for this test") + errs = append(errs, fmt.Errorf("no call found for %d", id)) + continue } - - var errs []error - for _, id := range request.Ids { - ch, ok := testCalls[id] - if !ok { - errs = append(errs, fmt.Errorf("no call found for %d", id)) - continue - } - select { - case resp := <-ch: - response.Responses[id] = resp - case <-tb.Context().Done(): - return nil, tb.Context().Err() - } - } - - bytes, _ := proto.Marshal(response) - if len(bytes) > int(maxResponseSize) { - return nil, errors.New(sdk.ResponseBufferTooSmall) + select { + case resp := <-ch: + response.Responses[id] = resp + case <-rh.tb.Context().Done(): + return nil, rh.tb.Context().Err() } + } - return response, errors.Join(errs...) + bytes, _ := proto.Marshal(response) + if len(bytes) > int(maxResponseSize) { + return nil, errors.New(sdk.ResponseBufferTooSmall) } + + return response, errors.Join(errs...) } + +func (rh *runtimeHelpers) SwitchModes(_ pb.Mode) {} diff --git a/pkg/workflows/wasm/host/execution.go b/pkg/workflows/wasm/host/execution.go index 90fec4aa3..91d1d1675 100644 --- a/pkg/workflows/wasm/host/execution.go +++ b/pkg/workflows/wasm/host/execution.go @@ -16,7 +16,11 @@ type execution[T any] struct { capabilityResponses map[int32]<-chan *sdkpb.CapabilityResponse lock sync.RWMutex module *module - executor CapabilityExecutor + executor ExecutionHelper + hasRun bool + mode sdkpb.Mode + donSeed int64 + nodeSeed int64 } // callCapAsync async calls a capability by placing execution results onto a @@ -81,6 +85,22 @@ func (e *execution[T]) log(caller *wasmtime.Caller, ptr int32, ptrlen int32) { lggr.Errorf("error calling log: %s", innerErr) return } - + lggr.Info(string(b)) } + +func (e *execution[T]) getSeed(mode int32) int64 { + switch sdkpb.Mode(mode) { + case sdkpb.Mode_DON: + return e.donSeed + case sdkpb.Mode_Node: + return e.nodeSeed + } + + return -1 +} + +func (e *execution[T]) switchModes(_ *wasmtime.Caller, mode int32) { + e.hasRun = true + e.mode = sdkpb.Mode(mode) +} diff --git a/pkg/workflows/wasm/host/mock_capability_executor_test.go b/pkg/workflows/wasm/host/mock_capability_executor_test.go deleted file mode 100644 index 89ccc38b7..000000000 --- a/pkg/workflows/wasm/host/mock_capability_executor_test.go +++ /dev/null @@ -1,96 +0,0 @@ -// Code generated by mockery v2.53.3. DO NOT EDIT. - -package host - -import ( - context "context" - - pb "github.com/smartcontractkit/chainlink-common/pkg/workflows/sdk/v2/pb" - mock "github.com/stretchr/testify/mock" -) - -// MockCapabilityExecutor is an autogenerated mock type for the CapabilityExecutor type -type MockCapabilityExecutor struct { - mock.Mock -} - -type MockCapabilityExecutor_Expecter struct { - mock *mock.Mock -} - -func (_m *MockCapabilityExecutor) EXPECT() *MockCapabilityExecutor_Expecter { - return &MockCapabilityExecutor_Expecter{mock: &_m.Mock} -} - -// CallCapability provides a mock function with given fields: ctx, request -func (_m *MockCapabilityExecutor) CallCapability(ctx context.Context, request *pb.CapabilityRequest) (*pb.CapabilityResponse, error) { - ret := _m.Called(ctx, request) - - if len(ret) == 0 { - panic("no return value specified for CallCapability") - } - - var r0 *pb.CapabilityResponse - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, *pb.CapabilityRequest) (*pb.CapabilityResponse, error)); ok { - return rf(ctx, request) - } - if rf, ok := ret.Get(0).(func(context.Context, *pb.CapabilityRequest) *pb.CapabilityResponse); ok { - r0 = rf(ctx, request) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*pb.CapabilityResponse) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, *pb.CapabilityRequest) error); ok { - r1 = rf(ctx, request) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// MockCapabilityExecutor_CallCapability_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CallCapability' -type MockCapabilityExecutor_CallCapability_Call struct { - *mock.Call -} - -// CallCapability is a helper method to define mock.On call -// - ctx context.Context -// - request *pb.CapabilityRequest -func (_e *MockCapabilityExecutor_Expecter) CallCapability(ctx interface{}, request interface{}) *MockCapabilityExecutor_CallCapability_Call { - return &MockCapabilityExecutor_CallCapability_Call{Call: _e.mock.On("CallCapability", ctx, request)} -} - -func (_c *MockCapabilityExecutor_CallCapability_Call) Run(run func(ctx context.Context, request *pb.CapabilityRequest)) *MockCapabilityExecutor_CallCapability_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(*pb.CapabilityRequest)) - }) - return _c -} - -func (_c *MockCapabilityExecutor_CallCapability_Call) Return(_a0 *pb.CapabilityResponse, _a1 error) *MockCapabilityExecutor_CallCapability_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MockCapabilityExecutor_CallCapability_Call) RunAndReturn(run func(context.Context, *pb.CapabilityRequest) (*pb.CapabilityResponse, error)) *MockCapabilityExecutor_CallCapability_Call { - _c.Call.Return(run) - return _c -} - -// NewMockCapabilityExecutor creates a new instance of MockCapabilityExecutor. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewMockCapabilityExecutor(t interface { - mock.TestingT - Cleanup(func()) -}) *MockCapabilityExecutor { - mock := &MockCapabilityExecutor{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/pkg/workflows/wasm/host/mock_execution_helper_test.go b/pkg/workflows/wasm/host/mock_execution_helper_test.go new file mode 100644 index 000000000..68e1efca8 --- /dev/null +++ b/pkg/workflows/wasm/host/mock_execution_helper_test.go @@ -0,0 +1,233 @@ +// Code generated by mockery v2.53.3. DO NOT EDIT. + +package host + +import ( + context "context" + + pb "github.com/smartcontractkit/chainlink-common/pkg/workflows/sdk/v2/pb" + mock "github.com/stretchr/testify/mock" + + time "time" +) + +// MockExecutionHelper is an autogenerated mock type for the ExecutionHelper type +type MockExecutionHelper struct { + mock.Mock +} + +type MockExecutionHelper_Expecter struct { + mock *mock.Mock +} + +func (_m *MockExecutionHelper) EXPECT() *MockExecutionHelper_Expecter { + return &MockExecutionHelper_Expecter{mock: &_m.Mock} +} + +// CallCapability provides a mock function with given fields: ctx, request +func (_m *MockExecutionHelper) CallCapability(ctx context.Context, request *pb.CapabilityRequest) (*pb.CapabilityResponse, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for CallCapability") + } + + var r0 *pb.CapabilityResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *pb.CapabilityRequest) (*pb.CapabilityResponse, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, *pb.CapabilityRequest) *pb.CapabilityResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*pb.CapabilityResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *pb.CapabilityRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockExecutionHelper_CallCapability_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CallCapability' +type MockExecutionHelper_CallCapability_Call struct { + *mock.Call +} + +// CallCapability is a helper method to define mock.On call +// - ctx context.Context +// - request *pb.CapabilityRequest +func (_e *MockExecutionHelper_Expecter) CallCapability(ctx interface{}, request interface{}) *MockExecutionHelper_CallCapability_Call { + return &MockExecutionHelper_CallCapability_Call{Call: _e.mock.On("CallCapability", ctx, request)} +} + +func (_c *MockExecutionHelper_CallCapability_Call) Run(run func(ctx context.Context, request *pb.CapabilityRequest)) *MockExecutionHelper_CallCapability_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*pb.CapabilityRequest)) + }) + return _c +} + +func (_c *MockExecutionHelper_CallCapability_Call) Return(_a0 *pb.CapabilityResponse, _a1 error) *MockExecutionHelper_CallCapability_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockExecutionHelper_CallCapability_Call) RunAndReturn(run func(context.Context, *pb.CapabilityRequest) (*pb.CapabilityResponse, error)) *MockExecutionHelper_CallCapability_Call { + _c.Call.Return(run) + return _c +} + +// GetDONTime provides a mock function with no fields +func (_m *MockExecutionHelper) GetDONTime() time.Time { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetDONTime") + } + + var r0 time.Time + if rf, ok := ret.Get(0).(func() time.Time); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(time.Time) + } + + return r0 +} + +// MockExecutionHelper_GetDONTime_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDONTime' +type MockExecutionHelper_GetDONTime_Call struct { + *mock.Call +} + +// GetDONTime is a helper method to define mock.On call +func (_e *MockExecutionHelper_Expecter) GetDONTime() *MockExecutionHelper_GetDONTime_Call { + return &MockExecutionHelper_GetDONTime_Call{Call: _e.mock.On("GetDONTime")} +} + +func (_c *MockExecutionHelper_GetDONTime_Call) Run(run func()) *MockExecutionHelper_GetDONTime_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockExecutionHelper_GetDONTime_Call) Return(_a0 time.Time) *MockExecutionHelper_GetDONTime_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockExecutionHelper_GetDONTime_Call) RunAndReturn(run func() time.Time) *MockExecutionHelper_GetDONTime_Call { + _c.Call.Return(run) + return _c +} + +// GetId provides a mock function with no fields +func (_m *MockExecutionHelper) GetId() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetId") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// MockExecutionHelper_GetId_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetId' +type MockExecutionHelper_GetId_Call struct { + *mock.Call +} + +// GetId is a helper method to define mock.On call +func (_e *MockExecutionHelper_Expecter) GetId() *MockExecutionHelper_GetId_Call { + return &MockExecutionHelper_GetId_Call{Call: _e.mock.On("GetId")} +} + +func (_c *MockExecutionHelper_GetId_Call) Run(run func()) *MockExecutionHelper_GetId_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockExecutionHelper_GetId_Call) Return(_a0 string) *MockExecutionHelper_GetId_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockExecutionHelper_GetId_Call) RunAndReturn(run func() string) *MockExecutionHelper_GetId_Call { + _c.Call.Return(run) + return _c +} + +// GetNodeTime provides a mock function with no fields +func (_m *MockExecutionHelper) GetNodeTime() time.Time { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetNodeTime") + } + + var r0 time.Time + if rf, ok := ret.Get(0).(func() time.Time); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(time.Time) + } + + return r0 +} + +// MockExecutionHelper_GetNodeTime_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetNodeTime' +type MockExecutionHelper_GetNodeTime_Call struct { + *mock.Call +} + +// GetNodeTime is a helper method to define mock.On call +func (_e *MockExecutionHelper_Expecter) GetNodeTime() *MockExecutionHelper_GetNodeTime_Call { + return &MockExecutionHelper_GetNodeTime_Call{Call: _e.mock.On("GetNodeTime")} +} + +func (_c *MockExecutionHelper_GetNodeTime_Call) Run(run func()) *MockExecutionHelper_GetNodeTime_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockExecutionHelper_GetNodeTime_Call) Return(_a0 time.Time) *MockExecutionHelper_GetNodeTime_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockExecutionHelper_GetNodeTime_Call) RunAndReturn(run func() time.Time) *MockExecutionHelper_GetNodeTime_Call { + _c.Call.Return(run) + return _c +} + +// NewMockExecutionHelper creates a new instance of MockExecutionHelper. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockExecutionHelper(t interface { + mock.TestingT + Cleanup(func()) +}) *MockExecutionHelper { + mock := &MockExecutionHelper{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/workflows/wasm/host/mocks/module_v2.go b/pkg/workflows/wasm/host/mocks/module_v2.go index ba16a8855..72da9a931 100644 --- a/pkg/workflows/wasm/host/mocks/module_v2.go +++ b/pkg/workflows/wasm/host/mocks/module_v2.go @@ -57,7 +57,7 @@ func (_c *ModuleV2_Close_Call) RunAndReturn(run func()) *ModuleV2_Close_Call { } // Execute provides a mock function with given fields: ctx, request, handler -func (_m *ModuleV2) Execute(ctx context.Context, request *pb.ExecuteRequest, handler host.CapabilityExecutor) (*pb.ExecutionResult, error) { +func (_m *ModuleV2) Execute(ctx context.Context, request *pb.ExecuteRequest, handler host.ExecutionHelper) (*pb.ExecutionResult, error) { ret := _m.Called(ctx, request, handler) if len(ret) == 0 { @@ -66,10 +66,10 @@ func (_m *ModuleV2) Execute(ctx context.Context, request *pb.ExecuteRequest, han var r0 *pb.ExecutionResult var r1 error - if rf, ok := ret.Get(0).(func(context.Context, *pb.ExecuteRequest, host.CapabilityExecutor) (*pb.ExecutionResult, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, *pb.ExecuteRequest, host.ExecutionHelper) (*pb.ExecutionResult, error)); ok { return rf(ctx, request, handler) } - if rf, ok := ret.Get(0).(func(context.Context, *pb.ExecuteRequest, host.CapabilityExecutor) *pb.ExecutionResult); ok { + if rf, ok := ret.Get(0).(func(context.Context, *pb.ExecuteRequest, host.ExecutionHelper) *pb.ExecutionResult); ok { r0 = rf(ctx, request, handler) } else { if ret.Get(0) != nil { @@ -77,7 +77,7 @@ func (_m *ModuleV2) Execute(ctx context.Context, request *pb.ExecuteRequest, han } } - if rf, ok := ret.Get(1).(func(context.Context, *pb.ExecuteRequest, host.CapabilityExecutor) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, *pb.ExecuteRequest, host.ExecutionHelper) error); ok { r1 = rf(ctx, request, handler) } else { r1 = ret.Error(1) @@ -94,14 +94,14 @@ type ModuleV2_Execute_Call struct { // Execute is a helper method to define mock.On call // - ctx context.Context // - request *pb.ExecuteRequest -// - handler host.CapabilityExecutor +// - handler host.ExecutionHelper func (_e *ModuleV2_Expecter) Execute(ctx interface{}, request interface{}, handler interface{}) *ModuleV2_Execute_Call { return &ModuleV2_Execute_Call{Call: _e.mock.On("Execute", ctx, request, handler)} } -func (_c *ModuleV2_Execute_Call) Run(run func(ctx context.Context, request *pb.ExecuteRequest, handler host.CapabilityExecutor)) *ModuleV2_Execute_Call { +func (_c *ModuleV2_Execute_Call) Run(run func(ctx context.Context, request *pb.ExecuteRequest, handler host.ExecutionHelper)) *ModuleV2_Execute_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(*pb.ExecuteRequest), args[2].(host.CapabilityExecutor)) + run(args[0].(context.Context), args[1].(*pb.ExecuteRequest), args[2].(host.ExecutionHelper)) }) return _c } @@ -111,7 +111,7 @@ func (_c *ModuleV2_Execute_Call) Return(_a0 *pb.ExecutionResult, _a1 error) *Mod return _c } -func (_c *ModuleV2_Execute_Call) RunAndReturn(run func(context.Context, *pb.ExecuteRequest, host.CapabilityExecutor) (*pb.ExecutionResult, error)) *ModuleV2_Execute_Call { +func (_c *ModuleV2_Execute_Call) RunAndReturn(run func(context.Context, *pb.ExecuteRequest, host.ExecutionHelper) (*pb.ExecutionResult, error)) *ModuleV2_Execute_Call { _c.Call.Return(run) return _c } diff --git a/pkg/workflows/wasm/host/module.go b/pkg/workflows/wasm/host/module.go index 6e4e1f249..1d9df7670 100644 --- a/pkg/workflows/wasm/host/module.go +++ b/pkg/workflows/wasm/host/module.go @@ -8,8 +8,10 @@ import ( "encoding/json" "errors" "fmt" + "hash/fnv" "io" "math" + "math/rand" "regexp" "strings" "sync" @@ -86,13 +88,19 @@ type ModuleV2 interface { ModuleBase // V2/"NoDAG" API - request either the list of Trigger Subscriptions or launch workflow execution - Execute(ctx context.Context, request *wasmpb.ExecuteRequest, handler CapabilityExecutor) (*wasmpb.ExecutionResult, error) + Execute(ctx context.Context, request *wasmpb.ExecuteRequest, handler ExecutionHelper) (*wasmpb.ExecutionResult, error) } -// Implemented by the Workflow Engine -type CapabilityExecutor interface { - // blocking call to the Workflow Engine +// ExecutionHelper Implemented by those running the host, for example the Workflow Engine +type ExecutionHelper interface { + // CallCapability blocking call to the Workflow Engine CallCapability(ctx context.Context, request *sdkpb.CapabilityRequest) (*sdkpb.CapabilityResponse, error) + + GetId() string + + GetNodeTime() time.Time + + GetDONTime() time.Time } type module struct { @@ -243,7 +251,7 @@ func NewModule(modCfg *ModuleConfig, binary []byte, opts ...func(*ModuleConfig)) } func linkNoDAG(m *module, store *wasmtime.Store, exec *execution[*wasmpb.ExecutionResult]) (*wasmtime.Instance, error) { - linker, err := newWasiLinker(m.cfg, m.engine) + linker, err := newWasiLinker(exec, m.engine) if err != nil { return nil, err } @@ -290,11 +298,25 @@ func linkNoDAG(m *module, store *wasmtime.Store, exec *execution[*wasmpb.Executi return nil, fmt.Errorf("error wrapping log func: %w", err) } + if err = linker.FuncWrap( + "env", + "switch_modes", + exec.switchModes); err != nil { + return nil, fmt.Errorf("error wrapping switchModes func: %w", err) + } + + if err = linker.FuncWrap( + "env", + "random_seed", + exec.getSeed); err != nil { + return nil, fmt.Errorf("error wrapping getSeed func: %w", err) + } + return linker.Instantiate(store, m.module) } func linkLegacyDAG(m *module, store *wasmtime.Store, exec *execution[*wasmdagpb.Response]) (*wasmtime.Instance, error) { - linker, err := newWasiLinker(m.cfg, m.engine) + linker, err := newDagWasiLinker(m.cfg, m.engine) if err != nil { return nil, err } @@ -370,7 +392,7 @@ func (m *module) IsLegacyDAG() bool { return m.v2ImportName == "" } -func (m *module) Execute(ctx context.Context, req *wasmpb.ExecuteRequest, executor CapabilityExecutor) (*wasmpb.ExecutionResult, error) { +func (m *module) Execute(ctx context.Context, req *wasmpb.ExecuteRequest, executor ExecutionHelper) (*wasmpb.ExecutionResult, error) { if m.IsLegacyDAG() { return nil, errors.New("cannot execute a legacy dag workflow") } @@ -422,7 +444,7 @@ func runWasm[I, O proto.Message]( request I, setMaxResponseSize func(i I, maxSize uint64), linkWasm linkFn[O], - executor CapabilityExecutor) (O, error) { + helper ExecutionHelper) (O, error) { var o O @@ -468,11 +490,23 @@ func runWasm[I, O proto.Message]( deadline := *m.cfg.Timeout / m.cfg.TickInterval store.SetEpochDeadline(uint64(deadline)) + h := fnv.New64a() + if helper != nil { + id := helper.GetId() + _, _ = h.Write([]byte(id)) + } + + donSeed := int64(h.Sum64()) + + _ = ctxWithTimeout exec := &execution[O]{ - ctx: ctxWithTimeout, + //ctx: ctxWithTimeout, + ctx: ctx, capabilityResponses: map[int32]<-chan *sdkpb.CapabilityResponse{}, module: m, - executor: executor, + executor: helper, + donSeed: donSeed, + nodeSeed: int64(rand.Uint64()), } instance, err := linkWasm(m, store, exec) diff --git a/pkg/workflows/wasm/host/module_test.go b/pkg/workflows/wasm/host/module_test.go index 8637e1b23..c284320ae 100644 --- a/pkg/workflows/wasm/host/module_test.go +++ b/pkg/workflows/wasm/host/module_test.go @@ -590,8 +590,8 @@ func Test_toEmissible(t *testing.T) { // CallAwaitRace validates that every call can be awaited. func Test_CallAwaitRace(t *testing.T) { ctx := t.Context() - mockCapExec := NewMockCapabilityExecutor(t) - mockCapExec.EXPECT(). + mockExecHelper := NewMockExecutionHelper(t) + mockExecHelper.EXPECT(). CallCapability(matches.AnyContext, mock.Anything). Return(&sdkpb.CapabilityResponse{}, nil) @@ -604,7 +604,7 @@ func Test_CallAwaitRace(t *testing.T) { module: m, capabilityResponses: map[int32]<-chan *sdkpb.CapabilityResponse{}, ctx: t.Context(), - executor: mockCapExec, + executor: mockExecHelper, } wg.Add(wantAttempts) diff --git a/pkg/workflows/wasm/host/test/nodag/randoms/cmd/main.go b/pkg/workflows/wasm/host/test/nodag/randoms/cmd/main.go new file mode 100644 index 000000000..27ca5ab0c --- /dev/null +++ b/pkg/workflows/wasm/host/test/nodag/randoms/cmd/main.go @@ -0,0 +1,54 @@ +//go:build wasip1 + +package main + +import ( + "strconv" + + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc/pkg/test_capabilities/basictrigger" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc/pkg/test_capabilities/nodeaction" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/sdk/v2" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/testhelpers/v2" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/v2" +) + +func main() { + runner := wasm.NewDonRunner() + basic := &basictrigger.Basic{} + + runner.Run(&sdk.WorkflowArgs[sdk.DonRuntime]{ + Handlers: []sdk.Handler[sdk.DonRuntime]{ + sdk.NewDonHandler( + basic.Trigger(testhelpers.TestWorkflowTriggerConfig()), + func(runtime sdk.DonRuntime, trigger *basictrigger.Outputs) (uint64, error) { + r, err := runtime.Rand() + if err != nil { + return 0, err + } + total := r.Uint64() + sdk.RunInNodeMode[uint64](runtime, func(nrt sdk.NodeRuntime) (uint64, error) { + node, err := (&nodeaction.BasicAction{}).PerformAction(nrt, &nodeaction.NodeInputs{ + InputThing: false, + }).Await() + + if err != nil { + return 0, err + } + + // Conditionally generate a random number based on the node output. + // This ensures it doesn't impact the next DON mode number. + if node.OutputThing < 100 { + nr, err := nrt.Rand() + if err != nil { + return 0, err + } + runtime.LogWriter().Write([]byte(strconv.FormatUint(nr.Uint64(), 10))) + } + return 0, nil + }, sdk.ConsensusIdenticalAggregation[uint64]()) + total += r.Uint64() + return total, nil + }), + }, + }) +} diff --git a/pkg/workflows/wasm/host/wasip1.go b/pkg/workflows/wasm/host/wasip1.go index 3343bbf70..b72181911 100644 --- a/pkg/workflows/wasm/host/wasip1.go +++ b/pkg/workflows/wasm/host/wasip1.go @@ -17,7 +17,39 @@ var ( tick = 100 * time.Millisecond ) -func newWasiLinker(modCfg *ModuleConfig, engine *wasmtime.Engine) (*wasmtime.Linker, error) { +func newWasiLinker[T any](exec *execution[T], engine *wasmtime.Engine) (*wasmtime.Linker, error) { + linker := wasmtime.NewLinker(engine) + linker.AllowShadowing(true) + + err := linker.DefineWasi() + if err != nil { + return nil, err + } + + // TODO: https://smartcontract-it.atlassian.net/browse/CAPPL-903 + err = linker.FuncWrap( + "wasi_snapshot_preview1", + "poll_oneoff", + pollOneoff, + ) + if err != nil { + return nil, err + } + + // TODO: https://smartcontract-it.atlassian.net/browse/CAPPL-903 + err = linker.FuncWrap( + "wasi_snapshot_preview1", + "clock_time_get", + clockTimeGet, + ) + if err != nil { + return nil, err + } + + return linker, nil +} + +func newDagWasiLinker(modCfg *ModuleConfig, engine *wasmtime.Engine) (*wasmtime.Linker, error) { linker := wasmtime.NewLinker(engine) linker.AllowShadowing(true) diff --git a/pkg/workflows/wasm/host/wasm_nodag_test.go b/pkg/workflows/wasm/host/wasm_nodag_test.go index 84ec5de2b..e99193d97 100644 --- a/pkg/workflows/wasm/host/wasm_nodag_test.go +++ b/pkg/workflows/wasm/host/wasm_nodag_test.go @@ -7,6 +7,7 @@ import ( "strings" "testing" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc/pkg/test_capabilities/nodeaction" sdkpb "github.com/smartcontractkit/chainlink-common/pkg/workflows/sdk/v2/pb" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -23,10 +24,12 @@ import ( ) const ( - nodagBinaryLocation = "test/nodag/singlehandler/cmd/testmodule.wasm" - nodagMultiTriggerBinaryLocation = "test/nodag/multihandler/cmd/testmodule.wasm" nodagBinaryCmd = "test/nodag/singlehandler/cmd" + nodagBinaryLocation = nodagBinaryCmd + "/testmodule.wasm" nodagMultiTriggerBinaryCmd = "test/nodag/multihandler/cmd" + nodagMultiTriggerBinaryLocation = nodagMultiTriggerBinaryCmd + "/testmodule.wasm" + nodagRandomBinaryCmd = "test/nodag/randoms/cmd" + nodagRandomBinaryLocation = nodagRandomBinaryCmd + "/testmodule.wasm" ) var wordList = []string{"Hello, ", "world", "!"} @@ -36,7 +39,7 @@ func Test_NoDag_Run(t *testing.T) { binary := createTestBinary(nodagBinaryCmd, nodagBinaryLocation, true, t) - t.Run("NOK fails with unset CapabilityExecutor for trigger", func(t *testing.T) { + t.Run("NOK fails with unset ExecutionHelper for trigger", func(t *testing.T) { mc := defaultNoDAGModCfg(t) m, err := NewModule(mc, binary) require.NoError(t, err) @@ -54,7 +57,7 @@ func Test_NoDag_Run(t *testing.T) { require.ErrorContains(t, err, "invalid capability executor") }) - t.Run("OK can subscribe without setting CapabilityExecutor", func(t *testing.T) { + t.Run("OK can subscribe without setting ExecutionHelper", func(t *testing.T) { mc := defaultNoDAGModCfg(t) m, err := NewModule(mc, binary) require.NoError(t, err) @@ -81,7 +84,8 @@ func Test_NoDag_Run(t *testing.T) { m.Start() defer m.Close() - mockCapExecutor := NewMockCapabilityExecutor(t) + mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper.EXPECT().GetId().Return("Id") // wrap some common payload newWantedCapResponse := func(i int) *sdkpb.CapabilityResponse { @@ -97,7 +101,7 @@ func Test_NoDag_Run(t *testing.T) { for i := 1; i < len(wordList); i++ { wantCapResp := newWantedCapResponse(i) - mockCapExecutor.EXPECT().CallCapability(mock.Anything, mock.Anything). + mockExecutionHelper.EXPECT().CallCapability(mock.Anything, mock.Anything). Run( func(ctx context.Context, request *sdkpb.CapabilityRequest) { require.Equal(t, "basic-test-action@1.0.0", request.Id) @@ -121,7 +125,7 @@ func Test_NoDag_Run(t *testing.T) { }, } - response, err := m.Execute(ctx, req, mockCapExecutor) + response, err := m.Execute(ctx, req, mockExecutionHelper) require.NoError(t, err) logs := observer.TakeAll() @@ -195,9 +199,9 @@ func Test_NoDag_MultipleTriggers_Run(t *testing.T) { m.Start() defer m.Close() - mockCapExecutor := NewMockCapabilityExecutor(t) + mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper.EXPECT().GetId().Return("Id") - // wrap some common payload newWantedCapResponse := func(i int) *sdkpb.CapabilityResponse { action := &basicaction.Outputs{AdaptedThing: wordList[i]} anyAction, err := anypb.New(action) @@ -211,7 +215,7 @@ func Test_NoDag_MultipleTriggers_Run(t *testing.T) { for i := 1; i < len(wordList); i++ { wantCapResp := newWantedCapResponse(i) - mockCapExecutor.EXPECT().CallCapability(mock.Anything, mock.Anything). + mockExecutionHelper.EXPECT().CallCapability(mock.Anything, mock.Anything). Run( func(ctx context.Context, request *sdkpb.CapabilityRequest) { require.Equal(t, "basic-test-action@1.0.0", request.Id) @@ -234,7 +238,7 @@ func Test_NoDag_MultipleTriggers_Run(t *testing.T) { }, }, } - response, err := m.Execute(ctx, req, mockCapExecutor) + response, err := m.Execute(ctx, req, mockExecutionHelper) require.NoError(t, err) switch output := response.Result.(type) { @@ -251,6 +255,100 @@ func Test_NoDag_MultipleTriggers_Run(t *testing.T) { }) } +func Test_NoDag_Random(t *testing.T) { + t.Parallel() + + mc := defaultNoDAGModCfg(t) + lggr, observed := logger.TestObserved(t, zapcore.DebugLevel) + mc.Logger = lggr + + binary := createTestBinary(nodagRandomBinaryCmd, nodagRandomBinaryLocation, true, t) + + m, err := NewModule(mc, binary) + require.NoError(t, err) + + // Test binary executes node mode code conditionally based on the value >= 100 + anyId := "Id" + gte100Exec := NewMockExecutionHelper(t) + gte100Exec.EXPECT().GetId().Return(anyId) + gte100 := &nodeaction.NodeOutputs{OutputThing: 120} + gte100Payload, err := anypb.New(gte100) + require.NoError(t, err) + + gte100Exec.EXPECT().CallCapability(mock.Anything, mock.Anything).Return(&sdkpb.CapabilityResponse{ + Response: &sdkpb.CapabilityResponse_Payload{ + Payload: gte100Payload, + }, + }, nil) + + m.Start() + defer m.Close() + + trigger := &basictrigger.Outputs{CoolOutput: "trigger1"} + triggerPayload, err := anypb.New(trigger) + require.NoError(t, err) + anyRequest := &wasmpb.ExecuteRequest{ + Request: &wasmpb.ExecuteRequest_Trigger{ + Trigger: &sdkpb.Trigger{ + Id: uint64(0), + Payload: triggerPayload, + }, + }, + } + execution1Result, err := m.Execute(t.Context(), anyRequest, gte100Exec) + require.NoError(t, err) + wrappedValue1, err := values.FromProto(execution1Result.GetValue()) + require.NoError(t, err) + value1, err := wrappedValue1.Unwrap() + require.NoError(t, err) + + t.Run("Same execution id gives the same randoms, even if random is called in node mode", func(t *testing.T) { + // Clear from any previous test + observed.TakeAll() + + lt100Exec := NewMockExecutionHelper(t) + lt100Exec.EXPECT().GetId().Return(anyId) + lt100 := &nodeaction.NodeOutputs{OutputThing: 120} + lt100Payload, err := anypb.New(lt100) + require.NoError(t, err) + + lt100Exec.EXPECT().CallCapability(mock.Anything, mock.Anything).Return(&sdkpb.CapabilityResponse{ + Response: &sdkpb.CapabilityResponse_Payload{ + Payload: lt100Payload, + }, + }, nil) + + exectuion2Result, err := m.Execute(t.Context(), anyRequest, lt100Exec) + require.NoError(t, err) + wrappedValue2, err := values.FromProto(exectuion2Result.GetValue()) + require.NoError(t, err) + value2, err := wrappedValue2.Unwrap() + require.NoError(t, err) + require.Equal(t, value1, value2, "Expected the same random number to be generated for the same trigger") + }) + + t.Run("Different execution id give different randoms", func(t *testing.T) { + require.NoError(t, err) + + gte100Exec2 := NewMockExecutionHelper(t) + gte100Exec2.EXPECT().GetId().Return("differentId") + + gte100Exec2.EXPECT().CallCapability(mock.Anything, mock.Anything).Return(&sdkpb.CapabilityResponse{ + Response: &sdkpb.CapabilityResponse_Payload{ + Payload: gte100Payload, + }, + }, nil) + + executionResult2, err := m.Execute(t.Context(), anyRequest, gte100Exec2) + require.NoError(t, err) + wrappedValue2, err := values.FromProto(executionResult2.GetValue()) + require.NoError(t, err) + value2, err := wrappedValue2.Unwrap() + require.NoError(t, err) + require.NotEqual(t, value1, value2, "Expected different random numbers for different triggers") + }) +} + func defaultNoDAGModCfg(t testing.TB) *ModuleConfig { return &ModuleConfig{ Logger: logger.Test(t), @@ -259,10 +357,12 @@ func defaultNoDAGModCfg(t testing.TB) *ModuleConfig { } func getTriggersSpec(t *testing.T, m ModuleV2, config []byte) (*sdkpb.TriggerSubscriptionRequest, error) { + helper := NewMockExecutionHelper(t) + helper.EXPECT().GetId().Return("Id") execResult, err := m.Execute(t.Context(), &wasmpb.ExecuteRequest{ Config: config, Request: &wasmpb.ExecuteRequest_Subscribe{Subscribe: &emptypb.Empty{}}, - }, NewMockCapabilityExecutor(t)) + }, helper) if err != nil { return nil, err diff --git a/pkg/workflows/wasm/v2/runner.go b/pkg/workflows/wasm/v2/runner.go index 3d6930a77..093fdb627 100644 --- a/pkg/workflows/wasm/v2/runner.go +++ b/pkg/workflows/wasm/v2/runner.go @@ -20,6 +20,7 @@ type runnerInternals interface { args() []string sendResponse(response unsafe.Pointer, responseLen int32) int32 versionV2() + switchModes(mode int32) } func newDonRunner(runnerInternals runnerInternals, runtimeInternals runtimeInternals) sdk.DonRunner { diff --git a/pkg/workflows/wasm/v2/runner_test_hooks.go b/pkg/workflows/wasm/v2/runner_test_hooks.go index 83c7527c9..6344d29c1 100644 --- a/pkg/workflows/wasm/v2/runner_test_hooks.go +++ b/pkg/workflows/wasm/v2/runner_test_hooks.go @@ -12,6 +12,8 @@ type runnerInternalsTestHook struct { execId string arguments []string sentResponse []byte + modeSwitched bool + mode int32 } func (r *runnerInternalsTestHook) args() []string { @@ -25,4 +27,9 @@ func (r *runnerInternalsTestHook) sendResponse(response unsafe.Pointer, response func (r *runnerInternalsTestHook) versionV2() {} +func (r *runnerInternalsTestHook) switchModes(mode int32) { + r.mode = mode + r.modeSwitched = true +} + var _ runnerInternals = (*runnerInternalsTestHook)(nil) diff --git a/pkg/workflows/wasm/v2/runner_wasip1.go b/pkg/workflows/wasm/v2/runner_wasip1.go index d4d29fb4f..145b76738 100644 --- a/pkg/workflows/wasm/v2/runner_wasip1.go +++ b/pkg/workflows/wasm/v2/runner_wasip1.go @@ -5,6 +5,7 @@ import ( "unsafe" "github.com/smartcontractkit/chainlink-common/pkg/workflows/sdk/v2" + sdkpb "github.com/smartcontractkit/chainlink-common/pkg/workflows/sdk/v2/pb" ) //go:wasmimport env send_response @@ -13,11 +14,16 @@ func sendResponse(response unsafe.Pointer, responseLen int32) int32 //go:wasmimport env version_v2 func versionV2() +//go:wasmimport env switch_modes +func switchModes(mode int32) + func NewDonRunner() sdk.DonRunner { + switchModes((int32)(sdkpb.Mode_DON)) return newDonRunner(runnerInternalsImpl{}, runtimeInternalsImpl{}) } func NewNodeRunner() sdk.NodeRunner { + switchModes((int32)(sdkpb.Mode_Node)) return newNodeRunner(runnerInternalsImpl{}, runtimeInternalsImpl{}) } @@ -36,3 +42,7 @@ func (r runnerInternalsImpl) sendResponse(response unsafe.Pointer, responseLen i func (r runnerInternalsImpl) versionV2() { versionV2() } + +func (r runnerInternalsImpl) switchModes(mode int32) { + switchModes(mode) +} diff --git a/pkg/workflows/wasm/v2/runtime.go b/pkg/workflows/wasm/v2/runtime.go index 07e3953d0..033991d35 100644 --- a/pkg/workflows/wasm/v2/runtime.go +++ b/pkg/workflows/wasm/v2/runtime.go @@ -2,6 +2,7 @@ package wasm import ( "errors" + "math/rand" "unsafe" "github.com/smartcontractkit/chainlink-common/pkg/workflows/internal/v2/sdkimpl" @@ -12,69 +13,92 @@ import ( type runtimeInternals interface { callCapability(req unsafe.Pointer, reqLen int32) int64 awaitCapabilities(awaitRequest unsafe.Pointer, awaitRequestLen int32, responseBuffer unsafe.Pointer, maxResponseLen int32) int64 + switchModes(mode int32) + getSeed(mode int32) int64 } func newRuntime(internals runtimeInternals, mode sdkpb.Mode) sdkimpl.RuntimeBase { return sdkimpl.RuntimeBase{ - Call: callCapabilityWasmWrapper(internals), - Await: awaitCapabilitiesWasmWrapper(internals), - Writer: &writer{}, - Mode: mode, + Writer: &writer{}, + Mode: mode, + RuntimeHelpers: &runtimeHelper{runtimeInternals: internals}, } } -func callCapabilityWasmWrapper(internals runtimeInternals) func(request *sdkpb.CapabilityRequest) error { - return func(request *sdkpb.CapabilityRequest) error { - marshalled, err := proto.Marshal(request) - if err != nil { - return err - } +type runtimeHelper struct { + runtimeInternals + donSource rand.Source + nodeSource rand.Source +} - marshalledPtr, marshalledLen, err := bufferToPointerLen(marshalled) - if err != nil { - return err +func (r *runtimeHelper) GetSource(mode sdkpb.Mode) rand.Source { + switch mode { + case sdkpb.Mode_DON: + if r.donSource == nil { + seed := r.getSeed(int32(mode)) + r.donSource = rand.NewSource(seed) } - - // TODO (CAPPL-846): callCapability should also have a response pointer and response pointer buffer - result := internals.callCapability(marshalledPtr, marshalledLen) - if result < 0 { - return errors.New("cannot find capability " + request.Id) + return r.donSource + default: + if r.nodeSource == nil { + seed := r.getSeed(int32(mode)) + r.nodeSource = rand.NewSource(seed) } - - return nil + return r.nodeSource } } -func awaitCapabilitiesWasmWrapper(internals runtimeInternals) func(request *sdkpb.AwaitCapabilitiesRequest, maxResponseSize uint64) (*sdkpb.AwaitCapabilitiesResponse, error) { - return func(request *sdkpb.AwaitCapabilitiesRequest, maxResponseSize uint64) (*sdkpb.AwaitCapabilitiesResponse, error) { +func (r *runtimeHelper) Call(request *sdkpb.CapabilityRequest) error { + marshalled, err := proto.Marshal(request) + if err != nil { + return err + } - m, err := proto.Marshal(request) - if err != nil { - return nil, err - } + marshalledPtr, marshalledLen, err := bufferToPointerLen(marshalled) + if err != nil { + return err + } - mptr, mlen, err := bufferToPointerLen(m) - if err != nil { - return nil, err - } + // TODO (CAPPL-846): callCapability should also have a response pointer and response pointer buffer + result := r.callCapability(marshalledPtr, marshalledLen) + if result < 0 { + return errors.New("cannot find capability " + request.Id) + } - response := make([]byte, maxResponseSize) - responsePtr, responseLen, err := bufferToPointerLen(response) - if err != nil { - return nil, err - } + return nil +} - bytes := internals.awaitCapabilities(mptr, mlen, responsePtr, responseLen) - if bytes < 0 { - return nil, errors.New(string(response[:-bytes])) - } +func (r *runtimeHelper) Await(request *sdkpb.AwaitCapabilitiesRequest, maxResponseSize uint64) (*sdkpb.AwaitCapabilitiesResponse, error) { + m, err := proto.Marshal(request) + if err != nil { + return nil, err + } - awaitResponse := &sdkpb.AwaitCapabilitiesResponse{} - err = proto.Unmarshal(response[:bytes], awaitResponse) - if err != nil { - return nil, err - } + mptr, mlen, err := bufferToPointerLen(m) + if err != nil { + return nil, err + } + + response := make([]byte, maxResponseSize) + responsePtr, responseLen, err := bufferToPointerLen(response) + if err != nil { + return nil, err + } - return awaitResponse, nil + bytes := r.awaitCapabilities(mptr, mlen, responsePtr, responseLen) + if bytes < 0 { + return nil, errors.New(string(response[:-bytes])) } + + awaitResponse := &sdkpb.AwaitCapabilitiesResponse{} + err = proto.Unmarshal(response[:bytes], awaitResponse) + if err != nil { + return nil, err + } + + return awaitResponse, nil +} + +func (r *runtimeHelper) SwitchModes(mode sdkpb.Mode) { + r.switchModes(int32(mode)) } diff --git a/pkg/workflows/wasm/v2/runtime_test.go b/pkg/workflows/wasm/v2/runtime_test.go index 493be10dc..f8292808d 100644 --- a/pkg/workflows/wasm/v2/runtime_test.go +++ b/pkg/workflows/wasm/v2/runtime_test.go @@ -3,6 +3,7 @@ package wasm import ( "context" "errors" + "math/rand" "testing" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc/pkg/test_capabilities/basicaction" @@ -81,6 +82,30 @@ func TestRuntimeBase_LogWriter(t *testing.T) { assert.IsType(t, &writer{}, runtime.LogWriter()) } +func Test_runtimeInternals_UsesSeeds(t *testing.T) { + anyDonSeed := int64(123456789) + anyNodeSeed := int64(987654321) + helper := &runtimeHelper{runtimeInternals: &runtimeInternalsTestHook{ + donSeed: anyDonSeed, + nodeSeed: anyNodeSeed, + }} + assertRnd(t, helper, sdkpb.Mode_DON, anyDonSeed) + assertRnd(t, helper, sdkpb.Mode_Node, anyNodeSeed) +} + +func assertRnd(t *testing.T, helper *runtimeHelper, mode sdkpb.Mode, seed int64) { + rnd := rand.New(helper.GetSource(mode)) + buff := make([]byte, 1000) + n, err := rnd.Read(buff) + require.NoError(t, err) + assert.Equal(t, len(buff), n) + expectedBuf := make([]byte, 1000) + n, err = rand.New(rand.NewSource(seed)).Read(expectedBuf) + require.NoError(t, err) + assert.Equal(t, len(expectedBuf), n) + assert.Equal(t, string(expectedBuf), string(buff)) +} + func newTestRuntime(t *testing.T, callCapabilityErr bool, awaitResponseOverride func() ([]byte, error)) sdkimpl.RuntimeBase { internals := testRuntimeInternals(t) internals.callCapabilityErr = callCapabilityErr diff --git a/pkg/workflows/wasm/v2/runtime_test_hooks.go b/pkg/workflows/wasm/v2/runtime_test_hooks.go index 3be8d9d38..33ae81263 100644 --- a/pkg/workflows/wasm/v2/runtime_test_hooks.go +++ b/pkg/workflows/wasm/v2/runtime_test_hooks.go @@ -19,6 +19,8 @@ type runtimeInternalsTestHook struct { awaitResponseOverride func() ([]byte, error) callCapabilityErr bool outstandingCalls map[int32]sdk.Promise[*sdkpb.CapabilityResponse] + nodeSeed int64 + donSeed int64 } var _ runtimeInternals = (*runtimeInternalsTestHook)(nil) @@ -118,3 +120,14 @@ func readHostMessage(response []byte, msg string, isError bool) int64 { return written } + +func (r *runtimeInternalsTestHook) switchModes(_ int32) {} + +func (r *runtimeInternalsTestHook) getSeed(mode int32) int64 { + switch mode { + case int32(sdkpb.Mode_DON): + return r.donSeed + default: + return r.nodeSeed + } +} diff --git a/pkg/workflows/wasm/v2/runtime_wasip1.go b/pkg/workflows/wasm/v2/runtime_wasip1.go index 6cfd330c2..d3e2fe80f 100644 --- a/pkg/workflows/wasm/v2/runtime_wasip1.go +++ b/pkg/workflows/wasm/v2/runtime_wasip1.go @@ -10,6 +10,9 @@ func callCapability(req unsafe.Pointer, reqLen int32) int64 //go:wasmimport env await_capabilities func awaitCapabilities(awaitRequest unsafe.Pointer, awaitRequestLen int32, responseBuffer unsafe.Pointer, maxResponseLen int32) int64 +//go:wasmimport env random_seed +func getSeed(mode int32) int64 + type runtimeInternalsImpl struct{} var _ runtimeInternals = runtimeInternalsImpl{} @@ -21,3 +24,11 @@ func (r runtimeInternalsImpl) callCapability(req unsafe.Pointer, reqLen int32) i func (r runtimeInternalsImpl) awaitCapabilities(awaitRequest unsafe.Pointer, awaitRequestLen int32, responseBuffer unsafe.Pointer, maxResponseLen int32) int64 { return awaitCapabilities(awaitRequest, awaitRequestLen, responseBuffer, maxResponseLen) } + +func (r runtimeInternalsImpl) switchModes(mode int32) { + switchModes(mode) +} + +func (r runtimeInternalsImpl) getSeed(mode int32) int64 { + return getSeed(mode) +} From 00a1fdb4172af042269eefc66880b0b900ab89d4 Mon Sep 17 00:00:00 2001 From: Silas Lenihan Date: Wed, 11 Jun 2025 13:59:52 -0400 Subject: [PATCH 14/16] Revert "requests handling (#1247)" This reverts commit 26e78071ce46e347cf2fc8256b707b856279877e. --- .../consensus/ocr3/benchmark_test.go | 15 +- pkg/capabilities/consensus/ocr3/capability.go | 18 +- .../consensus/ocr3/capability_test.go | 24 +-- pkg/capabilities/consensus/ocr3/factory.go | 6 +- pkg/capabilities/consensus/ocr3/ocr3.go | 6 +- .../consensus/ocr3/reporting_plugin.go | 11 +- .../consensus/ocr3/reporting_plugin_test.go | 36 ++-- .../consensus/ocr3/requests/handler.go | 153 +++++++++++++++++ .../{ => ocr3}/requests/handler_test.go | 38 ++--- .../request.go} | 43 +---- .../consensus/{ => ocr3}/requests/store.go | 52 +++--- .../{ => ocr3}/requests/store_test.go | 60 +++---- .../consensus/ocr3/transmitter_test.go | 6 +- .../consensus/requests/handler.go | 160 ------------------ pkg/capabilities/v2/consensus/consensus.pb.go | 4 +- pkg/capabilities/v2/consensus/consensus.proto | 2 +- .../v2/consensus/consensus_sdk_gen.go | 2 +- .../consensusmock/consensus_mock_gen.go | 2 +- .../consensus/server/consensus_server_gen.go | 4 +- 19 files changed, 296 insertions(+), 346 deletions(-) create mode 100644 pkg/capabilities/consensus/ocr3/requests/handler.go rename pkg/capabilities/consensus/{ => ocr3}/requests/handler_test.go (69%) rename pkg/capabilities/consensus/ocr3/{report_request.go => requests/request.go} (62%) rename pkg/capabilities/consensus/{ => ocr3}/requests/store.go (52%) rename pkg/capabilities/consensus/{ => ocr3}/requests/store_test.go (70%) delete mode 100644 pkg/capabilities/consensus/requests/handler.go diff --git a/pkg/capabilities/consensus/ocr3/benchmark_test.go b/pkg/capabilities/consensus/ocr3/benchmark_test.go index d124184d2..09f9b983a 100644 --- a/pkg/capabilities/consensus/ocr3/benchmark_test.go +++ b/pkg/capabilities/consensus/ocr3/benchmark_test.go @@ -8,19 +8,18 @@ import ( "time" "github.com/shopspring/decimal" + ocrcommon "github.com/smartcontractkit/libocr/commontypes" + "github.com/smartcontractkit/libocr/offchainreporting2/types" + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" "github.com/stretchr/testify/require" "go.uber.org/zap/zapcore" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" - ocrcommon "github.com/smartcontractkit/libocr/commontypes" - "github.com/smartcontractkit/libocr/offchainreporting2/types" - "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" - "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/datafeeds" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/requests" pbtypes "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/types" - "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/requests" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/datastreams" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/values" @@ -109,7 +108,7 @@ func runObservationBenchmarkWithParams(b *testing.B, lggr logger.Logger, numWork ) // Create request store with requests for each workflow - store := requests.NewStore[*ocr3.ReportRequest, ocr3.ReportResponse]() + store := requests.NewStore() // Create capability with LLO aggregators for each workflow mockCap := &mockCapability{ @@ -136,7 +135,7 @@ func runObservationBenchmarkWithParams(b *testing.B, lggr logger.Logger, numWork require.NoError(b, err) // Create and add request to store - req := &ocr3.ReportRequest{ + req := &requests.Request{ WorkflowID: workflowID, WorkflowExecutionID: executionID, WorkflowName: fmt.Sprintf("Workflow %d", i), @@ -230,7 +229,7 @@ func runBenchmarkWithParams(b *testing.B, lggr logger.Logger, numWorkflows, numS ) // Create request store - store := requests.NewStore[*ocr3.ReportRequest]() + store := requests.NewStore() // Create capability with LLO aggregators for each workflow mockCap := &mockCapability{ diff --git a/pkg/capabilities/consensus/ocr3/capability.go b/pkg/capabilities/consensus/ocr3/capability.go index 2639f5653..ac0f1f4d6 100644 --- a/pkg/capabilities/consensus/ocr3/capability.go +++ b/pkg/capabilities/consensus/ocr3/capability.go @@ -10,8 +10,8 @@ import ( "google.golang.org/protobuf/proto" "github.com/smartcontractkit/chainlink-common/pkg/capabilities" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/requests" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/types" - "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/requests" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/metering" "github.com/smartcontractkit/chainlink-common/pkg/services" @@ -39,9 +39,9 @@ type capability struct { eng *services.Engine capabilities.CapabilityInfo - capabilities.Validator[config, inputs, ReportResponse] + capabilities.Validator[config, inputs, requests.Response] - reqHandler *requests.Handler[*ReportRequest, ReportResponse] + reqHandler *requests.Handler requestTimeout time.Duration requestTimeoutLock sync.RWMutex @@ -63,11 +63,11 @@ type capability struct { var _ CapabilityIface = (*capability)(nil) var _ capabilities.ExecutableCapability = (*capability)(nil) -func NewCapability(s *requests.Store[*ReportRequest, ReportResponse], clock clockwork.Clock, requestTimeout time.Duration, aggregatorFactory types.AggregatorFactory, encoderFactory types.EncoderFactory, lggr logger.Logger, +func NewCapability(s *requests.Store, clock clockwork.Clock, requestTimeout time.Duration, aggregatorFactory types.AggregatorFactory, encoderFactory types.EncoderFactory, lggr logger.Logger, callbackChannelBufferSize int) *capability { o := &capability{ CapabilityInfo: info, - Validator: capabilities.NewValidator[config, inputs, ReportResponse](capabilities.ValidatorArgs{Info: info}), + Validator: capabilities.NewValidator[config, inputs, requests.Response](capabilities.ValidatorArgs{Info: info}), clock: clock, requestTimeout: requestTimeout, aggregatorFactory: aggregatorFactory, @@ -195,7 +195,7 @@ func (o *capability) Execute(ctx context.Context, r capabilities.CapabilityReque o.eng.Debugw("Execute - terminating execution", "workflowExecutionID", r.Metadata.WorkflowExecutionID) responseErr = capabilities.ErrStopExecution } - out := ReportResponse{ + out := requests.Response{ WorkflowExecutionID: r.Metadata.WorkflowExecutionID, Value: inputs, Err: responseErr, @@ -254,8 +254,8 @@ func (o *capability) queueRequestForProcessing( metadata capabilities.RequestMetadata, i *inputs, c *config, -) (<-chan ReportResponse, error) { - callbackCh := make(chan ReportResponse, o.callbackChannelBufferSize) +) (<-chan requests.Response, error) { + callbackCh := make(chan requests.Response, o.callbackChannelBufferSize) // Use the capability-level request timeout unless the request's config specifies // its own timeout, in which case we'll use that instead. This allows the workflow spec @@ -267,7 +267,7 @@ func (o *capability) queueRequestForProcessing( } o.requestTimeoutLock.RUnlock() - r := &ReportRequest{ + r := &requests.Request{ StopCh: make(chan struct{}), CallbackCh: callbackCh, WorkflowExecutionID: metadata.WorkflowExecutionID, diff --git a/pkg/capabilities/consensus/ocr3/capability_test.go b/pkg/capabilities/consensus/ocr3/capability_test.go index 9f7a07c20..952d02dc8 100644 --- a/pkg/capabilities/consensus/ocr3/capability_test.go +++ b/pkg/capabilities/consensus/ocr3/capability_test.go @@ -13,8 +13,8 @@ import ( "github.com/stretchr/testify/require" "github.com/smartcontractkit/chainlink-common/pkg/capabilities" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/requests" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/types" - "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/requests" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/utils" "github.com/smartcontractkit/chainlink-common/pkg/values" @@ -48,7 +48,7 @@ func TestOCR3Capability_Schema(t *testing.T) { fc := clockwork.NewFakeClockAt(n) lggr := logger.Nop() - s := requests.NewStore[*ReportRequest]() + s := requests.NewStore() cp := NewCapability(s, fc, 1*time.Second, mockAggregatorFactory, mockEncoderFactory, lggr, 10) schema, err := cp.Schema() @@ -88,7 +88,7 @@ func TestOCR3Capability(t *testing.T) { ctx := t.Context() - s := requests.NewStore[*ReportRequest]() + s := requests.NewStore() cp := NewCapability(s, fc, 1*time.Second, mockAggregatorFactory, mockEncoderFactory, lggr, 10) require.NoError(t, cp.Start(ctx)) @@ -129,7 +129,7 @@ func TestOCR3Capability(t *testing.T) { // Mock the oracle returning a response mresp, err := values.NewMap(map[string]any{"observations": obsv}) - cp.reqHandler.SendResponse(ctx, ReportResponse{ + cp.reqHandler.SendResponse(ctx, requests.Response{ Value: mresp, WorkflowExecutionID: workflowExecutionTestID, }) @@ -155,7 +155,7 @@ func TestOCR3Capability_Eviction(t *testing.T) { defer cancel() rea := time.Second - s := requests.NewStore[*ReportRequest]() + s := requests.NewStore() cp := NewCapability(s, fc, rea, mockAggregatorFactory, mockEncoderFactory, lggr, 10) require.NoError(t, cp.Start(ctx)) @@ -223,7 +223,7 @@ func TestOCR3Capability_EvictionUsingConfig(t *testing.T) { defer cancel() // This is the default expired at rea := time.Hour - s := requests.NewStore[*ReportRequest]() + s := requests.NewStore() cp := NewCapability(s, fc, rea, mockAggregatorFactory, mockEncoderFactory, lggr, 10) require.NoError(t, cp.Start(ctx)) @@ -290,7 +290,7 @@ func TestOCR3Capability_Registration(t *testing.T) { lggr := logger.Test(t) ctx := t.Context() - s := requests.NewStore[*ReportRequest]() + s := requests.NewStore() cp := NewCapability(s, fc, 1*time.Second, mockAggregatorFactory, mockEncoderFactory, lggr, 10) require.NoError(t, cp.Start(ctx)) @@ -336,7 +336,7 @@ func TestOCR3Capability_ValidateConfig(t *testing.T) { fc := clockwork.NewFakeClockAt(n) lggr := logger.Test(t) - s := requests.NewStore[*ReportRequest]() + s := requests.NewStore() o := NewCapability(s, fc, 1*time.Second, mockAggregatorFactory, mockEncoderFactory, lggr, 10) @@ -411,7 +411,7 @@ func TestOCR3Capability_RespondsToLateRequest(t *testing.T) { ctx := t.Context() - s := requests.NewStore[*ReportRequest]() + s := requests.NewStore() cp := NewCapability(s, fc, 1*time.Second, mockAggregatorFactory, mockEncoderFactory, lggr, 10) require.NoError(t, cp.Start(ctx)) @@ -440,7 +440,7 @@ func TestOCR3Capability_RespondsToLateRequest(t *testing.T) { require.NoError(t, err) // Mock the oracle returning a response prior to the request being sent - cp.reqHandler.SendResponse(ctx, ReportResponse{ + cp.reqHandler.SendResponse(ctx, requests.Response{ Value: obsv, WorkflowExecutionID: workflowExecutionTestID, }) @@ -471,7 +471,7 @@ func TestOCR3Capability_RespondingToLateRequestDoesNotBlockOnSlowResponseConsume ctx := t.Context() - s := requests.NewStore[*ReportRequest]() + s := requests.NewStore() cp := NewCapability(s, fc, 1*time.Second, mockAggregatorFactory, mockEncoderFactory, lggr, 0) require.NoError(t, cp.Start(ctx)) @@ -500,7 +500,7 @@ func TestOCR3Capability_RespondingToLateRequestDoesNotBlockOnSlowResponseConsume require.NoError(t, err) // Mock the oracle returning a response prior to the request being sent - cp.reqHandler.SendResponse(ctx, ReportResponse{ + cp.reqHandler.SendResponse(ctx, requests.Response{ Value: obsv, WorkflowExecutionID: workflowExecutionTestID, }) diff --git a/pkg/capabilities/consensus/ocr3/factory.go b/pkg/capabilities/consensus/ocr3/factory.go index 1b60526b7..0e4a4c386 100644 --- a/pkg/capabilities/consensus/ocr3/factory.go +++ b/pkg/capabilities/consensus/ocr3/factory.go @@ -9,8 +9,8 @@ import ( "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/requests" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/types" - "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/requests" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/services" ) @@ -24,7 +24,7 @@ const ( ) type factory struct { - store *requests.Store[*ReportRequest, ReportResponse] + store *requests.Store capability *capability batchSize int outcomePruningThreshold uint64 @@ -33,7 +33,7 @@ type factory struct { services.StateMachine } -func newFactory(s *requests.Store[*ReportRequest, ReportResponse], c *capability, lggr logger.Logger) (*factory, error) { +func newFactory(s *requests.Store, c *capability, lggr logger.Logger) (*factory, error) { return &factory{ store: s, capability: c, diff --git a/pkg/capabilities/consensus/ocr3/ocr3.go b/pkg/capabilities/consensus/ocr3/ocr3.go index 123e4a943..e4b6046f5 100644 --- a/pkg/capabilities/consensus/ocr3/ocr3.go +++ b/pkg/capabilities/consensus/ocr3/ocr3.go @@ -6,8 +6,8 @@ import ( "github.com/jonboulle/clockwork" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/requests" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/types" - "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/requests" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/loop" "github.com/smartcontractkit/chainlink-common/pkg/loop/reportingplugins" @@ -32,7 +32,7 @@ type Config struct { EncoderFactory types.EncoderFactory SendBufferSize int - store *requests.Store[*ReportRequest, ReportResponse] + store *requests.Store capability *capability clock clockwork.Clock } @@ -56,7 +56,7 @@ func NewOCR3(config Config) *Capability { } if config.store == nil { - config.store = requests.NewStore[*ReportRequest]() + config.store = requests.NewStore() } if config.capability == nil { diff --git a/pkg/capabilities/consensus/ocr3/reporting_plugin.go b/pkg/capabilities/consensus/ocr3/reporting_plugin.go index a4d518c47..31f27fc44 100644 --- a/pkg/capabilities/consensus/ocr3/reporting_plugin.go +++ b/pkg/capabilities/consensus/ocr3/reporting_plugin.go @@ -8,16 +8,15 @@ import ( "slices" "time" + "github.com/smartcontractkit/libocr/quorumhelper" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" - "github.com/smartcontractkit/libocr/quorumhelper" - ocrcommon "github.com/smartcontractkit/libocr/commontypes" "github.com/smartcontractkit/libocr/offchainreporting2/types" "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" - "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/requests" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/requests" "google.golang.org/protobuf/types/known/structpb" @@ -39,14 +38,14 @@ type CapabilityIface interface { type reportingPlugin struct { batchSize int - s *requests.Store[*ReportRequest, ReportResponse] + s *requests.Store r CapabilityIface config ocr3types.ReportingPluginConfig outcomePruningThreshold uint64 lggr logger.Logger } -func NewReportingPlugin(s *requests.Store[*ReportRequest, ReportResponse], r CapabilityIface, batchSize int, config ocr3types.ReportingPluginConfig, +func NewReportingPlugin(s *requests.Store, r CapabilityIface, batchSize int, config ocr3types.ReportingPluginConfig, outcomePruningThreshold uint64, lggr logger.Logger) (*reportingPlugin, error) { return &reportingPlugin{ s: s, @@ -104,7 +103,7 @@ func (r *reportingPlugin) Observation(ctx context.Context, outctx ocr3types.Outc } reqs := r.s.GetByIDs(weids) - reqMap := map[string]*ReportRequest{} + reqMap := map[string]*requests.Request{} for _, req := range reqs { reqMap[req.WorkflowExecutionID] = req } diff --git a/pkg/capabilities/consensus/ocr3/reporting_plugin_test.go b/pkg/capabilities/consensus/ocr3/reporting_plugin_test.go index 30c9d976a..7974ffed8 100644 --- a/pkg/capabilities/consensus/ocr3/reporting_plugin_test.go +++ b/pkg/capabilities/consensus/ocr3/reporting_plugin_test.go @@ -17,8 +17,8 @@ import ( "github.com/smartcontractkit/libocr/offchainreporting2/types" "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/requests" pbtypes "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/types" - "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/requests" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/values" "github.com/smartcontractkit/chainlink-common/pkg/values/pb" @@ -27,7 +27,7 @@ import ( func TestReportingPlugin_Query_ErrorInQueueCall(t *testing.T) { ctx := t.Context() lggr := logger.Test(t) - s := requests.NewStore[*ReportRequest]() + s := requests.NewStore() batchSize := 0 rp, err := NewReportingPlugin(s, nil, batchSize, ocr3types.ReportingPluginConfig{}, defaultOutcomePruningThreshold, lggr) require.NoError(t, err) @@ -42,14 +42,14 @@ func TestReportingPlugin_Query_ErrorInQueueCall(t *testing.T) { func TestReportingPlugin_Query(t *testing.T) { ctx := t.Context() lggr := logger.Test(t) - s := requests.NewStore[*ReportRequest]() + s := requests.NewStore() rp, err := NewReportingPlugin(s, nil, defaultBatchSize, ocr3types.ReportingPluginConfig{}, defaultOutcomePruningThreshold, lggr) require.NoError(t, err) eid := uuid.New().String() wowner := uuid.New().String() - err = s.Add(&ReportRequest{ + err = s.Add(&requests.Request{ WorkflowID: workflowTestID, WorkflowExecutionID: eid, WorkflowOwner: wowner, @@ -153,7 +153,7 @@ func (mc *mockCapability) UnregisterWorkflowID(workflowID string) { func TestReportingPlugin_Observation(t *testing.T) { ctx := t.Context() lggr := logger.Test(t) - s := requests.NewStore[*ReportRequest]() + s := requests.NewStore() mcap := &mockCapability{ aggregator: &aggregator{}, encoder: &enc{}, @@ -170,7 +170,7 @@ func TestReportingPlugin_Observation(t *testing.T) { eid := uuid.New().String() wowner := uuid.New().String() - err = s.Add(&ReportRequest{ + err = s.Add(&requests.Request{ WorkflowID: workflowTestID, WorkflowExecutionID: eid, WorkflowOwner: wowner, @@ -210,7 +210,7 @@ func TestReportingPlugin_Observation(t *testing.T) { func TestReportingPlugin_Observation_NilIds(t *testing.T) { ctx := t.Context() lggr := logger.Test(t) - s := requests.NewStore[*ReportRequest]() + s := requests.NewStore() mcap := &mockCapability{ aggregator: &aggregator{}, encoder: &enc{}, @@ -244,7 +244,7 @@ func TestReportingPlugin_Observation_NilIds(t *testing.T) { func TestReportingPlugin_Observation_NoResults(t *testing.T) { ctx := t.Context() lggr := logger.Test(t) - s := requests.NewStore[*ReportRequest]() + s := requests.NewStore() mcap := &mockCapability{ aggregator: &aggregator{}, encoder: &enc{}, @@ -271,7 +271,7 @@ func TestReportingPlugin_Observation_NoResults(t *testing.T) { func TestReportingPlugin_Outcome(t *testing.T) { lggr := logger.Test(t) - s := requests.NewStore[*ReportRequest]() + s := requests.NewStore() aggregator := &aggregator{} mcap := &mockCapability{ aggregator: aggregator, @@ -331,7 +331,7 @@ func TestReportingPlugin_Outcome(t *testing.T) { func TestReportingPlugin_Outcome_AggregatorErrorDoesntInterruptOtherWorkflows(t *testing.T) { lggr := logger.Test(t) - s := requests.NewStore[*ReportRequest]() + s := requests.NewStore() aggregator := &erroringAggregator{} mcap := &mockCapability{ aggregator: aggregator, @@ -407,7 +407,7 @@ func TestReportingPlugin_Outcome_AggregatorErrorDoesntInterruptOtherWorkflows(t func TestReportingPlugin_Outcome_NilDerefs(t *testing.T) { ctx := t.Context() lggr := logger.Test(t) - s := requests.NewStore[*ReportRequest]() + s := requests.NewStore() mcap := &mockCapability{ aggregator: &aggregator{}, encoder: &enc{}, @@ -465,7 +465,7 @@ func TestReportingPlugin_Outcome_NilDerefs(t *testing.T) { func TestReportingPlugin_Outcome_AggregatorErrorDoesntInterruptOtherIDs(t *testing.T) { ctx := t.Context() lggr := logger.Test(t) - s := requests.NewStore[*ReportRequest]() + s := requests.NewStore() mcap := &mockCapability{ aggregator: &aggregator{}, encoder: &enc{}, @@ -531,7 +531,7 @@ func TestReportingPlugin_Outcome_AggregatorErrorDoesntInterruptOtherIDs(t *testi func TestReportingPlugin_Reports_ShouldReportFalse(t *testing.T) { lggr := logger.Test(t) - s := requests.NewStore[*ReportRequest]() + s := requests.NewStore() mcap := &mockCapability{ aggregator: &aggregator{}, encoder: &enc{}, @@ -585,7 +585,7 @@ func TestReportingPlugin_Reports_ShouldReportFalse(t *testing.T) { func TestReportingPlugin_Reports_NilDerefs(t *testing.T) { ctx := t.Context() lggr := logger.Test(t) - s := requests.NewStore[*ReportRequest]() + s := requests.NewStore() mcap := &mockCapability{ aggregator: &aggregator{}, encoder: &enc{}, @@ -628,7 +628,7 @@ func TestReportingPlugin_Reports_NilDerefs(t *testing.T) { func TestReportingPlugin_Reports_ShouldReportTrue(t *testing.T) { lggr := logger.Test(t) dynamicEncoderName := "special_encoder" - s := requests.NewStore[*ReportRequest]() + s := requests.NewStore() mcap := &mockCapability{ t: t, aggregator: &aggregator{}, @@ -711,7 +711,7 @@ func TestReportingPlugin_Reports_ShouldReportTrue(t *testing.T) { func TestReportingPlugin_Outcome_ShouldPruneOldOutcomes(t *testing.T) { ctx := t.Context() lggr := logger.Test(t) - s := requests.NewStore[*ReportRequest]() + s := requests.NewStore() mcap := &mockCapability{ aggregator: &aggregator{}, encoder: &enc{}, @@ -828,7 +828,7 @@ func TestReportingPlugin_Outcome_ShouldPruneOldOutcomes(t *testing.T) { func TestReportPlugin_Outcome_ShouldReturnMedianTimestamp(t *testing.T) { ctx := t.Context() lggr := logger.Test(t) - s := requests.NewStore[*ReportRequest]() + s := requests.NewStore() mcap := &mockCapability{ aggregator: &aggregator{}, encoder: &enc{}, @@ -960,7 +960,7 @@ func TestReportPlugin_Outcome_ShouldReturnMedianTimestamp(t *testing.T) { func TestReportPlugin_Outcome_ShouldReturnOverriddenEncoder(t *testing.T) { lggr := logger.Test(t) - s := requests.NewStore[*ReportRequest]() + s := requests.NewStore() mcap := &mockCapability{ aggregator: &aggregator{}, encoder: &enc{}, diff --git a/pkg/capabilities/consensus/ocr3/requests/handler.go b/pkg/capabilities/consensus/ocr3/requests/handler.go new file mode 100644 index 000000000..bda3e27d1 --- /dev/null +++ b/pkg/capabilities/consensus/ocr3/requests/handler.go @@ -0,0 +1,153 @@ +package requests + +import ( + "context" + "fmt" + "time" + + "github.com/jonboulle/clockwork" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services" +) + +type responseCacheEntry struct { + response Response + entryTime time.Time +} + +type Handler struct { + services.Service + eng *services.Engine + + store *Store + + pendingRequests map[string]*Request + + responseCache map[string]*responseCacheEntry + cacheExpiryTime time.Duration + + responseCh chan Response + requestCh chan *Request + + clock clockwork.Clock +} + +func NewHandler(lggr logger.Logger, s *Store, clock clockwork.Clock, responseExpiryTime time.Duration) *Handler { + h := &Handler{ + store: s, + pendingRequests: map[string]*Request{}, + responseCache: map[string]*responseCacheEntry{}, + responseCh: make(chan Response), + requestCh: make(chan *Request), + clock: clock, + cacheExpiryTime: responseExpiryTime, + } + h.Service, h.eng = services.Config{ + Name: "Handler", + Start: h.start, + }.NewServiceEngine(lggr) + return h +} + +func (h *Handler) SendResponse(ctx context.Context, resp Response) { + select { + case <-ctx.Done(): + return + case h.responseCh <- resp: + } +} + +func (h *Handler) SendRequest(ctx context.Context, r *Request) { + select { + case <-ctx.Done(): + return + case h.requestCh <- r: + } +} + +func (h *Handler) start(_ context.Context) error { + h.eng.Go(h.worker) + return nil +} + +func (h *Handler) worker(ctx context.Context) { + responseCacheExpiryTicker := h.clock.NewTicker(h.cacheExpiryTime) + defer responseCacheExpiryTicker.Stop() + + // Set to tick at 1 second as this is a sufficient resolution for expiring requests without causing too much overhead + pendingRequestsExpiryTicker := h.clock.NewTicker(1 * time.Second) + defer pendingRequestsExpiryTicker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-pendingRequestsExpiryTicker.Chan(): + h.expirePendingRequests(ctx) + case <-responseCacheExpiryTicker.Chan(): + h.expireCachedResponses() + case req := <-h.requestCh: + h.pendingRequests[req.WorkflowExecutionID] = req + + existingResponse := h.responseCache[req.WorkflowExecutionID] + if existingResponse != nil { + delete(h.responseCache, req.WorkflowExecutionID) + h.eng.Debugw("Found cached response for request", "workflowExecutionID", req.WorkflowExecutionID) + h.sendResponse(ctx, req, existingResponse.response) + continue + } + + if err := h.store.Add(req); err != nil { + h.eng.Errorw("failed to add request to store", "err", err) + } + + case resp := <-h.responseCh: + req, wasPresent := h.store.evict(resp.WorkflowExecutionID) + if !wasPresent { + h.responseCache[resp.WorkflowExecutionID] = &responseCacheEntry{ + response: resp, + entryTime: h.clock.Now(), + } + h.eng.Debugw("Caching response without request", "workflowExecutionID", resp.WorkflowExecutionID) + continue + } + + h.sendResponse(ctx, req, resp) + } + } +} + +func (h *Handler) sendResponse(ctx context.Context, req *Request, resp Response) { + select { + case <-ctx.Done(): + return + case req.CallbackCh <- resp: + close(req.CallbackCh) + delete(h.pendingRequests, req.WorkflowExecutionID) + } +} + +func (h *Handler) expirePendingRequests(ctx context.Context) { + now := h.clock.Now() + + for _, req := range h.pendingRequests { + if now.After(req.ExpiresAt) { + resp := Response{ + WorkflowExecutionID: req.WorkflowExecutionID, + Err: fmt.Errorf("timeout exceeded: could not process request before expiry %s", req.WorkflowExecutionID), + } + h.store.evict(req.WorkflowExecutionID) + h.sendResponse(ctx, req, resp) + } + } +} + +func (h *Handler) expireCachedResponses() { + for k, v := range h.responseCache { + if h.clock.Since(v.entryTime) > h.cacheExpiryTime { + delete(h.responseCache, k) + h.eng.Debugw("Expired response", "workflowExecutionID", k) + } + } +} diff --git a/pkg/capabilities/consensus/requests/handler_test.go b/pkg/capabilities/consensus/ocr3/requests/handler_test.go similarity index 69% rename from pkg/capabilities/consensus/requests/handler_test.go rename to pkg/capabilities/consensus/ocr3/requests/handler_test.go index 3b6a0e038..86b6e0a0d 100644 --- a/pkg/capabilities/consensus/requests/handler_test.go +++ b/pkg/capabilities/consensus/ocr3/requests/handler_test.go @@ -8,10 +8,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3" - "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/requests" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/requests" "github.com/smartcontractkit/chainlink-common/pkg/logger" - "github.com/smartcontractkit/chainlink-common/pkg/services/servicetest" "github.com/smartcontractkit/chainlink-common/pkg/values" ) @@ -20,11 +18,11 @@ func Test_Handler_SendsResponse(t *testing.T) { lggr := logger.Test(t) ctx := t.Context() - h := requests.NewHandler(lggr, requests.NewStore[*ocr3.ReportRequest, ocr3.ReportResponse](), clockwork.NewFakeClockAt(time.Now()), 1*time.Second) + h := requests.NewHandler(lggr, requests.NewStore(), clockwork.NewFakeClockAt(time.Now()), 1*time.Second) servicetest.Run(t, h) - responseCh := make(chan ocr3.ReportResponse, 10) - h.SendRequest(ctx, &ocr3.ReportRequest{ + responseCh := make(chan requests.Response, 10) + h.SendRequest(ctx, &requests.Request{ WorkflowExecutionID: "test", CallbackCh: responseCh, ExpiresAt: time.Now().Add(1 * time.Hour), @@ -33,7 +31,7 @@ func Test_Handler_SendsResponse(t *testing.T) { testVal, err := values.NewMap(map[string]any{"result": "testval"}) require.NoError(t, err) - h.SendResponse(ctx, ocr3.ReportResponse{ + h.SendResponse(ctx, requests.Response{ WorkflowExecutionID: "test", Value: testVal, Err: nil, @@ -47,19 +45,19 @@ func Test_Handler_SendsResponseToLateRequest(t *testing.T) { lggr := logger.Test(t) ctx := t.Context() - h := requests.NewHandler(lggr, requests.NewStore[*ocr3.ReportRequest](), clockwork.NewFakeClockAt(time.Now()), 1*time.Second) + h := requests.NewHandler(lggr, requests.NewStore(), clockwork.NewFakeClockAt(time.Now()), 1*time.Second) servicetest.Run(t, h) testVal, err := values.NewMap(map[string]any{"result": "testval"}) require.NoError(t, err) - h.SendResponse(ctx, ocr3.ReportResponse{ + h.SendResponse(ctx, requests.Response{ WorkflowExecutionID: "test", Value: testVal, Err: nil, }) - responseCh := make(chan ocr3.ReportResponse, 10) - h.SendRequest(ctx, &ocr3.ReportRequest{ + responseCh := make(chan requests.Response, 10) + h.SendRequest(ctx, &requests.Request{ WorkflowExecutionID: "test", CallbackCh: responseCh, ExpiresAt: time.Now().Add(1 * time.Hour), @@ -73,20 +71,20 @@ func Test_Handler_SendsResponseToLateRequestOnlyOnce(t *testing.T) { lggr := logger.Test(t) ctx := t.Context() - h := requests.NewHandler(lggr, requests.NewStore[*ocr3.ReportRequest](), clockwork.NewFakeClockAt(time.Now()), 1*time.Second) + h := requests.NewHandler(lggr, requests.NewStore(), clockwork.NewFakeClockAt(time.Now()), 1*time.Second) servicetest.Run(t, h) testVal, err := values.NewMap(map[string]any{"result": "testval"}) require.NoError(t, err) - h.SendResponse(ctx, ocr3.ReportResponse{ + h.SendResponse(ctx, requests.Response{ WorkflowExecutionID: "test", Value: testVal, Err: nil, }) - responseCh := make(chan ocr3.ReportResponse, 10) - h.SendRequest(ctx, &ocr3.ReportRequest{ + responseCh := make(chan requests.Response, 10) + h.SendRequest(ctx, &requests.Request{ WorkflowExecutionID: "test", CallbackCh: responseCh, ExpiresAt: time.Now().Add(1 * time.Hour), @@ -97,8 +95,8 @@ func Test_Handler_SendsResponseToLateRequestOnlyOnce(t *testing.T) { resp := <-responseCh require.Equal(t, testVal, resp.Value) - responseCh = make(chan ocr3.ReportResponse, 10) - h.SendRequest(ctx, &ocr3.ReportRequest{ + responseCh = make(chan requests.Response, 10) + h.SendRequest(ctx, &requests.Request{ WorkflowExecutionID: "test", CallbackCh: responseCh, ExpiresAt: time.Now().Add(1 * time.Hour), @@ -116,11 +114,11 @@ func Test_Handler_PendingRequestsExpiry(t *testing.T) { lggr := logger.Test(t) clock := clockwork.NewFakeClockAt(time.Now()) - h := requests.NewHandler(lggr, requests.NewStore[*ocr3.ReportRequest](), clock, 1*time.Second) + h := requests.NewHandler(lggr, requests.NewStore(), clock, 1*time.Second) servicetest.Run(t, h) - responseCh := make(chan ocr3.ReportResponse, 10) - h.SendRequest(ctx, &ocr3.ReportRequest{ + responseCh := make(chan requests.Response, 10) + h.SendRequest(ctx, &requests.Request{ WorkflowExecutionID: "test", CallbackCh: responseCh, ExpiresAt: time.Now().Add(1 * time.Second), diff --git a/pkg/capabilities/consensus/ocr3/report_request.go b/pkg/capabilities/consensus/ocr3/requests/request.go similarity index 62% rename from pkg/capabilities/consensus/ocr3/report_request.go rename to pkg/capabilities/consensus/ocr3/requests/request.go index 935d3958c..4d427a038 100644 --- a/pkg/capabilities/consensus/ocr3/report_request.go +++ b/pkg/capabilities/consensus/ocr3/requests/request.go @@ -1,15 +1,13 @@ -package ocr3 +package requests import ( - "context" - "fmt" "time" "github.com/smartcontractkit/chainlink-common/pkg/services" "github.com/smartcontractkit/chainlink-common/pkg/values" ) -type ReportRequest struct { +type Request struct { Observations *values.List `mapstructure:"-"` OverriddenEncoderName string OverriddenEncoderConfig *values.Map @@ -17,7 +15,7 @@ type ReportRequest struct { // CallbackCh is a channel to send a response back to the requester // after the request has been processed or timed out. - CallbackCh chan ReportResponse + CallbackCh chan Response StopCh services.StopChan WorkflowExecutionID string @@ -31,33 +29,8 @@ type ReportRequest struct { KeyID string } -func (r *ReportRequest) ID() string { - return r.WorkflowExecutionID -} - -func (r *ReportRequest) ExpiryTime() time.Time { - return r.ExpiresAt -} - -func (r *ReportRequest) SendResponse(ctx context.Context, resp ReportResponse) { - select { - case <-ctx.Done(): - return - case r.CallbackCh <- resp: - close(r.CallbackCh) - } -} - -func (r *ReportRequest) SendTimeout(ctx context.Context) { - timeoutResponse := ReportResponse{ - WorkflowExecutionID: r.WorkflowExecutionID, - Err: fmt.Errorf("timeout exceeded: could not process request before expiry, workflowExecutionID %s", r.WorkflowExecutionID), - } - r.SendResponse(ctx, timeoutResponse) -} - -func (r *ReportRequest) Copy() *ReportRequest { - return &ReportRequest{ +func (r *Request) Copy() *Request { + return &Request{ Observations: r.Observations.CopyList(), OverriddenEncoderConfig: r.OverriddenEncoderConfig.CopyMap(), @@ -79,12 +52,8 @@ func (r *ReportRequest) Copy() *ReportRequest { } } -type ReportResponse struct { +type Response struct { WorkflowExecutionID string Value *values.Map Err error } - -func (r ReportResponse) RequestID() string { - return r.WorkflowExecutionID -} diff --git a/pkg/capabilities/consensus/requests/store.go b/pkg/capabilities/consensus/ocr3/requests/store.go similarity index 52% rename from pkg/capabilities/consensus/requests/store.go rename to pkg/capabilities/consensus/ocr3/requests/store.go index b9d664057..41419cdf8 100644 --- a/pkg/capabilities/consensus/requests/store.go +++ b/pkg/capabilities/consensus/ocr3/requests/store.go @@ -6,29 +6,32 @@ import ( "sync" ) -// Store is a generic store for ongoing consensus requests. -// It is thread-safe and uses a map to store requests. -type Store[T ConsensusRequest[T, R], R ConsensusResponse] struct { +// Store stores ongoing consensus requests in an +// in-memory map. +// Note: this object is intended to be thread-safe, +// so any read requests should first deep-copy the returned +// request object via request.Copy(). +type Store struct { requestIDs []string - requests map[string]T + requests map[string]*Request mu sync.RWMutex } -func NewStore[T ConsensusRequest[T, R], R ConsensusResponse]() *Store[T, R] { - return &Store[T, R]{ +func NewStore() *Store { + return &Store{ requestIDs: []string{}, - requests: map[string]T{}, + requests: map[string]*Request{}, } } -// GetByIDs retrieves requests by their IDs. +// GetByIDs is best-effort, doesn't return requests that are not in store // The method deep-copies requests before returning them. -func (s *Store[T, R]) GetByIDs(requestIDs []string) []T { +func (s *Store) GetByIDs(requestIDs []string) []*Request { s.mu.RLock() defer s.mu.RUnlock() - o := []T{} + o := []*Request{} for _, r := range requestIDs { gr, ok := s.requests[r] if ok { @@ -39,15 +42,15 @@ func (s *Store[T, R]) GetByIDs(requestIDs []string) []T { return o } -// FirstN retrieves up to `batchSize` requests. +// FirstN returns up to `bathSize` requests. // The method deep-copies requests before returning them. -func (s *Store[T, R]) FirstN(batchSize int) ([]T, error) { +func (s *Store) FirstN(batchSize int) ([]*Request, error) { s.mu.RLock() defer s.mu.RUnlock() if batchSize == 0 { return nil, errors.New("batchsize cannot be 0") } - got := []T{} + got := []*Request{} if len(s.requestIDs) == 0 { return got, nil } @@ -67,34 +70,31 @@ func (s *Store[T, R]) FirstN(batchSize int) ([]T, error) { return got, nil } -// Add adds a new request to the store. -func (s *Store[T, R]) Add(req T) error { +func (s *Store) Add(req *Request) error { s.mu.Lock() defer s.mu.Unlock() - if _, ok := s.requests[req.ID()]; ok { - return fmt.Errorf("request with id %s already exists", req.ID()) + if _, ok := s.requests[req.WorkflowExecutionID]; ok { + return fmt.Errorf("request with id %s already exists", req.WorkflowExecutionID) } - s.requestIDs = append(s.requestIDs, req.ID()) - s.requests[req.ID()] = req + s.requestIDs = append(s.requestIDs, req.WorkflowExecutionID) + s.requests[req.WorkflowExecutionID] = req return nil } -// Get retrieves a request by its ID. -// The method deep-copies the request before returning it. -func (s *Store[T, R]) Get(requestID string) T { +// Get returns the request corresponding to request ID. +// The method deep-copies requests before returning them. +func (s *Store) Get(requestID string) *Request { s.mu.RLock() defer s.mu.RUnlock() rid, ok := s.requests[requestID] if ok { return rid.Copy() } - var zero T - return zero + return nil } -// Evict removes a request from the store by its ID. -func (s *Store[T, R]) Evict(requestID string) (T, bool) { +func (s *Store) evict(requestID string) (*Request, bool) { s.mu.Lock() defer s.mu.Unlock() diff --git a/pkg/capabilities/consensus/requests/store_test.go b/pkg/capabilities/consensus/ocr3/requests/store_test.go similarity index 70% rename from pkg/capabilities/consensus/requests/store_test.go rename to pkg/capabilities/consensus/ocr3/requests/store_test.go index 9203bb017..b871b0b96 100644 --- a/pkg/capabilities/consensus/requests/store_test.go +++ b/pkg/capabilities/consensus/ocr3/requests/store_test.go @@ -1,4 +1,4 @@ -package requests_test +package requests import ( "context" @@ -11,17 +11,14 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" "github.com/smartcontractkit/chainlink-common/pkg/values" - - "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3" - "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/requests" ) func TestOCR3Store(t *testing.T) { n := time.Now() - s := requests.NewStore[*ocr3.ReportRequest]() + s := NewStore() rid := uuid.New().String() - req := &ocr3.ReportRequest{ + req := &Request{ WorkflowExecutionID: rid, ExpiresAt: n.Add(10 * time.Second), } @@ -37,11 +34,9 @@ func TestOCR3Store(t *testing.T) { }) t.Run("evict", func(t *testing.T) { - _, wasPresent := s.Evict(rid) + _, wasPresent := s.evict(rid) assert.True(t, wasPresent) - reqs, err := s.FirstN(10) - require.NoError(t, err) - assert.Len(t, reqs, 0) + assert.Len(t, s.requests, 0) }) t.Run("firstN", func(t *testing.T) { @@ -57,7 +52,7 @@ func TestOCR3Store(t *testing.T) { t.Run("firstN, batchSize larger than queue", func(t *testing.T) { for i := 0; i < 10; i++ { - err := s.Add(&ocr3.ReportRequest{WorkflowExecutionID: uuid.New().String(), ExpiresAt: n.Add(1 * time.Hour)}) + err := s.Add(&Request{WorkflowExecutionID: uuid.New().String(), ExpiresAt: n.Add(1 * time.Hour)}) require.NoError(t, err) } items, err := s.FirstN(100) @@ -75,45 +70,42 @@ func TestOCR3Store(t *testing.T) { } func TestOCR3Store_ManagesStateConsistently(t *testing.T) { - s := requests.NewStore[*ocr3.ReportRequest]() + s := NewStore() rid := uuid.New().String() - req := &ocr3.ReportRequest{ + req := &Request{ WorkflowExecutionID: rid, } err := s.Add(req) require.NoError(t, err) - reqs, err := s.FirstN(10) - require.NoError(t, err) + assert.Len(t, s.requests, 1) + assert.Len(t, s.requestIDs, 1) - assert.Len(t, reqs, 1) + s.GetByIDs([]string{rid}) + assert.Len(t, s.requests, 1) + assert.Len(t, s.requestIDs, 1) - reqs = s.GetByIDs([]string{rid}) - assert.Len(t, reqs, 1) - - _, ok := s.Evict(rid) + _, ok := s.evict(rid) assert.True(t, ok) - reqs, err = s.FirstN(10) - require.NoError(t, err) - assert.Len(t, reqs, 0) + assert.Len(t, s.requests, 0) + assert.Len(t, s.requestIDs, 0) err = s.Add(req) require.NoError(t, err) - reqs, err = s.FirstN(10) - require.NoError(t, err) - assert.Len(t, reqs, 1) + assert.Len(t, s.requests, 1) + assert.Len(t, s.requestIDs, 1) } func TestOCR3Store_ReadRequestsCopy(t *testing.T) { - s := requests.NewStore[*ocr3.ReportRequest]() + s := NewStore() rid := uuid.New().String() - cb := make(chan ocr3.ReportResponse, 1) + cb := make(chan Response, 1) stopCh := make(chan struct{}, 1) obs, err := values.NewList( []any{"hello", 1}, ) require.NoError(t, err) - req := &ocr3.ReportRequest{ + req := &Request{ WorkflowExecutionID: rid, WorkflowID: "wid", WorkflowName: "name", @@ -132,17 +124,17 @@ func TestOCR3Store_ReadRequestsCopy(t *testing.T) { testCases := []struct { name string - get func(ctx context.Context, rid string) *ocr3.ReportRequest + get func(ctx context.Context, rid string) *Request }{ { name: "get", - get: func(ctx context.Context, rid string) *ocr3.ReportRequest { + get: func(ctx context.Context, rid string) *Request { return s.Get(rid) }, }, { name: "firstN", - get: func(ctx context.Context, rid string) *ocr3.ReportRequest { + get: func(ctx context.Context, rid string) *Request { rs, err2 := s.FirstN(1) require.NoError(t, err2) assert.Len(t, rs, 1) @@ -151,7 +143,7 @@ func TestOCR3Store_ReadRequestsCopy(t *testing.T) { }, { name: "getByIDs", - get: func(ctx context.Context, rid string) *ocr3.ReportRequest { + get: func(ctx context.Context, rid string) *Request { rs := s.GetByIDs([]string{rid}) assert.Len(t, rs, 1) return rs[0] @@ -178,7 +170,7 @@ func TestOCR3Store_ReadRequestsCopy(t *testing.T) { gr.StopCh <- struct{}{} <-stopCh - gr.CallbackCh <- ocr3.ReportResponse{} + gr.CallbackCh <- Response{} <-cb }) } diff --git a/pkg/capabilities/consensus/ocr3/transmitter_test.go b/pkg/capabilities/consensus/ocr3/transmitter_test.go index 0a7179723..f651ceafd 100644 --- a/pkg/capabilities/consensus/ocr3/transmitter_test.go +++ b/pkg/capabilities/consensus/ocr3/transmitter_test.go @@ -17,8 +17,8 @@ import ( "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" "github.com/smartcontractkit/chainlink-common/pkg/capabilities" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/requests" pbtypes "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/types" - "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/requests" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/services/servicetest" "github.com/smartcontractkit/chainlink-common/pkg/types/core/mocks" @@ -31,7 +31,7 @@ func TestTransmitter(t *testing.T) { repID := []byte{0xf0, 0xe0} ctx := t.Context() lggr := logger.Test(t) - s := requests.NewStore[*ReportRequest]() + s := requests.NewStore() weid := uuid.New().String() @@ -120,7 +120,7 @@ func TestTransmitter_ShouldReportFalse(t *testing.T) { wowner := "foo-owner" ctx := t.Context() lggr := logger.Test(t) - s := requests.NewStore[*ReportRequest]() + s := requests.NewStore() weid := uuid.New().String() diff --git a/pkg/capabilities/consensus/requests/handler.go b/pkg/capabilities/consensus/requests/handler.go deleted file mode 100644 index ece247548..000000000 --- a/pkg/capabilities/consensus/requests/handler.go +++ /dev/null @@ -1,160 +0,0 @@ -package requests - -import ( - "context" - "time" - - "github.com/jonboulle/clockwork" - - "github.com/smartcontractkit/chainlink-common/pkg/logger" - "github.com/smartcontractkit/chainlink-common/pkg/services" -) - -type responseCacheEntry[R ConsensusResponse] struct { - response R - entryTime time.Time -} - -type ConsensusRequest[T any, R ConsensusResponse] interface { - ID() string - Copy() T - ExpiryTime() time.Time - SendResponse(ctx context.Context, response R) - SendTimeout(ctx context.Context) -} - -type ConsensusResponse interface { - RequestID() string -} - -type Handler[T ConsensusRequest[T, R], R ConsensusResponse] struct { - services.Service - eng *services.Engine - - store *Store[T, R] - - pendingRequests map[string]T - - responseCache map[string]*responseCacheEntry[R] - cacheExpiryTime time.Duration - - responseCh chan R - requestCh chan T - - clock clockwork.Clock -} - -func NewHandler[T ConsensusRequest[T, R], R ConsensusResponse](lggr logger.Logger, s *Store[T, R], clock clockwork.Clock, responseExpiryTime time.Duration) *Handler[T, R] { - h := &Handler[T, R]{ - store: s, - pendingRequests: map[string]T{}, - responseCache: map[string]*responseCacheEntry[R]{}, - responseCh: make(chan R), - requestCh: make(chan T), - clock: clock, - cacheExpiryTime: responseExpiryTime, - } - h.Service, h.eng = services.Config{ - Name: "Handler", - Start: h.start, - }.NewServiceEngine(lggr) - return h -} - -func (h *Handler[T, R]) SendResponse(ctx context.Context, resp R) { - select { - case <-ctx.Done(): - return - case h.responseCh <- resp: - } -} - -func (h *Handler[T, R]) SendRequest(ctx context.Context, r T) { - select { - case <-ctx.Done(): - return - case h.requestCh <- r: - } -} - -func (h *Handler[T, R]) start(_ context.Context) error { - h.eng.Go(h.worker) - return nil -} - -func (h *Handler[T, R]) worker(ctx context.Context) { - responseCacheExpiryTicker := h.clock.NewTicker(h.cacheExpiryTime) - defer responseCacheExpiryTicker.Stop() - - // Set to tick at 1 second as this is a sufficient resolution for expiring requests without causing too much overhead - pendingRequestsExpiryTicker := h.clock.NewTicker(1 * time.Second) - defer pendingRequestsExpiryTicker.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-pendingRequestsExpiryTicker.Chan(): - h.expirePendingRequests(ctx) - case <-responseCacheExpiryTicker.Chan(): - h.expireCachedResponses() - case req := <-h.requestCh: - h.pendingRequests[req.ID()] = req - - existingResponse := h.responseCache[req.ID()] - if existingResponse != nil { - delete(h.responseCache, req.ID()) - h.eng.Debugw("Found cached response for request", "requestID", req.ID) - h.sendResponse(ctx, req, existingResponse.response) - continue - } - - if err := h.store.Add(req); err != nil { - h.eng.Errorw("failed to add request to store", "err", err) - } - - case resp := <-h.responseCh: - req, wasPresent := h.store.Evict(resp.RequestID()) - if !wasPresent { - h.responseCache[resp.RequestID()] = &responseCacheEntry[R]{ - response: resp, - entryTime: h.clock.Now(), - } - h.eng.Debugw("Caching response without request", "requestID", resp.RequestID()) - continue - } - - h.sendResponse(ctx, req, resp) - } - } -} - -func (h *Handler[T, R]) sendResponse(ctx context.Context, req T, resp R) { - req.SendResponse(ctx, resp) - delete(h.pendingRequests, req.ID()) -} - -func (h *Handler[T, R]) sendTimeout(ctx context.Context, req T) { - req.SendTimeout(ctx) - delete(h.pendingRequests, req.ID()) -} - -func (h *Handler[T, R]) expirePendingRequests(ctx context.Context) { - now := h.clock.Now() - - for _, req := range h.pendingRequests { - if now.After(req.ExpiryTime()) { - h.store.Evict(req.ID()) - h.sendTimeout(ctx, req) - } - } -} - -func (h *Handler[T, R]) expireCachedResponses() { - for k, v := range h.responseCache { - if h.clock.Since(v.entryTime) > h.cacheExpiryTime { - delete(h.responseCache, k) - h.eng.Debugw("Expired response", "requestID", k) - } - } -} diff --git a/pkg/capabilities/v2/consensus/consensus.pb.go b/pkg/capabilities/v2/consensus/consensus.pb.go index 1cf1ddccb..8d4601074 100644 --- a/pkg/capabilities/v2/consensus/consensus.pb.go +++ b/pkg/capabilities/v2/consensus/consensus.pb.go @@ -27,9 +27,9 @@ var File_capabilities_v2_consensus_consensus_proto protoreflect.FileDescriptor const file_capabilities_v2_consensus_consensus_proto_rawDesc = "" + "\n" + - ")capabilities/v2/consensus/consensus.proto\x12\x14cre.sdk.v2.consensus\x1a\x16values/pb/values.proto\x1a0capabilities/v2/protoc/pkg/pb/cre_metadata.proto\x1a\x1dworkflows/sdk/v2/pb/sdk.proto2^\n" + + ")capabilities/v2/consensus/consensus.proto\x12\x14cre.sdk.v2.consensus\x1a\x16values/pb/values.proto\x1a0capabilities/v2/protoc/pkg/pb/cre_metadata.proto\x1a\x1dworkflows/sdk/v2/pb/sdk.proto2g\n" + "\tConsensus\x12:\n" + - "\x06Simple\x12!.cre.sdk.v2.SimpleConsensusInputs\x1a\r.values.Value\x1a\x15\x82\xb5\x18\x11\x12\x0fconsensus@1.0.0BLZJgithub.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/consensusb\x06proto3" + "\x06Simple\x12!.cre.sdk.v2.SimpleConsensusInputs\x1a\r.values.Value\x1a\x1e\x82\xb5\x18\x1a\x12\x18offchain_reporting@1.0.0BLZJgithub.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/consensusb\x06proto3" var file_capabilities_v2_consensus_consensus_proto_goTypes = []any{ (*pb.SimpleConsensusInputs)(nil), // 0: cre.sdk.v2.SimpleConsensusInputs diff --git a/pkg/capabilities/v2/consensus/consensus.proto b/pkg/capabilities/v2/consensus/consensus.proto index bf21e83ff..57d0343d5 100644 --- a/pkg/capabilities/v2/consensus/consensus.proto +++ b/pkg/capabilities/v2/consensus/consensus.proto @@ -11,7 +11,7 @@ import "workflows/sdk/v2/pb/sdk.proto"; service Consensus { option (cre.metadata.capability) = { mode: DON - capability_id: "consensus@1.0.0" + capability_id: "offchain_reporting@1.0.0" }; rpc Simple (cre.sdk.v2.SimpleConsensusInputs) returns (values.Value); } diff --git a/pkg/capabilities/v2/consensus/consensus_sdk_gen.go b/pkg/capabilities/v2/consensus/consensus_sdk_gen.go index 7c25aca8d..9b2c83195 100644 --- a/pkg/capabilities/v2/consensus/consensus_sdk_gen.go +++ b/pkg/capabilities/v2/consensus/consensus_sdk_gen.go @@ -23,7 +23,7 @@ func (c *Consensus) Simple(runtime sdk.DonRuntime, input *pb1.SimpleConsensusInp return sdk.PromiseFromResult[*pb.Value](nil, err) } return sdk.Then(runtime.CallCapability(&sdkpb.CapabilityRequest{ - Id: "consensus@1.0.0", + Id: "offchain_reporting@1.0.0", Payload: wrapped, Method: "Simple", }), func(i *sdkpb.CapabilityResponse) (*pb.Value, error) { diff --git a/pkg/capabilities/v2/consensus/consensusmock/consensus_mock_gen.go b/pkg/capabilities/v2/consensus/consensusmock/consensus_mock_gen.go index d20b590d9..521bdb98d 100644 --- a/pkg/capabilities/v2/consensus/consensusmock/consensus_mock_gen.go +++ b/pkg/capabilities/v2/consensus/consensusmock/consensus_mock_gen.go @@ -67,5 +67,5 @@ func (cap *ConsensusCapability) InvokeTrigger(ctx context.Context, request *sdkp } func (cap *ConsensusCapability) ID() string { - return "consensus@1.0.0" + return "offchain_reporting@1.0.0" } diff --git a/pkg/capabilities/v2/consensus/server/consensus_server_gen.go b/pkg/capabilities/v2/consensus/server/consensus_server_gen.go index f0187d96f..33641fad5 100644 --- a/pkg/capabilities/v2/consensus/server/consensus_server_gen.go +++ b/pkg/capabilities/v2/consensus/server/consensus_server_gen.go @@ -66,7 +66,7 @@ func (cs *ConsensusServer) Close() error { defer cancel() if cs.capabilityRegistry != nil { - if err := cs.capabilityRegistry.Remove(ctx, "consensus@1.0.0"); err != nil { + if err := cs.capabilityRegistry.Remove(ctx, "offchain_reporting@1.0.0"); err != nil { return err } } @@ -93,7 +93,7 @@ type consensusCapability struct { func (c *consensusCapability) Info(ctx context.Context) (capabilities.CapabilityInfo, error) { // Maybe we do need to split it out, even if the user doesn't see it - return capabilities.NewCapabilityInfo("consensus@1.0.0", capabilities.CapabilityTypeCombined, c.ConsensusCapability.Description()) + return capabilities.NewCapabilityInfo("offchain_reporting@1.0.0", capabilities.CapabilityTypeCombined, c.ConsensusCapability.Description()) } var _ capabilities.ExecutableAndTriggerCapability = (*consensusCapability)(nil) From 938ff7f43b0fcc9fbeaa421bc347b880f77e9ddd Mon Sep 17 00:00:00 2001 From: Silas Lenihan Date: Wed, 11 Jun 2025 14:09:04 -0400 Subject: [PATCH 15/16] Reapply "requests handling (#1247)" This reverts commit 00a1fdb4172af042269eefc66880b0b900ab89d4. --- .../consensus/ocr3/benchmark_test.go | 15 +- pkg/capabilities/consensus/ocr3/capability.go | 18 +- .../consensus/ocr3/capability_test.go | 24 +-- pkg/capabilities/consensus/ocr3/factory.go | 6 +- pkg/capabilities/consensus/ocr3/ocr3.go | 6 +- .../request.go => report_request.go} | 43 ++++- .../consensus/ocr3/reporting_plugin.go | 11 +- .../consensus/ocr3/reporting_plugin_test.go | 36 ++-- .../consensus/ocr3/requests/handler.go | 153 ----------------- .../consensus/ocr3/transmitter_test.go | 6 +- .../consensus/requests/handler.go | 160 ++++++++++++++++++ .../{ocr3 => }/requests/handler_test.go | 38 +++-- .../consensus/{ocr3 => }/requests/store.go | 52 +++--- .../{ocr3 => }/requests/store_test.go | 60 ++++--- pkg/capabilities/v2/consensus/consensus.pb.go | 4 +- pkg/capabilities/v2/consensus/consensus.proto | 2 +- .../v2/consensus/consensus_sdk_gen.go | 2 +- .../consensusmock/consensus_mock_gen.go | 2 +- .../consensus/server/consensus_server_gen.go | 4 +- 19 files changed, 346 insertions(+), 296 deletions(-) rename pkg/capabilities/consensus/ocr3/{requests/request.go => report_request.go} (62%) delete mode 100644 pkg/capabilities/consensus/ocr3/requests/handler.go create mode 100644 pkg/capabilities/consensus/requests/handler.go rename pkg/capabilities/consensus/{ocr3 => }/requests/handler_test.go (69%) rename pkg/capabilities/consensus/{ocr3 => }/requests/store.go (52%) rename pkg/capabilities/consensus/{ocr3 => }/requests/store_test.go (70%) diff --git a/pkg/capabilities/consensus/ocr3/benchmark_test.go b/pkg/capabilities/consensus/ocr3/benchmark_test.go index 09f9b983a..d124184d2 100644 --- a/pkg/capabilities/consensus/ocr3/benchmark_test.go +++ b/pkg/capabilities/consensus/ocr3/benchmark_test.go @@ -8,18 +8,19 @@ import ( "time" "github.com/shopspring/decimal" - ocrcommon "github.com/smartcontractkit/libocr/commontypes" - "github.com/smartcontractkit/libocr/offchainreporting2/types" - "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" "github.com/stretchr/testify/require" "go.uber.org/zap/zapcore" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" + ocrcommon "github.com/smartcontractkit/libocr/commontypes" + "github.com/smartcontractkit/libocr/offchainreporting2/types" + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/datafeeds" - "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/requests" pbtypes "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/types" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/requests" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/datastreams" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/values" @@ -108,7 +109,7 @@ func runObservationBenchmarkWithParams(b *testing.B, lggr logger.Logger, numWork ) // Create request store with requests for each workflow - store := requests.NewStore() + store := requests.NewStore[*ocr3.ReportRequest, ocr3.ReportResponse]() // Create capability with LLO aggregators for each workflow mockCap := &mockCapability{ @@ -135,7 +136,7 @@ func runObservationBenchmarkWithParams(b *testing.B, lggr logger.Logger, numWork require.NoError(b, err) // Create and add request to store - req := &requests.Request{ + req := &ocr3.ReportRequest{ WorkflowID: workflowID, WorkflowExecutionID: executionID, WorkflowName: fmt.Sprintf("Workflow %d", i), @@ -229,7 +230,7 @@ func runBenchmarkWithParams(b *testing.B, lggr logger.Logger, numWorkflows, numS ) // Create request store - store := requests.NewStore() + store := requests.NewStore[*ocr3.ReportRequest]() // Create capability with LLO aggregators for each workflow mockCap := &mockCapability{ diff --git a/pkg/capabilities/consensus/ocr3/capability.go b/pkg/capabilities/consensus/ocr3/capability.go index ac0f1f4d6..2639f5653 100644 --- a/pkg/capabilities/consensus/ocr3/capability.go +++ b/pkg/capabilities/consensus/ocr3/capability.go @@ -10,8 +10,8 @@ import ( "google.golang.org/protobuf/proto" "github.com/smartcontractkit/chainlink-common/pkg/capabilities" - "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/requests" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/types" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/requests" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/metering" "github.com/smartcontractkit/chainlink-common/pkg/services" @@ -39,9 +39,9 @@ type capability struct { eng *services.Engine capabilities.CapabilityInfo - capabilities.Validator[config, inputs, requests.Response] + capabilities.Validator[config, inputs, ReportResponse] - reqHandler *requests.Handler + reqHandler *requests.Handler[*ReportRequest, ReportResponse] requestTimeout time.Duration requestTimeoutLock sync.RWMutex @@ -63,11 +63,11 @@ type capability struct { var _ CapabilityIface = (*capability)(nil) var _ capabilities.ExecutableCapability = (*capability)(nil) -func NewCapability(s *requests.Store, clock clockwork.Clock, requestTimeout time.Duration, aggregatorFactory types.AggregatorFactory, encoderFactory types.EncoderFactory, lggr logger.Logger, +func NewCapability(s *requests.Store[*ReportRequest, ReportResponse], clock clockwork.Clock, requestTimeout time.Duration, aggregatorFactory types.AggregatorFactory, encoderFactory types.EncoderFactory, lggr logger.Logger, callbackChannelBufferSize int) *capability { o := &capability{ CapabilityInfo: info, - Validator: capabilities.NewValidator[config, inputs, requests.Response](capabilities.ValidatorArgs{Info: info}), + Validator: capabilities.NewValidator[config, inputs, ReportResponse](capabilities.ValidatorArgs{Info: info}), clock: clock, requestTimeout: requestTimeout, aggregatorFactory: aggregatorFactory, @@ -195,7 +195,7 @@ func (o *capability) Execute(ctx context.Context, r capabilities.CapabilityReque o.eng.Debugw("Execute - terminating execution", "workflowExecutionID", r.Metadata.WorkflowExecutionID) responseErr = capabilities.ErrStopExecution } - out := requests.Response{ + out := ReportResponse{ WorkflowExecutionID: r.Metadata.WorkflowExecutionID, Value: inputs, Err: responseErr, @@ -254,8 +254,8 @@ func (o *capability) queueRequestForProcessing( metadata capabilities.RequestMetadata, i *inputs, c *config, -) (<-chan requests.Response, error) { - callbackCh := make(chan requests.Response, o.callbackChannelBufferSize) +) (<-chan ReportResponse, error) { + callbackCh := make(chan ReportResponse, o.callbackChannelBufferSize) // Use the capability-level request timeout unless the request's config specifies // its own timeout, in which case we'll use that instead. This allows the workflow spec @@ -267,7 +267,7 @@ func (o *capability) queueRequestForProcessing( } o.requestTimeoutLock.RUnlock() - r := &requests.Request{ + r := &ReportRequest{ StopCh: make(chan struct{}), CallbackCh: callbackCh, WorkflowExecutionID: metadata.WorkflowExecutionID, diff --git a/pkg/capabilities/consensus/ocr3/capability_test.go b/pkg/capabilities/consensus/ocr3/capability_test.go index 952d02dc8..9f7a07c20 100644 --- a/pkg/capabilities/consensus/ocr3/capability_test.go +++ b/pkg/capabilities/consensus/ocr3/capability_test.go @@ -13,8 +13,8 @@ import ( "github.com/stretchr/testify/require" "github.com/smartcontractkit/chainlink-common/pkg/capabilities" - "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/requests" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/types" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/requests" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/utils" "github.com/smartcontractkit/chainlink-common/pkg/values" @@ -48,7 +48,7 @@ func TestOCR3Capability_Schema(t *testing.T) { fc := clockwork.NewFakeClockAt(n) lggr := logger.Nop() - s := requests.NewStore() + s := requests.NewStore[*ReportRequest]() cp := NewCapability(s, fc, 1*time.Second, mockAggregatorFactory, mockEncoderFactory, lggr, 10) schema, err := cp.Schema() @@ -88,7 +88,7 @@ func TestOCR3Capability(t *testing.T) { ctx := t.Context() - s := requests.NewStore() + s := requests.NewStore[*ReportRequest]() cp := NewCapability(s, fc, 1*time.Second, mockAggregatorFactory, mockEncoderFactory, lggr, 10) require.NoError(t, cp.Start(ctx)) @@ -129,7 +129,7 @@ func TestOCR3Capability(t *testing.T) { // Mock the oracle returning a response mresp, err := values.NewMap(map[string]any{"observations": obsv}) - cp.reqHandler.SendResponse(ctx, requests.Response{ + cp.reqHandler.SendResponse(ctx, ReportResponse{ Value: mresp, WorkflowExecutionID: workflowExecutionTestID, }) @@ -155,7 +155,7 @@ func TestOCR3Capability_Eviction(t *testing.T) { defer cancel() rea := time.Second - s := requests.NewStore() + s := requests.NewStore[*ReportRequest]() cp := NewCapability(s, fc, rea, mockAggregatorFactory, mockEncoderFactory, lggr, 10) require.NoError(t, cp.Start(ctx)) @@ -223,7 +223,7 @@ func TestOCR3Capability_EvictionUsingConfig(t *testing.T) { defer cancel() // This is the default expired at rea := time.Hour - s := requests.NewStore() + s := requests.NewStore[*ReportRequest]() cp := NewCapability(s, fc, rea, mockAggregatorFactory, mockEncoderFactory, lggr, 10) require.NoError(t, cp.Start(ctx)) @@ -290,7 +290,7 @@ func TestOCR3Capability_Registration(t *testing.T) { lggr := logger.Test(t) ctx := t.Context() - s := requests.NewStore() + s := requests.NewStore[*ReportRequest]() cp := NewCapability(s, fc, 1*time.Second, mockAggregatorFactory, mockEncoderFactory, lggr, 10) require.NoError(t, cp.Start(ctx)) @@ -336,7 +336,7 @@ func TestOCR3Capability_ValidateConfig(t *testing.T) { fc := clockwork.NewFakeClockAt(n) lggr := logger.Test(t) - s := requests.NewStore() + s := requests.NewStore[*ReportRequest]() o := NewCapability(s, fc, 1*time.Second, mockAggregatorFactory, mockEncoderFactory, lggr, 10) @@ -411,7 +411,7 @@ func TestOCR3Capability_RespondsToLateRequest(t *testing.T) { ctx := t.Context() - s := requests.NewStore() + s := requests.NewStore[*ReportRequest]() cp := NewCapability(s, fc, 1*time.Second, mockAggregatorFactory, mockEncoderFactory, lggr, 10) require.NoError(t, cp.Start(ctx)) @@ -440,7 +440,7 @@ func TestOCR3Capability_RespondsToLateRequest(t *testing.T) { require.NoError(t, err) // Mock the oracle returning a response prior to the request being sent - cp.reqHandler.SendResponse(ctx, requests.Response{ + cp.reqHandler.SendResponse(ctx, ReportResponse{ Value: obsv, WorkflowExecutionID: workflowExecutionTestID, }) @@ -471,7 +471,7 @@ func TestOCR3Capability_RespondingToLateRequestDoesNotBlockOnSlowResponseConsume ctx := t.Context() - s := requests.NewStore() + s := requests.NewStore[*ReportRequest]() cp := NewCapability(s, fc, 1*time.Second, mockAggregatorFactory, mockEncoderFactory, lggr, 0) require.NoError(t, cp.Start(ctx)) @@ -500,7 +500,7 @@ func TestOCR3Capability_RespondingToLateRequestDoesNotBlockOnSlowResponseConsume require.NoError(t, err) // Mock the oracle returning a response prior to the request being sent - cp.reqHandler.SendResponse(ctx, requests.Response{ + cp.reqHandler.SendResponse(ctx, ReportResponse{ Value: obsv, WorkflowExecutionID: workflowExecutionTestID, }) diff --git a/pkg/capabilities/consensus/ocr3/factory.go b/pkg/capabilities/consensus/ocr3/factory.go index 0e4a4c386..1b60526b7 100644 --- a/pkg/capabilities/consensus/ocr3/factory.go +++ b/pkg/capabilities/consensus/ocr3/factory.go @@ -9,8 +9,8 @@ import ( "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" - "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/requests" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/types" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/requests" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/services" ) @@ -24,7 +24,7 @@ const ( ) type factory struct { - store *requests.Store + store *requests.Store[*ReportRequest, ReportResponse] capability *capability batchSize int outcomePruningThreshold uint64 @@ -33,7 +33,7 @@ type factory struct { services.StateMachine } -func newFactory(s *requests.Store, c *capability, lggr logger.Logger) (*factory, error) { +func newFactory(s *requests.Store[*ReportRequest, ReportResponse], c *capability, lggr logger.Logger) (*factory, error) { return &factory{ store: s, capability: c, diff --git a/pkg/capabilities/consensus/ocr3/ocr3.go b/pkg/capabilities/consensus/ocr3/ocr3.go index e4b6046f5..123e4a943 100644 --- a/pkg/capabilities/consensus/ocr3/ocr3.go +++ b/pkg/capabilities/consensus/ocr3/ocr3.go @@ -6,8 +6,8 @@ import ( "github.com/jonboulle/clockwork" - "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/requests" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/types" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/requests" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/loop" "github.com/smartcontractkit/chainlink-common/pkg/loop/reportingplugins" @@ -32,7 +32,7 @@ type Config struct { EncoderFactory types.EncoderFactory SendBufferSize int - store *requests.Store + store *requests.Store[*ReportRequest, ReportResponse] capability *capability clock clockwork.Clock } @@ -56,7 +56,7 @@ func NewOCR3(config Config) *Capability { } if config.store == nil { - config.store = requests.NewStore() + config.store = requests.NewStore[*ReportRequest]() } if config.capability == nil { diff --git a/pkg/capabilities/consensus/ocr3/requests/request.go b/pkg/capabilities/consensus/ocr3/report_request.go similarity index 62% rename from pkg/capabilities/consensus/ocr3/requests/request.go rename to pkg/capabilities/consensus/ocr3/report_request.go index 4d427a038..935d3958c 100644 --- a/pkg/capabilities/consensus/ocr3/requests/request.go +++ b/pkg/capabilities/consensus/ocr3/report_request.go @@ -1,13 +1,15 @@ -package requests +package ocr3 import ( + "context" + "fmt" "time" "github.com/smartcontractkit/chainlink-common/pkg/services" "github.com/smartcontractkit/chainlink-common/pkg/values" ) -type Request struct { +type ReportRequest struct { Observations *values.List `mapstructure:"-"` OverriddenEncoderName string OverriddenEncoderConfig *values.Map @@ -15,7 +17,7 @@ type Request struct { // CallbackCh is a channel to send a response back to the requester // after the request has been processed or timed out. - CallbackCh chan Response + CallbackCh chan ReportResponse StopCh services.StopChan WorkflowExecutionID string @@ -29,8 +31,33 @@ type Request struct { KeyID string } -func (r *Request) Copy() *Request { - return &Request{ +func (r *ReportRequest) ID() string { + return r.WorkflowExecutionID +} + +func (r *ReportRequest) ExpiryTime() time.Time { + return r.ExpiresAt +} + +func (r *ReportRequest) SendResponse(ctx context.Context, resp ReportResponse) { + select { + case <-ctx.Done(): + return + case r.CallbackCh <- resp: + close(r.CallbackCh) + } +} + +func (r *ReportRequest) SendTimeout(ctx context.Context) { + timeoutResponse := ReportResponse{ + WorkflowExecutionID: r.WorkflowExecutionID, + Err: fmt.Errorf("timeout exceeded: could not process request before expiry, workflowExecutionID %s", r.WorkflowExecutionID), + } + r.SendResponse(ctx, timeoutResponse) +} + +func (r *ReportRequest) Copy() *ReportRequest { + return &ReportRequest{ Observations: r.Observations.CopyList(), OverriddenEncoderConfig: r.OverriddenEncoderConfig.CopyMap(), @@ -52,8 +79,12 @@ func (r *Request) Copy() *Request { } } -type Response struct { +type ReportResponse struct { WorkflowExecutionID string Value *values.Map Err error } + +func (r ReportResponse) RequestID() string { + return r.WorkflowExecutionID +} diff --git a/pkg/capabilities/consensus/ocr3/reporting_plugin.go b/pkg/capabilities/consensus/ocr3/reporting_plugin.go index 31f27fc44..a4d518c47 100644 --- a/pkg/capabilities/consensus/ocr3/reporting_plugin.go +++ b/pkg/capabilities/consensus/ocr3/reporting_plugin.go @@ -8,15 +8,16 @@ import ( "slices" "time" - "github.com/smartcontractkit/libocr/quorumhelper" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" + "github.com/smartcontractkit/libocr/quorumhelper" + ocrcommon "github.com/smartcontractkit/libocr/commontypes" "github.com/smartcontractkit/libocr/offchainreporting2/types" "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" - "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/requests" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/requests" "google.golang.org/protobuf/types/known/structpb" @@ -38,14 +39,14 @@ type CapabilityIface interface { type reportingPlugin struct { batchSize int - s *requests.Store + s *requests.Store[*ReportRequest, ReportResponse] r CapabilityIface config ocr3types.ReportingPluginConfig outcomePruningThreshold uint64 lggr logger.Logger } -func NewReportingPlugin(s *requests.Store, r CapabilityIface, batchSize int, config ocr3types.ReportingPluginConfig, +func NewReportingPlugin(s *requests.Store[*ReportRequest, ReportResponse], r CapabilityIface, batchSize int, config ocr3types.ReportingPluginConfig, outcomePruningThreshold uint64, lggr logger.Logger) (*reportingPlugin, error) { return &reportingPlugin{ s: s, @@ -103,7 +104,7 @@ func (r *reportingPlugin) Observation(ctx context.Context, outctx ocr3types.Outc } reqs := r.s.GetByIDs(weids) - reqMap := map[string]*requests.Request{} + reqMap := map[string]*ReportRequest{} for _, req := range reqs { reqMap[req.WorkflowExecutionID] = req } diff --git a/pkg/capabilities/consensus/ocr3/reporting_plugin_test.go b/pkg/capabilities/consensus/ocr3/reporting_plugin_test.go index 7974ffed8..30c9d976a 100644 --- a/pkg/capabilities/consensus/ocr3/reporting_plugin_test.go +++ b/pkg/capabilities/consensus/ocr3/reporting_plugin_test.go @@ -17,8 +17,8 @@ import ( "github.com/smartcontractkit/libocr/offchainreporting2/types" "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" - "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/requests" pbtypes "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/types" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/requests" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/values" "github.com/smartcontractkit/chainlink-common/pkg/values/pb" @@ -27,7 +27,7 @@ import ( func TestReportingPlugin_Query_ErrorInQueueCall(t *testing.T) { ctx := t.Context() lggr := logger.Test(t) - s := requests.NewStore() + s := requests.NewStore[*ReportRequest]() batchSize := 0 rp, err := NewReportingPlugin(s, nil, batchSize, ocr3types.ReportingPluginConfig{}, defaultOutcomePruningThreshold, lggr) require.NoError(t, err) @@ -42,14 +42,14 @@ func TestReportingPlugin_Query_ErrorInQueueCall(t *testing.T) { func TestReportingPlugin_Query(t *testing.T) { ctx := t.Context() lggr := logger.Test(t) - s := requests.NewStore() + s := requests.NewStore[*ReportRequest]() rp, err := NewReportingPlugin(s, nil, defaultBatchSize, ocr3types.ReportingPluginConfig{}, defaultOutcomePruningThreshold, lggr) require.NoError(t, err) eid := uuid.New().String() wowner := uuid.New().String() - err = s.Add(&requests.Request{ + err = s.Add(&ReportRequest{ WorkflowID: workflowTestID, WorkflowExecutionID: eid, WorkflowOwner: wowner, @@ -153,7 +153,7 @@ func (mc *mockCapability) UnregisterWorkflowID(workflowID string) { func TestReportingPlugin_Observation(t *testing.T) { ctx := t.Context() lggr := logger.Test(t) - s := requests.NewStore() + s := requests.NewStore[*ReportRequest]() mcap := &mockCapability{ aggregator: &aggregator{}, encoder: &enc{}, @@ -170,7 +170,7 @@ func TestReportingPlugin_Observation(t *testing.T) { eid := uuid.New().String() wowner := uuid.New().String() - err = s.Add(&requests.Request{ + err = s.Add(&ReportRequest{ WorkflowID: workflowTestID, WorkflowExecutionID: eid, WorkflowOwner: wowner, @@ -210,7 +210,7 @@ func TestReportingPlugin_Observation(t *testing.T) { func TestReportingPlugin_Observation_NilIds(t *testing.T) { ctx := t.Context() lggr := logger.Test(t) - s := requests.NewStore() + s := requests.NewStore[*ReportRequest]() mcap := &mockCapability{ aggregator: &aggregator{}, encoder: &enc{}, @@ -244,7 +244,7 @@ func TestReportingPlugin_Observation_NilIds(t *testing.T) { func TestReportingPlugin_Observation_NoResults(t *testing.T) { ctx := t.Context() lggr := logger.Test(t) - s := requests.NewStore() + s := requests.NewStore[*ReportRequest]() mcap := &mockCapability{ aggregator: &aggregator{}, encoder: &enc{}, @@ -271,7 +271,7 @@ func TestReportingPlugin_Observation_NoResults(t *testing.T) { func TestReportingPlugin_Outcome(t *testing.T) { lggr := logger.Test(t) - s := requests.NewStore() + s := requests.NewStore[*ReportRequest]() aggregator := &aggregator{} mcap := &mockCapability{ aggregator: aggregator, @@ -331,7 +331,7 @@ func TestReportingPlugin_Outcome(t *testing.T) { func TestReportingPlugin_Outcome_AggregatorErrorDoesntInterruptOtherWorkflows(t *testing.T) { lggr := logger.Test(t) - s := requests.NewStore() + s := requests.NewStore[*ReportRequest]() aggregator := &erroringAggregator{} mcap := &mockCapability{ aggregator: aggregator, @@ -407,7 +407,7 @@ func TestReportingPlugin_Outcome_AggregatorErrorDoesntInterruptOtherWorkflows(t func TestReportingPlugin_Outcome_NilDerefs(t *testing.T) { ctx := t.Context() lggr := logger.Test(t) - s := requests.NewStore() + s := requests.NewStore[*ReportRequest]() mcap := &mockCapability{ aggregator: &aggregator{}, encoder: &enc{}, @@ -465,7 +465,7 @@ func TestReportingPlugin_Outcome_NilDerefs(t *testing.T) { func TestReportingPlugin_Outcome_AggregatorErrorDoesntInterruptOtherIDs(t *testing.T) { ctx := t.Context() lggr := logger.Test(t) - s := requests.NewStore() + s := requests.NewStore[*ReportRequest]() mcap := &mockCapability{ aggregator: &aggregator{}, encoder: &enc{}, @@ -531,7 +531,7 @@ func TestReportingPlugin_Outcome_AggregatorErrorDoesntInterruptOtherIDs(t *testi func TestReportingPlugin_Reports_ShouldReportFalse(t *testing.T) { lggr := logger.Test(t) - s := requests.NewStore() + s := requests.NewStore[*ReportRequest]() mcap := &mockCapability{ aggregator: &aggregator{}, encoder: &enc{}, @@ -585,7 +585,7 @@ func TestReportingPlugin_Reports_ShouldReportFalse(t *testing.T) { func TestReportingPlugin_Reports_NilDerefs(t *testing.T) { ctx := t.Context() lggr := logger.Test(t) - s := requests.NewStore() + s := requests.NewStore[*ReportRequest]() mcap := &mockCapability{ aggregator: &aggregator{}, encoder: &enc{}, @@ -628,7 +628,7 @@ func TestReportingPlugin_Reports_NilDerefs(t *testing.T) { func TestReportingPlugin_Reports_ShouldReportTrue(t *testing.T) { lggr := logger.Test(t) dynamicEncoderName := "special_encoder" - s := requests.NewStore() + s := requests.NewStore[*ReportRequest]() mcap := &mockCapability{ t: t, aggregator: &aggregator{}, @@ -711,7 +711,7 @@ func TestReportingPlugin_Reports_ShouldReportTrue(t *testing.T) { func TestReportingPlugin_Outcome_ShouldPruneOldOutcomes(t *testing.T) { ctx := t.Context() lggr := logger.Test(t) - s := requests.NewStore() + s := requests.NewStore[*ReportRequest]() mcap := &mockCapability{ aggregator: &aggregator{}, encoder: &enc{}, @@ -828,7 +828,7 @@ func TestReportingPlugin_Outcome_ShouldPruneOldOutcomes(t *testing.T) { func TestReportPlugin_Outcome_ShouldReturnMedianTimestamp(t *testing.T) { ctx := t.Context() lggr := logger.Test(t) - s := requests.NewStore() + s := requests.NewStore[*ReportRequest]() mcap := &mockCapability{ aggregator: &aggregator{}, encoder: &enc{}, @@ -960,7 +960,7 @@ func TestReportPlugin_Outcome_ShouldReturnMedianTimestamp(t *testing.T) { func TestReportPlugin_Outcome_ShouldReturnOverriddenEncoder(t *testing.T) { lggr := logger.Test(t) - s := requests.NewStore() + s := requests.NewStore[*ReportRequest]() mcap := &mockCapability{ aggregator: &aggregator{}, encoder: &enc{}, diff --git a/pkg/capabilities/consensus/ocr3/requests/handler.go b/pkg/capabilities/consensus/ocr3/requests/handler.go deleted file mode 100644 index bda3e27d1..000000000 --- a/pkg/capabilities/consensus/ocr3/requests/handler.go +++ /dev/null @@ -1,153 +0,0 @@ -package requests - -import ( - "context" - "fmt" - "time" - - "github.com/jonboulle/clockwork" - - "github.com/smartcontractkit/chainlink-common/pkg/logger" - "github.com/smartcontractkit/chainlink-common/pkg/services" -) - -type responseCacheEntry struct { - response Response - entryTime time.Time -} - -type Handler struct { - services.Service - eng *services.Engine - - store *Store - - pendingRequests map[string]*Request - - responseCache map[string]*responseCacheEntry - cacheExpiryTime time.Duration - - responseCh chan Response - requestCh chan *Request - - clock clockwork.Clock -} - -func NewHandler(lggr logger.Logger, s *Store, clock clockwork.Clock, responseExpiryTime time.Duration) *Handler { - h := &Handler{ - store: s, - pendingRequests: map[string]*Request{}, - responseCache: map[string]*responseCacheEntry{}, - responseCh: make(chan Response), - requestCh: make(chan *Request), - clock: clock, - cacheExpiryTime: responseExpiryTime, - } - h.Service, h.eng = services.Config{ - Name: "Handler", - Start: h.start, - }.NewServiceEngine(lggr) - return h -} - -func (h *Handler) SendResponse(ctx context.Context, resp Response) { - select { - case <-ctx.Done(): - return - case h.responseCh <- resp: - } -} - -func (h *Handler) SendRequest(ctx context.Context, r *Request) { - select { - case <-ctx.Done(): - return - case h.requestCh <- r: - } -} - -func (h *Handler) start(_ context.Context) error { - h.eng.Go(h.worker) - return nil -} - -func (h *Handler) worker(ctx context.Context) { - responseCacheExpiryTicker := h.clock.NewTicker(h.cacheExpiryTime) - defer responseCacheExpiryTicker.Stop() - - // Set to tick at 1 second as this is a sufficient resolution for expiring requests without causing too much overhead - pendingRequestsExpiryTicker := h.clock.NewTicker(1 * time.Second) - defer pendingRequestsExpiryTicker.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-pendingRequestsExpiryTicker.Chan(): - h.expirePendingRequests(ctx) - case <-responseCacheExpiryTicker.Chan(): - h.expireCachedResponses() - case req := <-h.requestCh: - h.pendingRequests[req.WorkflowExecutionID] = req - - existingResponse := h.responseCache[req.WorkflowExecutionID] - if existingResponse != nil { - delete(h.responseCache, req.WorkflowExecutionID) - h.eng.Debugw("Found cached response for request", "workflowExecutionID", req.WorkflowExecutionID) - h.sendResponse(ctx, req, existingResponse.response) - continue - } - - if err := h.store.Add(req); err != nil { - h.eng.Errorw("failed to add request to store", "err", err) - } - - case resp := <-h.responseCh: - req, wasPresent := h.store.evict(resp.WorkflowExecutionID) - if !wasPresent { - h.responseCache[resp.WorkflowExecutionID] = &responseCacheEntry{ - response: resp, - entryTime: h.clock.Now(), - } - h.eng.Debugw("Caching response without request", "workflowExecutionID", resp.WorkflowExecutionID) - continue - } - - h.sendResponse(ctx, req, resp) - } - } -} - -func (h *Handler) sendResponse(ctx context.Context, req *Request, resp Response) { - select { - case <-ctx.Done(): - return - case req.CallbackCh <- resp: - close(req.CallbackCh) - delete(h.pendingRequests, req.WorkflowExecutionID) - } -} - -func (h *Handler) expirePendingRequests(ctx context.Context) { - now := h.clock.Now() - - for _, req := range h.pendingRequests { - if now.After(req.ExpiresAt) { - resp := Response{ - WorkflowExecutionID: req.WorkflowExecutionID, - Err: fmt.Errorf("timeout exceeded: could not process request before expiry %s", req.WorkflowExecutionID), - } - h.store.evict(req.WorkflowExecutionID) - h.sendResponse(ctx, req, resp) - } - } -} - -func (h *Handler) expireCachedResponses() { - for k, v := range h.responseCache { - if h.clock.Since(v.entryTime) > h.cacheExpiryTime { - delete(h.responseCache, k) - h.eng.Debugw("Expired response", "workflowExecutionID", k) - } - } -} diff --git a/pkg/capabilities/consensus/ocr3/transmitter_test.go b/pkg/capabilities/consensus/ocr3/transmitter_test.go index f651ceafd..0a7179723 100644 --- a/pkg/capabilities/consensus/ocr3/transmitter_test.go +++ b/pkg/capabilities/consensus/ocr3/transmitter_test.go @@ -17,8 +17,8 @@ import ( "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" "github.com/smartcontractkit/chainlink-common/pkg/capabilities" - "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/requests" pbtypes "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/types" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/requests" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/services/servicetest" "github.com/smartcontractkit/chainlink-common/pkg/types/core/mocks" @@ -31,7 +31,7 @@ func TestTransmitter(t *testing.T) { repID := []byte{0xf0, 0xe0} ctx := t.Context() lggr := logger.Test(t) - s := requests.NewStore() + s := requests.NewStore[*ReportRequest]() weid := uuid.New().String() @@ -120,7 +120,7 @@ func TestTransmitter_ShouldReportFalse(t *testing.T) { wowner := "foo-owner" ctx := t.Context() lggr := logger.Test(t) - s := requests.NewStore() + s := requests.NewStore[*ReportRequest]() weid := uuid.New().String() diff --git a/pkg/capabilities/consensus/requests/handler.go b/pkg/capabilities/consensus/requests/handler.go new file mode 100644 index 000000000..ece247548 --- /dev/null +++ b/pkg/capabilities/consensus/requests/handler.go @@ -0,0 +1,160 @@ +package requests + +import ( + "context" + "time" + + "github.com/jonboulle/clockwork" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services" +) + +type responseCacheEntry[R ConsensusResponse] struct { + response R + entryTime time.Time +} + +type ConsensusRequest[T any, R ConsensusResponse] interface { + ID() string + Copy() T + ExpiryTime() time.Time + SendResponse(ctx context.Context, response R) + SendTimeout(ctx context.Context) +} + +type ConsensusResponse interface { + RequestID() string +} + +type Handler[T ConsensusRequest[T, R], R ConsensusResponse] struct { + services.Service + eng *services.Engine + + store *Store[T, R] + + pendingRequests map[string]T + + responseCache map[string]*responseCacheEntry[R] + cacheExpiryTime time.Duration + + responseCh chan R + requestCh chan T + + clock clockwork.Clock +} + +func NewHandler[T ConsensusRequest[T, R], R ConsensusResponse](lggr logger.Logger, s *Store[T, R], clock clockwork.Clock, responseExpiryTime time.Duration) *Handler[T, R] { + h := &Handler[T, R]{ + store: s, + pendingRequests: map[string]T{}, + responseCache: map[string]*responseCacheEntry[R]{}, + responseCh: make(chan R), + requestCh: make(chan T), + clock: clock, + cacheExpiryTime: responseExpiryTime, + } + h.Service, h.eng = services.Config{ + Name: "Handler", + Start: h.start, + }.NewServiceEngine(lggr) + return h +} + +func (h *Handler[T, R]) SendResponse(ctx context.Context, resp R) { + select { + case <-ctx.Done(): + return + case h.responseCh <- resp: + } +} + +func (h *Handler[T, R]) SendRequest(ctx context.Context, r T) { + select { + case <-ctx.Done(): + return + case h.requestCh <- r: + } +} + +func (h *Handler[T, R]) start(_ context.Context) error { + h.eng.Go(h.worker) + return nil +} + +func (h *Handler[T, R]) worker(ctx context.Context) { + responseCacheExpiryTicker := h.clock.NewTicker(h.cacheExpiryTime) + defer responseCacheExpiryTicker.Stop() + + // Set to tick at 1 second as this is a sufficient resolution for expiring requests without causing too much overhead + pendingRequestsExpiryTicker := h.clock.NewTicker(1 * time.Second) + defer pendingRequestsExpiryTicker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-pendingRequestsExpiryTicker.Chan(): + h.expirePendingRequests(ctx) + case <-responseCacheExpiryTicker.Chan(): + h.expireCachedResponses() + case req := <-h.requestCh: + h.pendingRequests[req.ID()] = req + + existingResponse := h.responseCache[req.ID()] + if existingResponse != nil { + delete(h.responseCache, req.ID()) + h.eng.Debugw("Found cached response for request", "requestID", req.ID) + h.sendResponse(ctx, req, existingResponse.response) + continue + } + + if err := h.store.Add(req); err != nil { + h.eng.Errorw("failed to add request to store", "err", err) + } + + case resp := <-h.responseCh: + req, wasPresent := h.store.Evict(resp.RequestID()) + if !wasPresent { + h.responseCache[resp.RequestID()] = &responseCacheEntry[R]{ + response: resp, + entryTime: h.clock.Now(), + } + h.eng.Debugw("Caching response without request", "requestID", resp.RequestID()) + continue + } + + h.sendResponse(ctx, req, resp) + } + } +} + +func (h *Handler[T, R]) sendResponse(ctx context.Context, req T, resp R) { + req.SendResponse(ctx, resp) + delete(h.pendingRequests, req.ID()) +} + +func (h *Handler[T, R]) sendTimeout(ctx context.Context, req T) { + req.SendTimeout(ctx) + delete(h.pendingRequests, req.ID()) +} + +func (h *Handler[T, R]) expirePendingRequests(ctx context.Context) { + now := h.clock.Now() + + for _, req := range h.pendingRequests { + if now.After(req.ExpiryTime()) { + h.store.Evict(req.ID()) + h.sendTimeout(ctx, req) + } + } +} + +func (h *Handler[T, R]) expireCachedResponses() { + for k, v := range h.responseCache { + if h.clock.Since(v.entryTime) > h.cacheExpiryTime { + delete(h.responseCache, k) + h.eng.Debugw("Expired response", "requestID", k) + } + } +} diff --git a/pkg/capabilities/consensus/ocr3/requests/handler_test.go b/pkg/capabilities/consensus/requests/handler_test.go similarity index 69% rename from pkg/capabilities/consensus/ocr3/requests/handler_test.go rename to pkg/capabilities/consensus/requests/handler_test.go index 86b6e0a0d..3b6a0e038 100644 --- a/pkg/capabilities/consensus/ocr3/requests/handler_test.go +++ b/pkg/capabilities/consensus/requests/handler_test.go @@ -8,8 +8,10 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3/requests" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/requests" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services/servicetest" "github.com/smartcontractkit/chainlink-common/pkg/values" ) @@ -18,11 +20,11 @@ func Test_Handler_SendsResponse(t *testing.T) { lggr := logger.Test(t) ctx := t.Context() - h := requests.NewHandler(lggr, requests.NewStore(), clockwork.NewFakeClockAt(time.Now()), 1*time.Second) + h := requests.NewHandler(lggr, requests.NewStore[*ocr3.ReportRequest, ocr3.ReportResponse](), clockwork.NewFakeClockAt(time.Now()), 1*time.Second) servicetest.Run(t, h) - responseCh := make(chan requests.Response, 10) - h.SendRequest(ctx, &requests.Request{ + responseCh := make(chan ocr3.ReportResponse, 10) + h.SendRequest(ctx, &ocr3.ReportRequest{ WorkflowExecutionID: "test", CallbackCh: responseCh, ExpiresAt: time.Now().Add(1 * time.Hour), @@ -31,7 +33,7 @@ func Test_Handler_SendsResponse(t *testing.T) { testVal, err := values.NewMap(map[string]any{"result": "testval"}) require.NoError(t, err) - h.SendResponse(ctx, requests.Response{ + h.SendResponse(ctx, ocr3.ReportResponse{ WorkflowExecutionID: "test", Value: testVal, Err: nil, @@ -45,19 +47,19 @@ func Test_Handler_SendsResponseToLateRequest(t *testing.T) { lggr := logger.Test(t) ctx := t.Context() - h := requests.NewHandler(lggr, requests.NewStore(), clockwork.NewFakeClockAt(time.Now()), 1*time.Second) + h := requests.NewHandler(lggr, requests.NewStore[*ocr3.ReportRequest](), clockwork.NewFakeClockAt(time.Now()), 1*time.Second) servicetest.Run(t, h) testVal, err := values.NewMap(map[string]any{"result": "testval"}) require.NoError(t, err) - h.SendResponse(ctx, requests.Response{ + h.SendResponse(ctx, ocr3.ReportResponse{ WorkflowExecutionID: "test", Value: testVal, Err: nil, }) - responseCh := make(chan requests.Response, 10) - h.SendRequest(ctx, &requests.Request{ + responseCh := make(chan ocr3.ReportResponse, 10) + h.SendRequest(ctx, &ocr3.ReportRequest{ WorkflowExecutionID: "test", CallbackCh: responseCh, ExpiresAt: time.Now().Add(1 * time.Hour), @@ -71,20 +73,20 @@ func Test_Handler_SendsResponseToLateRequestOnlyOnce(t *testing.T) { lggr := logger.Test(t) ctx := t.Context() - h := requests.NewHandler(lggr, requests.NewStore(), clockwork.NewFakeClockAt(time.Now()), 1*time.Second) + h := requests.NewHandler(lggr, requests.NewStore[*ocr3.ReportRequest](), clockwork.NewFakeClockAt(time.Now()), 1*time.Second) servicetest.Run(t, h) testVal, err := values.NewMap(map[string]any{"result": "testval"}) require.NoError(t, err) - h.SendResponse(ctx, requests.Response{ + h.SendResponse(ctx, ocr3.ReportResponse{ WorkflowExecutionID: "test", Value: testVal, Err: nil, }) - responseCh := make(chan requests.Response, 10) - h.SendRequest(ctx, &requests.Request{ + responseCh := make(chan ocr3.ReportResponse, 10) + h.SendRequest(ctx, &ocr3.ReportRequest{ WorkflowExecutionID: "test", CallbackCh: responseCh, ExpiresAt: time.Now().Add(1 * time.Hour), @@ -95,8 +97,8 @@ func Test_Handler_SendsResponseToLateRequestOnlyOnce(t *testing.T) { resp := <-responseCh require.Equal(t, testVal, resp.Value) - responseCh = make(chan requests.Response, 10) - h.SendRequest(ctx, &requests.Request{ + responseCh = make(chan ocr3.ReportResponse, 10) + h.SendRequest(ctx, &ocr3.ReportRequest{ WorkflowExecutionID: "test", CallbackCh: responseCh, ExpiresAt: time.Now().Add(1 * time.Hour), @@ -114,11 +116,11 @@ func Test_Handler_PendingRequestsExpiry(t *testing.T) { lggr := logger.Test(t) clock := clockwork.NewFakeClockAt(time.Now()) - h := requests.NewHandler(lggr, requests.NewStore(), clock, 1*time.Second) + h := requests.NewHandler(lggr, requests.NewStore[*ocr3.ReportRequest](), clock, 1*time.Second) servicetest.Run(t, h) - responseCh := make(chan requests.Response, 10) - h.SendRequest(ctx, &requests.Request{ + responseCh := make(chan ocr3.ReportResponse, 10) + h.SendRequest(ctx, &ocr3.ReportRequest{ WorkflowExecutionID: "test", CallbackCh: responseCh, ExpiresAt: time.Now().Add(1 * time.Second), diff --git a/pkg/capabilities/consensus/ocr3/requests/store.go b/pkg/capabilities/consensus/requests/store.go similarity index 52% rename from pkg/capabilities/consensus/ocr3/requests/store.go rename to pkg/capabilities/consensus/requests/store.go index 41419cdf8..b9d664057 100644 --- a/pkg/capabilities/consensus/ocr3/requests/store.go +++ b/pkg/capabilities/consensus/requests/store.go @@ -6,32 +6,29 @@ import ( "sync" ) -// Store stores ongoing consensus requests in an -// in-memory map. -// Note: this object is intended to be thread-safe, -// so any read requests should first deep-copy the returned -// request object via request.Copy(). -type Store struct { +// Store is a generic store for ongoing consensus requests. +// It is thread-safe and uses a map to store requests. +type Store[T ConsensusRequest[T, R], R ConsensusResponse] struct { requestIDs []string - requests map[string]*Request + requests map[string]T mu sync.RWMutex } -func NewStore() *Store { - return &Store{ +func NewStore[T ConsensusRequest[T, R], R ConsensusResponse]() *Store[T, R] { + return &Store[T, R]{ requestIDs: []string{}, - requests: map[string]*Request{}, + requests: map[string]T{}, } } -// GetByIDs is best-effort, doesn't return requests that are not in store +// GetByIDs retrieves requests by their IDs. // The method deep-copies requests before returning them. -func (s *Store) GetByIDs(requestIDs []string) []*Request { +func (s *Store[T, R]) GetByIDs(requestIDs []string) []T { s.mu.RLock() defer s.mu.RUnlock() - o := []*Request{} + o := []T{} for _, r := range requestIDs { gr, ok := s.requests[r] if ok { @@ -42,15 +39,15 @@ func (s *Store) GetByIDs(requestIDs []string) []*Request { return o } -// FirstN returns up to `bathSize` requests. +// FirstN retrieves up to `batchSize` requests. // The method deep-copies requests before returning them. -func (s *Store) FirstN(batchSize int) ([]*Request, error) { +func (s *Store[T, R]) FirstN(batchSize int) ([]T, error) { s.mu.RLock() defer s.mu.RUnlock() if batchSize == 0 { return nil, errors.New("batchsize cannot be 0") } - got := []*Request{} + got := []T{} if len(s.requestIDs) == 0 { return got, nil } @@ -70,31 +67,34 @@ func (s *Store) FirstN(batchSize int) ([]*Request, error) { return got, nil } -func (s *Store) Add(req *Request) error { +// Add adds a new request to the store. +func (s *Store[T, R]) Add(req T) error { s.mu.Lock() defer s.mu.Unlock() - if _, ok := s.requests[req.WorkflowExecutionID]; ok { - return fmt.Errorf("request with id %s already exists", req.WorkflowExecutionID) + if _, ok := s.requests[req.ID()]; ok { + return fmt.Errorf("request with id %s already exists", req.ID()) } - s.requestIDs = append(s.requestIDs, req.WorkflowExecutionID) - s.requests[req.WorkflowExecutionID] = req + s.requestIDs = append(s.requestIDs, req.ID()) + s.requests[req.ID()] = req return nil } -// Get returns the request corresponding to request ID. -// The method deep-copies requests before returning them. -func (s *Store) Get(requestID string) *Request { +// Get retrieves a request by its ID. +// The method deep-copies the request before returning it. +func (s *Store[T, R]) Get(requestID string) T { s.mu.RLock() defer s.mu.RUnlock() rid, ok := s.requests[requestID] if ok { return rid.Copy() } - return nil + var zero T + return zero } -func (s *Store) evict(requestID string) (*Request, bool) { +// Evict removes a request from the store by its ID. +func (s *Store[T, R]) Evict(requestID string) (T, bool) { s.mu.Lock() defer s.mu.Unlock() diff --git a/pkg/capabilities/consensus/ocr3/requests/store_test.go b/pkg/capabilities/consensus/requests/store_test.go similarity index 70% rename from pkg/capabilities/consensus/ocr3/requests/store_test.go rename to pkg/capabilities/consensus/requests/store_test.go index b871b0b96..9203bb017 100644 --- a/pkg/capabilities/consensus/ocr3/requests/store_test.go +++ b/pkg/capabilities/consensus/requests/store_test.go @@ -1,4 +1,4 @@ -package requests +package requests_test import ( "context" @@ -11,14 +11,17 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" "github.com/smartcontractkit/chainlink-common/pkg/values" + + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/ocr3" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/consensus/requests" ) func TestOCR3Store(t *testing.T) { n := time.Now() - s := NewStore() + s := requests.NewStore[*ocr3.ReportRequest]() rid := uuid.New().String() - req := &Request{ + req := &ocr3.ReportRequest{ WorkflowExecutionID: rid, ExpiresAt: n.Add(10 * time.Second), } @@ -34,9 +37,11 @@ func TestOCR3Store(t *testing.T) { }) t.Run("evict", func(t *testing.T) { - _, wasPresent := s.evict(rid) + _, wasPresent := s.Evict(rid) assert.True(t, wasPresent) - assert.Len(t, s.requests, 0) + reqs, err := s.FirstN(10) + require.NoError(t, err) + assert.Len(t, reqs, 0) }) t.Run("firstN", func(t *testing.T) { @@ -52,7 +57,7 @@ func TestOCR3Store(t *testing.T) { t.Run("firstN, batchSize larger than queue", func(t *testing.T) { for i := 0; i < 10; i++ { - err := s.Add(&Request{WorkflowExecutionID: uuid.New().String(), ExpiresAt: n.Add(1 * time.Hour)}) + err := s.Add(&ocr3.ReportRequest{WorkflowExecutionID: uuid.New().String(), ExpiresAt: n.Add(1 * time.Hour)}) require.NoError(t, err) } items, err := s.FirstN(100) @@ -70,42 +75,45 @@ func TestOCR3Store(t *testing.T) { } func TestOCR3Store_ManagesStateConsistently(t *testing.T) { - s := NewStore() + s := requests.NewStore[*ocr3.ReportRequest]() rid := uuid.New().String() - req := &Request{ + req := &ocr3.ReportRequest{ WorkflowExecutionID: rid, } err := s.Add(req) require.NoError(t, err) - assert.Len(t, s.requests, 1) - assert.Len(t, s.requestIDs, 1) + reqs, err := s.FirstN(10) + require.NoError(t, err) - s.GetByIDs([]string{rid}) - assert.Len(t, s.requests, 1) - assert.Len(t, s.requestIDs, 1) + assert.Len(t, reqs, 1) - _, ok := s.evict(rid) + reqs = s.GetByIDs([]string{rid}) + assert.Len(t, reqs, 1) + + _, ok := s.Evict(rid) assert.True(t, ok) - assert.Len(t, s.requests, 0) - assert.Len(t, s.requestIDs, 0) + reqs, err = s.FirstN(10) + require.NoError(t, err) + assert.Len(t, reqs, 0) err = s.Add(req) require.NoError(t, err) - assert.Len(t, s.requests, 1) - assert.Len(t, s.requestIDs, 1) + reqs, err = s.FirstN(10) + require.NoError(t, err) + assert.Len(t, reqs, 1) } func TestOCR3Store_ReadRequestsCopy(t *testing.T) { - s := NewStore() + s := requests.NewStore[*ocr3.ReportRequest]() rid := uuid.New().String() - cb := make(chan Response, 1) + cb := make(chan ocr3.ReportResponse, 1) stopCh := make(chan struct{}, 1) obs, err := values.NewList( []any{"hello", 1}, ) require.NoError(t, err) - req := &Request{ + req := &ocr3.ReportRequest{ WorkflowExecutionID: rid, WorkflowID: "wid", WorkflowName: "name", @@ -124,17 +132,17 @@ func TestOCR3Store_ReadRequestsCopy(t *testing.T) { testCases := []struct { name string - get func(ctx context.Context, rid string) *Request + get func(ctx context.Context, rid string) *ocr3.ReportRequest }{ { name: "get", - get: func(ctx context.Context, rid string) *Request { + get: func(ctx context.Context, rid string) *ocr3.ReportRequest { return s.Get(rid) }, }, { name: "firstN", - get: func(ctx context.Context, rid string) *Request { + get: func(ctx context.Context, rid string) *ocr3.ReportRequest { rs, err2 := s.FirstN(1) require.NoError(t, err2) assert.Len(t, rs, 1) @@ -143,7 +151,7 @@ func TestOCR3Store_ReadRequestsCopy(t *testing.T) { }, { name: "getByIDs", - get: func(ctx context.Context, rid string) *Request { + get: func(ctx context.Context, rid string) *ocr3.ReportRequest { rs := s.GetByIDs([]string{rid}) assert.Len(t, rs, 1) return rs[0] @@ -170,7 +178,7 @@ func TestOCR3Store_ReadRequestsCopy(t *testing.T) { gr.StopCh <- struct{}{} <-stopCh - gr.CallbackCh <- Response{} + gr.CallbackCh <- ocr3.ReportResponse{} <-cb }) } diff --git a/pkg/capabilities/v2/consensus/consensus.pb.go b/pkg/capabilities/v2/consensus/consensus.pb.go index 8d4601074..1cf1ddccb 100644 --- a/pkg/capabilities/v2/consensus/consensus.pb.go +++ b/pkg/capabilities/v2/consensus/consensus.pb.go @@ -27,9 +27,9 @@ var File_capabilities_v2_consensus_consensus_proto protoreflect.FileDescriptor const file_capabilities_v2_consensus_consensus_proto_rawDesc = "" + "\n" + - ")capabilities/v2/consensus/consensus.proto\x12\x14cre.sdk.v2.consensus\x1a\x16values/pb/values.proto\x1a0capabilities/v2/protoc/pkg/pb/cre_metadata.proto\x1a\x1dworkflows/sdk/v2/pb/sdk.proto2g\n" + + ")capabilities/v2/consensus/consensus.proto\x12\x14cre.sdk.v2.consensus\x1a\x16values/pb/values.proto\x1a0capabilities/v2/protoc/pkg/pb/cre_metadata.proto\x1a\x1dworkflows/sdk/v2/pb/sdk.proto2^\n" + "\tConsensus\x12:\n" + - "\x06Simple\x12!.cre.sdk.v2.SimpleConsensusInputs\x1a\r.values.Value\x1a\x1e\x82\xb5\x18\x1a\x12\x18offchain_reporting@1.0.0BLZJgithub.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/consensusb\x06proto3" + "\x06Simple\x12!.cre.sdk.v2.SimpleConsensusInputs\x1a\r.values.Value\x1a\x15\x82\xb5\x18\x11\x12\x0fconsensus@1.0.0BLZJgithub.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/consensusb\x06proto3" var file_capabilities_v2_consensus_consensus_proto_goTypes = []any{ (*pb.SimpleConsensusInputs)(nil), // 0: cre.sdk.v2.SimpleConsensusInputs diff --git a/pkg/capabilities/v2/consensus/consensus.proto b/pkg/capabilities/v2/consensus/consensus.proto index 57d0343d5..bf21e83ff 100644 --- a/pkg/capabilities/v2/consensus/consensus.proto +++ b/pkg/capabilities/v2/consensus/consensus.proto @@ -11,7 +11,7 @@ import "workflows/sdk/v2/pb/sdk.proto"; service Consensus { option (cre.metadata.capability) = { mode: DON - capability_id: "offchain_reporting@1.0.0" + capability_id: "consensus@1.0.0" }; rpc Simple (cre.sdk.v2.SimpleConsensusInputs) returns (values.Value); } diff --git a/pkg/capabilities/v2/consensus/consensus_sdk_gen.go b/pkg/capabilities/v2/consensus/consensus_sdk_gen.go index 9b2c83195..7c25aca8d 100644 --- a/pkg/capabilities/v2/consensus/consensus_sdk_gen.go +++ b/pkg/capabilities/v2/consensus/consensus_sdk_gen.go @@ -23,7 +23,7 @@ func (c *Consensus) Simple(runtime sdk.DonRuntime, input *pb1.SimpleConsensusInp return sdk.PromiseFromResult[*pb.Value](nil, err) } return sdk.Then(runtime.CallCapability(&sdkpb.CapabilityRequest{ - Id: "offchain_reporting@1.0.0", + Id: "consensus@1.0.0", Payload: wrapped, Method: "Simple", }), func(i *sdkpb.CapabilityResponse) (*pb.Value, error) { diff --git a/pkg/capabilities/v2/consensus/consensusmock/consensus_mock_gen.go b/pkg/capabilities/v2/consensus/consensusmock/consensus_mock_gen.go index 521bdb98d..d20b590d9 100644 --- a/pkg/capabilities/v2/consensus/consensusmock/consensus_mock_gen.go +++ b/pkg/capabilities/v2/consensus/consensusmock/consensus_mock_gen.go @@ -67,5 +67,5 @@ func (cap *ConsensusCapability) InvokeTrigger(ctx context.Context, request *sdkp } func (cap *ConsensusCapability) ID() string { - return "offchain_reporting@1.0.0" + return "consensus@1.0.0" } diff --git a/pkg/capabilities/v2/consensus/server/consensus_server_gen.go b/pkg/capabilities/v2/consensus/server/consensus_server_gen.go index 33641fad5..f0187d96f 100644 --- a/pkg/capabilities/v2/consensus/server/consensus_server_gen.go +++ b/pkg/capabilities/v2/consensus/server/consensus_server_gen.go @@ -66,7 +66,7 @@ func (cs *ConsensusServer) Close() error { defer cancel() if cs.capabilityRegistry != nil { - if err := cs.capabilityRegistry.Remove(ctx, "offchain_reporting@1.0.0"); err != nil { + if err := cs.capabilityRegistry.Remove(ctx, "consensus@1.0.0"); err != nil { return err } } @@ -93,7 +93,7 @@ type consensusCapability struct { func (c *consensusCapability) Info(ctx context.Context) (capabilities.CapabilityInfo, error) { // Maybe we do need to split it out, even if the user doesn't see it - return capabilities.NewCapabilityInfo("offchain_reporting@1.0.0", capabilities.CapabilityTypeCombined, c.ConsensusCapability.Description()) + return capabilities.NewCapabilityInfo("consensus@1.0.0", capabilities.CapabilityTypeCombined, c.ConsensusCapability.Description()) } var _ capabilities.ExecutableAndTriggerCapability = (*consensusCapability)(nil) From e845ca66dba04d841ce9309d5aa3499bffeaffbe Mon Sep 17 00:00:00 2001 From: Silas Lenihan Date: Wed, 11 Jun 2025 14:43:38 -0400 Subject: [PATCH 16/16] Made ToSchemaFullName public --- pkg/beholder/schema.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/beholder/schema.go b/pkg/beholder/schema.go index 97f34837a..2f33040c1 100644 --- a/pkg/beholder/schema.go +++ b/pkg/beholder/schema.go @@ -30,14 +30,14 @@ func toSchemaName(m proto.Message) string { } // toSchemaName returns a protobuf message name (full) -func toSchemaFullName(m proto.Message) string { +func ToSchemaFullName(m proto.Message) string { return string(protoimpl.X.MessageTypeOf(m).Descriptor().FullName()) } // toSchemaPath maps a protobuf message to a Beholder schema path func toSchemaPath(m proto.Message, basePath string) string { // Notice: a name like 'platform.on_chain.forwarder.ReportProcessed' - protoName := toSchemaFullName(m) + protoName := ToSchemaFullName(m) // We map to a Beholder schema path like '/platform/on-chain/forwarder/report_processed.proto' protoPath := protoName @@ -85,7 +85,7 @@ func appendRequiredAttrDomain(attrKVs []any, m proto.Message) []any { } // Notice: a name like 'platform.on_chain.forwarder.ReportProcessed' - protoName := toSchemaFullName(m) + protoName := ToSchemaFullName(m) // Extract first path component (entrypoint package) as a domain domain := "unknown"