diff --git a/Makefile b/Makefile index 8cba3fed..cd463f44 100644 --- a/Makefile +++ b/Makefile @@ -76,8 +76,12 @@ lintfix: $(BIN)/golangci-lint $(BIN)/buf ## Automatically fix some lint errors golangci-lint run --fix --modules-download-mode=readonly --timeout=3m0s buf format -w +.PHONY: generate-annotations +generate-annotations: $(BIN)/buf $(BIN)/protoc-gen-go ## Generate only the proto annotations (no protoc-gen-connect-go needed) + cd ./cmd/protoc-gen-connect-go && buf generate + .PHONY: generate -generate: $(BIN)/buf $(BIN)/protoc-gen-go $(BIN)/protoc-gen-connect-go $(BIN)/license-header ## Regenerate code and licenses +generate: generate-annotations $(BIN)/buf $(BIN)/protoc-gen-go $(BIN)/protoc-gen-connect-go $(BIN)/license-header ## Regenerate code and licenses go mod tidy cd ./internal/conformance && go mod tidy buf generate diff --git a/cmd/protoc-gen-connect-go/buf.gen.yaml b/cmd/protoc-gen-connect-go/buf.gen.yaml new file mode 100644 index 00000000..22127b44 --- /dev/null +++ b/cmd/protoc-gen-connect-go/buf.gen.yaml @@ -0,0 +1,13 @@ +version: v2 +plugins: + - local: protoc-gen-go + out: gen + opt: paths=source_relative +# Currently, buf doesn't support per-module generation config, which +# is required to run a different set of plugins/plugin options for +# the testdata modules. As a workaround, we use `inputs` to restrict +# the generation config in this file to only the "proto" module, and +# have separate buf.gen.yaml files for the testdata modules. +# See: https://github.com/bufbuild/buf/issues/3060 +inputs: + - directory: proto diff --git a/cmd/protoc-gen-connect-go/buf.yaml b/cmd/protoc-gen-connect-go/buf.yaml new file mode 100644 index 00000000..ad64f54e --- /dev/null +++ b/cmd/protoc-gen-connect-go/buf.yaml @@ -0,0 +1,13 @@ +version: v2 +modules: + - path: proto + # this must be declared as a module in the same workspace as "proto", + # so it can import symbols from annotations.proto. + - path: internal/testdata/methodtimeouts +lint: + use: + - STANDARD + disallow_comment_ignores: true +breaking: + use: + - WIRE_JSON diff --git a/cmd/protoc-gen-connect-go/gen/connectrpc/go/options/v1/annotations.pb.go b/cmd/protoc-gen-connect-go/gen/connectrpc/go/options/v1/annotations.pb.go new file mode 100644 index 00000000..1716d27c --- /dev/null +++ b/cmd/protoc-gen-connect-go/gen/connectrpc/go/options/v1/annotations.pb.go @@ -0,0 +1,102 @@ +// Copyright 2021-2025 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.9 +// protoc (unknown) +// source: connectrpc/go/options/v1/annotations.proto + +package optionsv1 + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + descriptorpb "google.golang.org/protobuf/types/descriptorpb" + reflect "reflect" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +var file_connectrpc_go_options_v1_annotations_proto_extTypes = []protoimpl.ExtensionInfo{ + { + ExtendedType: (*descriptorpb.MethodOptions)(nil), + ExtensionType: (*MethodTimeouts)(nil), + Field: 50001, + Name: "connectrpc.go.options.v1.timeouts", + Tag: "bytes,50001,opt,name=timeouts", + Filename: "connectrpc/go/options/v1/annotations.proto", + }, +} + +// Extension fields to descriptorpb.MethodOptions. +var ( + // We should get a unique ID from protobuf-global-extension-registry@google.com + // for ConnectRPC project, and use it as the field number here. + // + // optional connectrpc.go.options.v1.MethodTimeouts timeouts = 50001; + E_Timeouts = &file_connectrpc_go_options_v1_annotations_proto_extTypes[0] +) + +var File_connectrpc_go_options_v1_annotations_proto protoreflect.FileDescriptor + +const file_connectrpc_go_options_v1_annotations_proto_rawDesc = "" + + "\n" + + "*connectrpc/go/options/v1/annotations.proto\x12\x18connectrpc.go.options.v1\x1a google/protobuf/descriptor.proto\x1a&connectrpc/go/options/v1/connect.proto:f\n" + + "\btimeouts\x12\x1e.google.protobuf.MethodOptions\x18ц\x03 \x01(\v2(.connectrpc.go.options.v1.MethodTimeoutsR\btimeoutsBYZWconnectrpc.com/connect/cmd/protoc-gen-connect-go/gen/connectrpc/go/options/v1;optionsv1b\x06proto3" + +var file_connectrpc_go_options_v1_annotations_proto_goTypes = []any{ + (*descriptorpb.MethodOptions)(nil), // 0: google.protobuf.MethodOptions + (*MethodTimeouts)(nil), // 1: connectrpc.go.options.v1.MethodTimeouts +} +var file_connectrpc_go_options_v1_annotations_proto_depIdxs = []int32{ + 0, // 0: connectrpc.go.options.v1.timeouts:extendee -> google.protobuf.MethodOptions + 1, // 1: connectrpc.go.options.v1.timeouts:type_name -> connectrpc.go.options.v1.MethodTimeouts + 2, // [2:2] is the sub-list for method output_type + 2, // [2:2] is the sub-list for method input_type + 1, // [1:2] is the sub-list for extension type_name + 0, // [0:1] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_connectrpc_go_options_v1_annotations_proto_init() } +func file_connectrpc_go_options_v1_annotations_proto_init() { + if File_connectrpc_go_options_v1_annotations_proto != nil { + return + } + file_connectrpc_go_options_v1_connect_proto_init() + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_connectrpc_go_options_v1_annotations_proto_rawDesc), len(file_connectrpc_go_options_v1_annotations_proto_rawDesc)), + NumEnums: 0, + NumMessages: 0, + NumExtensions: 1, + NumServices: 0, + }, + GoTypes: file_connectrpc_go_options_v1_annotations_proto_goTypes, + DependencyIndexes: file_connectrpc_go_options_v1_annotations_proto_depIdxs, + ExtensionInfos: file_connectrpc_go_options_v1_annotations_proto_extTypes, + }.Build() + File_connectrpc_go_options_v1_annotations_proto = out.File + file_connectrpc_go_options_v1_annotations_proto_goTypes = nil + file_connectrpc_go_options_v1_annotations_proto_depIdxs = nil +} diff --git a/cmd/protoc-gen-connect-go/gen/connectrpc/go/options/v1/connect.pb.go b/cmd/protoc-gen-connect-go/gen/connectrpc/go/options/v1/connect.pb.go new file mode 100644 index 00000000..46873c0a --- /dev/null +++ b/cmd/protoc-gen-connect-go/gen/connectrpc/go/options/v1/connect.pb.go @@ -0,0 +1,150 @@ +// Copyright 2021-2025 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.9 +// protoc (unknown) +// source: connectrpc/go/options/v1/connect.proto + +package optionsv1 + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type MethodTimeouts struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Timeout in milliseconds to read the entire request + // (including the body). + // A negative value means no timeout. + ReadMs int64 `protobuf:"varint,1,opt,name=read_ms,json=readMs,proto3" json:"read_ms,omitempty"` + // Timeout in milliseconds for writing the response. + // A negative value means no timeout. + WriteMs int64 `protobuf:"varint,2,opt,name=write_ms,json=writeMs,proto3" json:"write_ms,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *MethodTimeouts) Reset() { + *x = MethodTimeouts{} + mi := &file_connectrpc_go_options_v1_connect_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *MethodTimeouts) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MethodTimeouts) ProtoMessage() {} + +func (x *MethodTimeouts) ProtoReflect() protoreflect.Message { + mi := &file_connectrpc_go_options_v1_connect_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MethodTimeouts.ProtoReflect.Descriptor instead. +func (*MethodTimeouts) Descriptor() ([]byte, []int) { + return file_connectrpc_go_options_v1_connect_proto_rawDescGZIP(), []int{0} +} + +func (x *MethodTimeouts) GetReadMs() int64 { + if x != nil { + return x.ReadMs + } + return 0 +} + +func (x *MethodTimeouts) GetWriteMs() int64 { + if x != nil { + return x.WriteMs + } + return 0 +} + +var File_connectrpc_go_options_v1_connect_proto protoreflect.FileDescriptor + +const file_connectrpc_go_options_v1_connect_proto_rawDesc = "" + + "\n" + + "&connectrpc/go/options/v1/connect.proto\x12\x18connectrpc.go.options.v1\"D\n" + + "\x0eMethodTimeouts\x12\x17\n" + + "\aread_ms\x18\x01 \x01(\x03R\x06readMs\x12\x19\n" + + "\bwrite_ms\x18\x02 \x01(\x03R\awriteMsBYZWconnectrpc.com/connect/cmd/protoc-gen-connect-go/gen/connectrpc/go/options/v1;optionsv1b\x06proto3" + +var ( + file_connectrpc_go_options_v1_connect_proto_rawDescOnce sync.Once + file_connectrpc_go_options_v1_connect_proto_rawDescData []byte +) + +func file_connectrpc_go_options_v1_connect_proto_rawDescGZIP() []byte { + file_connectrpc_go_options_v1_connect_proto_rawDescOnce.Do(func() { + file_connectrpc_go_options_v1_connect_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_connectrpc_go_options_v1_connect_proto_rawDesc), len(file_connectrpc_go_options_v1_connect_proto_rawDesc))) + }) + return file_connectrpc_go_options_v1_connect_proto_rawDescData +} + +var file_connectrpc_go_options_v1_connect_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_connectrpc_go_options_v1_connect_proto_goTypes = []any{ + (*MethodTimeouts)(nil), // 0: connectrpc.go.options.v1.MethodTimeouts +} +var file_connectrpc_go_options_v1_connect_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_connectrpc_go_options_v1_connect_proto_init() } +func file_connectrpc_go_options_v1_connect_proto_init() { + if File_connectrpc_go_options_v1_connect_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_connectrpc_go_options_v1_connect_proto_rawDesc), len(file_connectrpc_go_options_v1_connect_proto_rawDesc)), + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_connectrpc_go_options_v1_connect_proto_goTypes, + DependencyIndexes: file_connectrpc_go_options_v1_connect_proto_depIdxs, + MessageInfos: file_connectrpc_go_options_v1_connect_proto_msgTypes, + }.Build() + File_connectrpc_go_options_v1_connect_proto = out.File + file_connectrpc_go_options_v1_connect_proto_goTypes = nil + file_connectrpc_go_options_v1_connect_proto_depIdxs = nil +} diff --git a/cmd/protoc-gen-connect-go/internal/testdata/methodtimeouts/buf.gen.yaml b/cmd/protoc-gen-connect-go/internal/testdata/methodtimeouts/buf.gen.yaml new file mode 100644 index 00000000..2787be73 --- /dev/null +++ b/cmd/protoc-gen-connect-go/internal/testdata/methodtimeouts/buf.gen.yaml @@ -0,0 +1,9 @@ +version: v2 +plugins: + - local: protoc-gen-go + out: gen + opt: paths=source_relative + - local: protoc-gen-connect-go + out: gen + opt: paths=source_relative +clean: true diff --git a/cmd/protoc-gen-connect-go/internal/testdata/methodtimeouts/gen/genconnect/methodtimeouts.connect.go b/cmd/protoc-gen-connect-go/internal/testdata/methodtimeouts/gen/genconnect/methodtimeouts.connect.go new file mode 100644 index 00000000..012e12bb --- /dev/null +++ b/cmd/protoc-gen-connect-go/internal/testdata/methodtimeouts/gen/genconnect/methodtimeouts.connect.go @@ -0,0 +1,243 @@ +// Copyright 2021-2025 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by protoc-gen-connect-go. DO NOT EDIT. +// +// Source: methodtimeouts.proto + +package genconnect + +import ( + connect "connectrpc.com/connect" + gen "connectrpc.com/connect/cmd/protoc-gen-connect-go/internal/testdata/methodtimeouts/gen" + context "context" + errors "errors" + http "net/http" + strings "strings" + time "time" +) + +// This is a compile-time assertion to ensure that this generated file and the connect package are +// compatible. If you get a compiler error that this constant is not defined, this code was +// generated with a version of connect newer than the one compiled into your binary. You can fix the +// problem by either regenerating this code with an older version of connect or updating the connect +// version compiled into your binary. +const _ = connect.IsAtLeastVersion1_13_0 + +const ( + // TestServiceName is the fully-qualified name of the TestService service. + TestServiceName = "connect.test.method_timeouts.TestService" +) + +// These constants are the fully-qualified names of the RPCs defined in this package. They're +// exposed at runtime as Spec.Procedure and as the final two segments of the HTTP route. +// +// Note that these are different from the fully-qualified method names used by +// google.golang.org/protobuf/reflect/protoreflect. To convert from these constants to +// reflection-formatted method names, remove the leading slash and convert the remaining slash to a +// period. +const ( + // TestServiceMethod0Procedure is the fully-qualified name of the TestService's Method0 RPC. + TestServiceMethod0Procedure = "/connect.test.method_timeouts.TestService/Method0" + // TestServiceMethod1Procedure is the fully-qualified name of the TestService's Method1 RPC. + TestServiceMethod1Procedure = "/connect.test.method_timeouts.TestService/Method1" + // TestServiceMethod2Procedure is the fully-qualified name of the TestService's Method2 RPC. + TestServiceMethod2Procedure = "/connect.test.method_timeouts.TestService/Method2" + // TestServiceMethod3Procedure is the fully-qualified name of the TestService's Method3 RPC. + TestServiceMethod3Procedure = "/connect.test.method_timeouts.TestService/Method3" + // TestServiceMethod4Procedure is the fully-qualified name of the TestService's Method4 RPC. + TestServiceMethod4Procedure = "/connect.test.method_timeouts.TestService/Method4" +) + +// TestServiceClient is a client for the connect.test.method_timeouts.TestService service. +type TestServiceClient interface { + Method0(context.Context, *connect.Request[gen.Request]) (*connect.Response[gen.Response], error) + Method1(context.Context, *connect.Request[gen.Request]) (*connect.Response[gen.Response], error) + Method2(context.Context, *connect.Request[gen.Request]) (*connect.Response[gen.Response], error) + Method3(context.Context, *connect.Request[gen.Request]) (*connect.Response[gen.Response], error) + Method4(context.Context, *connect.Request[gen.Request]) (*connect.Response[gen.Response], error) +} + +// NewTestServiceClient constructs a client for the connect.test.method_timeouts.TestService +// service. By default, it uses the Connect protocol with the binary Protobuf Codec, asks for +// gzipped responses, and sends uncompressed requests. To use the gRPC or gRPC-Web protocols, supply +// the connect.WithGRPC() or connect.WithGRPCWeb() options. +// +// The URL supplied here should be the base URL for the Connect or gRPC server (for example, +// http://api.acme.com or https://acme.com/grpc). +func NewTestServiceClient(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) TestServiceClient { + baseURL = strings.TrimRight(baseURL, "/") + testServiceMethods := gen.File_methodtimeouts_proto.Services().ByName("TestService").Methods() + return &testServiceClient{ + method0: connect.NewClient[gen.Request, gen.Response]( + httpClient, + baseURL+TestServiceMethod0Procedure, + connect.WithSchema(testServiceMethods.ByName("Method0")), + connect.WithClientOptions(opts...), + ), + method1: connect.NewClient[gen.Request, gen.Response]( + httpClient, + baseURL+TestServiceMethod1Procedure, + connect.WithSchema(testServiceMethods.ByName("Method1")), + connect.WithClientOptions(opts...), + ), + method2: connect.NewClient[gen.Request, gen.Response]( + httpClient, + baseURL+TestServiceMethod2Procedure, + connect.WithSchema(testServiceMethods.ByName("Method2")), + connect.WithClientOptions(opts...), + ), + method3: connect.NewClient[gen.Request, gen.Response]( + httpClient, + baseURL+TestServiceMethod3Procedure, + connect.WithSchema(testServiceMethods.ByName("Method3")), + connect.WithClientOptions(opts...), + ), + method4: connect.NewClient[gen.Request, gen.Response]( + httpClient, + baseURL+TestServiceMethod4Procedure, + connect.WithSchema(testServiceMethods.ByName("Method4")), + connect.WithClientOptions(opts...), + ), + } +} + +// testServiceClient implements TestServiceClient. +type testServiceClient struct { + method0 *connect.Client[gen.Request, gen.Response] + method1 *connect.Client[gen.Request, gen.Response] + method2 *connect.Client[gen.Request, gen.Response] + method3 *connect.Client[gen.Request, gen.Response] + method4 *connect.Client[gen.Request, gen.Response] +} + +// Method0 calls connect.test.method_timeouts.TestService.Method0. +func (c *testServiceClient) Method0(ctx context.Context, req *connect.Request[gen.Request]) (*connect.Response[gen.Response], error) { + return c.method0.CallUnary(ctx, req) +} + +// Method1 calls connect.test.method_timeouts.TestService.Method1. +func (c *testServiceClient) Method1(ctx context.Context, req *connect.Request[gen.Request]) (*connect.Response[gen.Response], error) { + return c.method1.CallUnary(ctx, req) +} + +// Method2 calls connect.test.method_timeouts.TestService.Method2. +func (c *testServiceClient) Method2(ctx context.Context, req *connect.Request[gen.Request]) (*connect.Response[gen.Response], error) { + return c.method2.CallUnary(ctx, req) +} + +// Method3 calls connect.test.method_timeouts.TestService.Method3. +func (c *testServiceClient) Method3(ctx context.Context, req *connect.Request[gen.Request]) (*connect.Response[gen.Response], error) { + return c.method3.CallUnary(ctx, req) +} + +// Method4 calls connect.test.method_timeouts.TestService.Method4. +func (c *testServiceClient) Method4(ctx context.Context, req *connect.Request[gen.Request]) (*connect.Response[gen.Response], error) { + return c.method4.CallUnary(ctx, req) +} + +// TestServiceHandler is an implementation of the connect.test.method_timeouts.TestService service. +type TestServiceHandler interface { + Method0(context.Context, *connect.Request[gen.Request]) (*connect.Response[gen.Response], error) + Method1(context.Context, *connect.Request[gen.Request]) (*connect.Response[gen.Response], error) + Method2(context.Context, *connect.Request[gen.Request]) (*connect.Response[gen.Response], error) + Method3(context.Context, *connect.Request[gen.Request]) (*connect.Response[gen.Response], error) + Method4(context.Context, *connect.Request[gen.Request]) (*connect.Response[gen.Response], error) +} + +// NewTestServiceHandler builds an HTTP handler from the service implementation. It returns the path +// on which to mount the handler and the handler itself. +// +// By default, handlers support the Connect, gRPC, and gRPC-Web protocols with the binary Protobuf +// and JSON codecs. They also support gzip compression. +func NewTestServiceHandler(svc TestServiceHandler, opts ...connect.HandlerOption) (string, http.Handler) { + testServiceMethods := gen.File_methodtimeouts_proto.Services().ByName("TestService").Methods() + testServiceMethod0Handler := connect.NewUnaryHandler( + TestServiceMethod0Procedure, + svc.Method0, + connect.WithSchema(testServiceMethods.ByName("Method0")), + connect.WithHandlerOptions(opts...), + ) + testServiceMethod1Handler := connect.NewUnaryHandler( + TestServiceMethod1Procedure, + svc.Method1, + connect.WithSchema(testServiceMethods.ByName("Method1")), + connect.WithReadTimeout(time.Duration(1000)*time.Millisecond), + connect.WithWriteTimeout(time.Duration(2000)*time.Millisecond), + connect.WithHandlerOptions(opts...), + ) + testServiceMethod2Handler := connect.NewUnaryHandler( + TestServiceMethod2Procedure, + svc.Method2, + connect.WithSchema(testServiceMethods.ByName("Method2")), + connect.WithReadTimeout(time.Duration(-1)*time.Millisecond), + connect.WithWriteTimeout(time.Duration(-1)*time.Millisecond), + connect.WithHandlerOptions(opts...), + ) + testServiceMethod3Handler := connect.NewUnaryHandler( + TestServiceMethod3Procedure, + svc.Method3, + connect.WithSchema(testServiceMethods.ByName("Method3")), + connect.WithReadTimeout(time.Duration(1000)*time.Millisecond), + connect.WithWriteTimeout(time.Duration(0)*time.Millisecond), + connect.WithHandlerOptions(opts...), + ) + testServiceMethod4Handler := connect.NewUnaryHandler( + TestServiceMethod4Procedure, + svc.Method4, + connect.WithSchema(testServiceMethods.ByName("Method4")), + connect.WithReadTimeout(time.Duration(1000)*time.Millisecond), + connect.WithWriteTimeout(time.Duration(0)*time.Millisecond), + connect.WithHandlerOptions(opts...), + ) + return "/connect.test.method_timeouts.TestService/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case TestServiceMethod0Procedure: + testServiceMethod0Handler.ServeHTTP(w, r) + case TestServiceMethod1Procedure: + testServiceMethod1Handler.ServeHTTP(w, r) + case TestServiceMethod2Procedure: + testServiceMethod2Handler.ServeHTTP(w, r) + case TestServiceMethod3Procedure: + testServiceMethod3Handler.ServeHTTP(w, r) + case TestServiceMethod4Procedure: + testServiceMethod4Handler.ServeHTTP(w, r) + default: + http.NotFound(w, r) + } + }) +} + +// UnimplementedTestServiceHandler returns CodeUnimplemented from all methods. +type UnimplementedTestServiceHandler struct{} + +func (UnimplementedTestServiceHandler) Method0(context.Context, *connect.Request[gen.Request]) (*connect.Response[gen.Response], error) { + return nil, connect.NewError(connect.CodeUnimplemented, errors.New("connect.test.method_timeouts.TestService.Method0 is not implemented")) +} + +func (UnimplementedTestServiceHandler) Method1(context.Context, *connect.Request[gen.Request]) (*connect.Response[gen.Response], error) { + return nil, connect.NewError(connect.CodeUnimplemented, errors.New("connect.test.method_timeouts.TestService.Method1 is not implemented")) +} + +func (UnimplementedTestServiceHandler) Method2(context.Context, *connect.Request[gen.Request]) (*connect.Response[gen.Response], error) { + return nil, connect.NewError(connect.CodeUnimplemented, errors.New("connect.test.method_timeouts.TestService.Method2 is not implemented")) +} + +func (UnimplementedTestServiceHandler) Method3(context.Context, *connect.Request[gen.Request]) (*connect.Response[gen.Response], error) { + return nil, connect.NewError(connect.CodeUnimplemented, errors.New("connect.test.method_timeouts.TestService.Method3 is not implemented")) +} + +func (UnimplementedTestServiceHandler) Method4(context.Context, *connect.Request[gen.Request]) (*connect.Response[gen.Response], error) { + return nil, connect.NewError(connect.CodeUnimplemented, errors.New("connect.test.method_timeouts.TestService.Method4 is not implemented")) +} diff --git a/cmd/protoc-gen-connect-go/internal/testdata/methodtimeouts/gen/methodtimeouts.pb.go b/cmd/protoc-gen-connect-go/internal/testdata/methodtimeouts/gen/methodtimeouts.pb.go new file mode 100644 index 00000000..26c4bf98 --- /dev/null +++ b/cmd/protoc-gen-connect-go/internal/testdata/methodtimeouts/gen/methodtimeouts.pb.go @@ -0,0 +1,184 @@ +// Copyright 2021-2025 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.9 +// protoc (unknown) +// source: methodtimeouts.proto + +package gen + +import ( + _ "connectrpc.com/connect/cmd/protoc-gen-connect-go/gen/connectrpc/go/options/v1" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Request struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Request) Reset() { + *x = Request{} + mi := &file_methodtimeouts_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Request) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Request) ProtoMessage() {} + +func (x *Request) ProtoReflect() protoreflect.Message { + mi := &file_methodtimeouts_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Request.ProtoReflect.Descriptor instead. +func (*Request) Descriptor() ([]byte, []int) { + return file_methodtimeouts_proto_rawDescGZIP(), []int{0} +} + +type Response struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Response) Reset() { + *x = Response{} + mi := &file_methodtimeouts_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Response) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Response) ProtoMessage() {} + +func (x *Response) ProtoReflect() protoreflect.Message { + mi := &file_methodtimeouts_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Response.ProtoReflect.Descriptor instead. +func (*Response) Descriptor() ([]byte, []int) { + return file_methodtimeouts_proto_rawDescGZIP(), []int{1} +} + +var File_methodtimeouts_proto protoreflect.FileDescriptor + +const file_methodtimeouts_proto_rawDesc = "" + + "\n" + + "\x14methodtimeouts.proto\x12\x1cconnect.test.method_timeouts\x1a*connectrpc/go/options/v1/annotations.proto\"\t\n" + + "\aRequest\"\n" + + "\n" + + "\bResponse2\x8b\x04\n" + + "\vTestService\x12Z\n" + + "\aMethod0\x12%.connect.test.method_timeouts.Request\x1a&.connect.test.method_timeouts.Response\"\x00\x12d\n" + + "\aMethod1\x12%.connect.test.method_timeouts.Request\x1a&.connect.test.method_timeouts.Response\"\n" + + "\x8a\xb5\x18\x06\b\xe8\a\x10\xd0\x0f\x12t\n" + + "\aMethod2\x12%.connect.test.method_timeouts.Request\x1a&.connect.test.method_timeouts.Response\"\x1a\x8a\xb5\x18\x16\b\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01\x10\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01\x12a\n" + + "\aMethod3\x12%.connect.test.method_timeouts.Request\x1a&.connect.test.method_timeouts.Response\"\a\x8a\xb5\x18\x03\b\xe8\a\x12a\n" + + "\aMethod4\x12%.connect.test.method_timeouts.Request\x1a&.connect.test.method_timeouts.Response\"\a\x8a\xb5\x18\x03\b\xe8\aB[ZYconnectrpc.com/connect/cmd/protoc-gen-connect-go/internal/testdata/methodtimeouts/gen;genb\x06proto3" + +var ( + file_methodtimeouts_proto_rawDescOnce sync.Once + file_methodtimeouts_proto_rawDescData []byte +) + +func file_methodtimeouts_proto_rawDescGZIP() []byte { + file_methodtimeouts_proto_rawDescOnce.Do(func() { + file_methodtimeouts_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_methodtimeouts_proto_rawDesc), len(file_methodtimeouts_proto_rawDesc))) + }) + return file_methodtimeouts_proto_rawDescData +} + +var file_methodtimeouts_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_methodtimeouts_proto_goTypes = []any{ + (*Request)(nil), // 0: connect.test.method_timeouts.Request + (*Response)(nil), // 1: connect.test.method_timeouts.Response +} +var file_methodtimeouts_proto_depIdxs = []int32{ + 0, // 0: connect.test.method_timeouts.TestService.Method0:input_type -> connect.test.method_timeouts.Request + 0, // 1: connect.test.method_timeouts.TestService.Method1:input_type -> connect.test.method_timeouts.Request + 0, // 2: connect.test.method_timeouts.TestService.Method2:input_type -> connect.test.method_timeouts.Request + 0, // 3: connect.test.method_timeouts.TestService.Method3:input_type -> connect.test.method_timeouts.Request + 0, // 4: connect.test.method_timeouts.TestService.Method4:input_type -> connect.test.method_timeouts.Request + 1, // 5: connect.test.method_timeouts.TestService.Method0:output_type -> connect.test.method_timeouts.Response + 1, // 6: connect.test.method_timeouts.TestService.Method1:output_type -> connect.test.method_timeouts.Response + 1, // 7: connect.test.method_timeouts.TestService.Method2:output_type -> connect.test.method_timeouts.Response + 1, // 8: connect.test.method_timeouts.TestService.Method3:output_type -> connect.test.method_timeouts.Response + 1, // 9: connect.test.method_timeouts.TestService.Method4:output_type -> connect.test.method_timeouts.Response + 5, // [5:10] is the sub-list for method output_type + 0, // [0:5] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_methodtimeouts_proto_init() } +func file_methodtimeouts_proto_init() { + if File_methodtimeouts_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_methodtimeouts_proto_rawDesc), len(file_methodtimeouts_proto_rawDesc)), + NumEnums: 0, + NumMessages: 2, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_methodtimeouts_proto_goTypes, + DependencyIndexes: file_methodtimeouts_proto_depIdxs, + MessageInfos: file_methodtimeouts_proto_msgTypes, + }.Build() + File_methodtimeouts_proto = out.File + file_methodtimeouts_proto_goTypes = nil + file_methodtimeouts_proto_depIdxs = nil +} diff --git a/cmd/protoc-gen-connect-go/internal/testdata/methodtimeouts/methodtimeouts.proto b/cmd/protoc-gen-connect-go/internal/testdata/methodtimeouts/methodtimeouts.proto new file mode 100644 index 00000000..0973b0c0 --- /dev/null +++ b/cmd/protoc-gen-connect-go/internal/testdata/methodtimeouts/methodtimeouts.proto @@ -0,0 +1,56 @@ +// Copyright 2021-2025 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package connect.test.method_timeouts; + +import "connectrpc/go/options/v1/annotations.proto"; + +option go_package = "connectrpc.com/connect/cmd/protoc-gen-connect-go/internal/testdata/methodtimeouts/gen;gen"; + +message Request {} + +message Response {} + +service TestService { + rpc Method0(Request) returns (Response) {} + + rpc Method1(Request) returns (Response) { + option (connectrpc.go.options.v1.timeouts) = { + read_ms: 1000 + write_ms: 2000 + }; + } + + rpc Method2(Request) returns (Response) { + option (connectrpc.go.options.v1.timeouts) = { + read_ms: -1 + write_ms: -1 + }; + } + + rpc Method3(Request) returns (Response) { + option (connectrpc.go.options.v1.timeouts) = { + read_ms: 1000 + }; + } + + rpc Method4(Request) returns (Response) { + option (connectrpc.go.options.v1.timeouts) = { + read_ms: 1000 + write_ms: 0 + }; + } +} diff --git a/cmd/protoc-gen-connect-go/main.go b/cmd/protoc-gen-connect-go/main.go index b4580c1e..5c7706d1 100644 --- a/cmd/protoc-gen-connect-go/main.go +++ b/cmd/protoc-gen-connect-go/main.go @@ -74,7 +74,9 @@ import ( "unicode/utf8" connect "connectrpc.com/connect" + optionsv1 "connectrpc.com/connect/cmd/protoc-gen-connect-go/gen/connectrpc/go/options/v1" "google.golang.org/protobuf/compiler/protogen" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/types/descriptorpb" "google.golang.org/protobuf/types/pluginpb" @@ -85,6 +87,7 @@ const ( errorsPackage = protogen.GoImportPath("errors") httpPackage = protogen.GoImportPath("net/http") stringsPackage = protogen.GoImportPath("strings") + timePackage = protogen.GoImportPath("time") connectPackage = protogen.GoImportPath("connectrpc.com/connect") generatedFilenameExtension = ".connect.go" @@ -502,6 +505,7 @@ func generateServerConstructor(g *protogen.GeneratedFile, file *protogen.File, s isStreamingServer := method.Desc.IsStreamingServer() isStreamingClient := method.Desc.IsStreamingClient() idempotency := methodIdempotency(method) + timeouts := methodTimeouts(method) switch { case isStreamingClient && !isStreamingServer: if simple { @@ -534,6 +538,10 @@ func generateServerConstructor(g *protogen.GeneratedFile, file *protogen.File, s g.P(connectPackage.Ident("WithIdempotency"), "(", connectPackage.Ident("IdempotencyIdempotent"), "),") case connect.IdempotencyUnknown: } + if timeouts != nil { + g.P(connectPackage.Ident("WithReadTimeout"), "(", timePackage.Ident("Duration"), "(", timeouts.ReadMs, ")*", timePackage.Ident("Millisecond"), "),") + g.P(connectPackage.Ident("WithWriteTimeout"), "(", timePackage.Ident("Duration"), "(", timeouts.WriteMs, ")*", timePackage.Ident("Millisecond"), "),") + } g.P(connectPackage.Ident("WithHandlerOptions"), "(opts...),") g.P(")") } @@ -678,6 +686,24 @@ func methodIdempotency(method *protogen.Method) connect.IdempotencyLevel { return connect.IdempotencyUnknown } +func methodTimeouts(method *protogen.Method) *optionsv1.MethodTimeouts { + methodOptions, ok := method.Desc.Options().(*descriptorpb.MethodOptions) + if !ok { + return nil + } + + if !proto.HasExtension(methodOptions, optionsv1.E_Timeouts) { + return nil + } + ext := proto.GetExtension(methodOptions, optionsv1.E_Timeouts) + + timeouts, timeoutsOk := ext.(*optionsv1.MethodTimeouts) + if !timeoutsOk { + return nil + } + return timeouts +} + // Raggedy comments in the generated code are driving me insane. This // word-wrapping function is ruinously inefficient, but it gets the job done. func wrapComments(g *protogen.GeneratedFile, elems ...any) { diff --git a/cmd/protoc-gen-connect-go/main_test.go b/cmd/protoc-gen-connect-go/main_test.go index b3d1d8f8..903f0cbd 100644 --- a/cmd/protoc-gen-connect-go/main_test.go +++ b/cmd/protoc-gen-connect-go/main_test.go @@ -18,6 +18,7 @@ import ( "bytes" "context" "embed" + "fmt" "io" "net/http" "net/http/httptest" @@ -36,10 +37,12 @@ import ( "google.golang.org/protobuf/types/descriptorpb" "google.golang.org/protobuf/types/pluginpb" + optionsv1 "connectrpc.com/connect/cmd/protoc-gen-connect-go/gen/connectrpc/go/options/v1" defaultpackage "connectrpc.com/connect/cmd/protoc-gen-connect-go/internal/testdata/defaultpackage/gen" defaultpackageconnect "connectrpc.com/connect/cmd/protoc-gen-connect-go/internal/testdata/defaultpackage/gen/genconnect" diffpackage "connectrpc.com/connect/cmd/protoc-gen-connect-go/internal/testdata/diffpackage/gen" diffpackagediff "connectrpc.com/connect/cmd/protoc-gen-connect-go/internal/testdata/diffpackage/gen/gendiff" + methodtimeouts "connectrpc.com/connect/cmd/protoc-gen-connect-go/internal/testdata/methodtimeouts/gen" noservice "connectrpc.com/connect/cmd/protoc-gen-connect-go/internal/testdata/noservice/gen" samepackage "connectrpc.com/connect/cmd/protoc-gen-connect-go/internal/testdata/samepackage/gen" simple "connectrpc.com/connect/cmd/protoc-gen-connect-go/internal/testdata/simple/gen" @@ -150,6 +153,32 @@ func TestGenerate(t *testing.T) { assert.NotZero(t, file.GetContent()) testCmpToTestdata(t, file.GetContent(), "internal/testdata/diffpackage/gen/gendiff/diffpackage.connect.go") }) + // Check generated code with method timeouts. + t.Run("methodtimeouts.proto", func(t *testing.T) { + t.Parallel() + fmt.Println(protodesc.ToFileDescriptorProto(optionsv1.File_connectrpc_go_options_v1_annotations_proto).GetName()) + methodTimeoutsFileDesc := protodesc.ToFileDescriptorProto(methodtimeouts.File_methodtimeouts_proto) + req := &pluginpb.CodeGeneratorRequest{ + FileToGenerate: []string{"methodtimeouts.proto"}, + ProtoFile: []*descriptorpb.FileDescriptorProto{ + // ProtoFile needs dependencies listed before the files that depend on them. + protodesc.ToFileDescriptorProto(descriptorpb.File_google_protobuf_descriptor_proto), + protodesc.ToFileDescriptorProto(optionsv1.File_connectrpc_go_options_v1_connect_proto), + protodesc.ToFileDescriptorProto(optionsv1.File_connectrpc_go_options_v1_annotations_proto), + methodTimeoutsFileDesc, + }, + SourceFileDescriptors: []*descriptorpb.FileDescriptorProto{methodTimeoutsFileDesc}, + CompilerVersion: compilerVersion, + } + rsp := testGenerate(t, req) + assert.Nil(t, rsp.Error) + + assert.Equal(t, len(rsp.File), 1) + file := rsp.File[0] + assert.Equal(t, file.GetName(), "connectrpc.com/connect/cmd/protoc-gen-connect-go/internal/testdata/methodtimeouts/gen/genconnect/methodtimeouts.connect.go") + assert.NotZero(t, file.GetContent()) + testCmpToTestdata(t, file.GetContent(), "internal/testdata/methodtimeouts/gen/genconnect/methodtimeouts.connect.go") + }) // Validate package_suffix option. t.Run("ping.proto:invalid_package_suffix", func(t *testing.T) { t.Parallel() diff --git a/cmd/protoc-gen-connect-go/proto/buf.md b/cmd/protoc-gen-connect-go/proto/buf.md new file mode 100644 index 00000000..047a703d --- /dev/null +++ b/cmd/protoc-gen-connect-go/proto/buf.md @@ -0,0 +1,52 @@ +# protoc-gen-connect-go options + +This module provides proto extensions for customising per-method timeouts on +[ConnectRPC Go](https://github.com/connectrpc/connect-go) service handlers. + +## Usage + +Add this module as a dependency in your `buf.yaml`: + +```yaml +version: v2 +deps: + - buf.build/connectrpc/connect-go +``` + +Then import the annotations in your proto file and annotate your RPC methods: + +```proto +syntax = "proto3"; + +import "connectrpc/go/options/v1/annotations.proto"; + +service GreetService { + rpc Greet(GreetRequest) returns (GreetResponse) { + option (connectrpc.go.options.v1.timeouts) = { + read_ms: 3000 // 3 seconds + write_ms: 10000 // 10 seconds + }; + } + + // For streaming RPCs, you can set -1 to disable timeouts entirely. + rpc Subscribe(SubscribeRequest) returns (stream SubscribeResponse) { + option (connectrpc.go.options.v1.timeouts) = { + read_ms: -1 + write_ms: -1 + }; + } +} +``` + +## Timeout values + +| Value | Meaning | +| ----- | --------------------------------------------- | +| `0` | Use the server-wide default timeout (if any) | +| `> 0` | Timeout in milliseconds | +| `< 0` | No timeouts (recommended for streaming RPCs) | + +## Code generation + +These options are consumed by `protoc-gen-connect-go` during code generation. +No additional runtime dependencies are required. diff --git a/cmd/protoc-gen-connect-go/proto/buf.yaml b/cmd/protoc-gen-connect-go/proto/buf.yaml new file mode 100644 index 00000000..13aabbe7 --- /dev/null +++ b/cmd/protoc-gen-connect-go/proto/buf.yaml @@ -0,0 +1,9 @@ +version: v2 +name: buf.build/connectrpc/protoc-gen-connect-go +lint: + use: + - STANDARD + disallow_comment_ignores: true +breaking: + use: + - WIRE_JSON diff --git a/cmd/protoc-gen-connect-go/proto/connectrpc/go/options/v1/annotations.proto b/cmd/protoc-gen-connect-go/proto/connectrpc/go/options/v1/annotations.proto new file mode 100644 index 00000000..6ed6678a --- /dev/null +++ b/cmd/protoc-gen-connect-go/proto/connectrpc/go/options/v1/annotations.proto @@ -0,0 +1,29 @@ +// Copyright 2021-2025 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package connectrpc.go.options.v1; + +import "google/protobuf/descriptor.proto"; +import "connectrpc/go/options/v1/connect.proto"; + +option go_package = "connectrpc.com/connect/cmd/protoc-gen-connect-go/gen/connectrpc/go/options/v1;optionsv1"; + +extend google.protobuf.MethodOptions { + // We should get a unique ID from protobuf-global-extension-registry@google.com + // for ConnectRPC project, and use it as the field number here. + MethodTimeouts timeouts = 50001; +} + diff --git a/cmd/protoc-gen-connect-go/proto/connectrpc/go/options/v1/connect.proto b/cmd/protoc-gen-connect-go/proto/connectrpc/go/options/v1/connect.proto new file mode 100644 index 00000000..9c4258d3 --- /dev/null +++ b/cmd/protoc-gen-connect-go/proto/connectrpc/go/options/v1/connect.proto @@ -0,0 +1,29 @@ +// Copyright 2021-2025 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package connectrpc.go.options.v1; + +option go_package = "connectrpc.com/connect/cmd/protoc-gen-connect-go/gen/connectrpc/go/options/v1;optionsv1"; + +message MethodTimeouts { + // Timeout in milliseconds to read the entire request + // (including the body). + // A negative value means no timeout. + int64 read_ms = 1; + // Timeout in milliseconds for writing the response. + // A negative value means no timeout. + int64 write_ms = 2; +} \ No newline at end of file diff --git a/connect.go b/connect.go index 274a41ee..2bf3cede 100644 --- a/connect.go +++ b/connect.go @@ -30,6 +30,7 @@ import ( "io" "net/http" "net/url" + "time" ) // Version is the semantic version of the connect module. @@ -319,6 +320,8 @@ type Spec struct { Procedure string // for example, "/acme.foo.v1.FooService/Bar" IsClient bool // otherwise we're in a handler IdempotencyLevel IdempotencyLevel + ReadTimeout time.Duration + WriteTimeout time.Duration } // Peer describes the other party to an RPC. diff --git a/handler.go b/handler.go index e33934ea..d3432976 100644 --- a/handler.go +++ b/handler.go @@ -17,6 +17,7 @@ package connect import ( "context" "net/http" + "time" ) // A Handler is the server-side implementation of a single RPC defined by a @@ -253,6 +254,12 @@ func NewBidiStreamHandler[Req, Res any]( // ServeHTTP implements [http.Handler]. func (h *Handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) { + responseController := http.NewResponseController(responseWriter) + if err := applyDeadlines(h.spec.ReadTimeout, h.spec.WriteTimeout, responseController); err != nil { + responseWriter.WriteHeader(http.StatusInternalServerError) + return + } + // We don't need to defer functions to close the request body or read to // EOF: the stream we construct later on already does that, and we only // return early when dealing with misbehaving clients. In those cases, it's @@ -333,6 +340,45 @@ func (h *Handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Re _ = connCloser.Close(h.implementation(ctx, connCloser)) } +// getDeadline returns a pointer to a time.Time with the given timeout. +// If the timeout is 0 (i.e. not set), nil is returned. +// If the timeout is negative, the zero value is returned to indicate no deadline. +// Otherwise, a time.Time with the given timeout is returned. +func getDeadline(timeout time.Duration) *time.Time { + if timeout == 0 { + return nil + } + if timeout < 0 { + return &time.Time{} + } + t := time.Now().Add(timeout) + return &t +} + +// deadlineSetter sets read/write deadlines. *http.ResponseController implements this. +type deadlineSetter interface { + SetReadDeadline(t time.Time) error + SetWriteDeadline(t time.Time) error +} + +// applyDeadlines applies read and write timeouts to the setter (e.g. an http.ResponseController). +// It returns an error if setting any deadline fails. +func applyDeadlines(readTimeout, writeTimeout time.Duration, setter deadlineSetter) error { + readDeadline := getDeadline(readTimeout) + if readDeadline != nil { + if err := setter.SetReadDeadline(*readDeadline); err != nil { + return err + } + } + writeDeadline := getDeadline(writeTimeout) + if writeDeadline != nil { + if err := setter.SetWriteDeadline(*writeDeadline); err != nil { + return err + } + } + return nil +} + type handlerConfig struct { CompressionPools map[string]*compressionPool CompressionNames []string @@ -347,6 +393,8 @@ type handlerConfig struct { BufferPool *bufferPool ReadMaxBytes int SendMaxBytes int + ReadTimeout time.Duration + WriteTimeout time.Duration StreamType StreamType } @@ -374,6 +422,8 @@ func (c *handlerConfig) newSpec() Spec { Schema: c.Schema, StreamType: c.StreamType, IdempotencyLevel: c.IdempotencyLevel, + ReadTimeout: c.ReadTimeout, + WriteTimeout: c.WriteTimeout, } } diff --git a/handler_test.go b/handler_test.go new file mode 100644 index 00000000..b4949251 --- /dev/null +++ b/handler_test.go @@ -0,0 +1,242 @@ +// Copyright 2021-2025 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connect + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" + + "connectrpc.com/connect/internal/assert" +) + +// fakeDeadlineSetter records deadlines passed to SetReadDeadline/SetWriteDeadline for testing. +type fakeDeadlineSetter struct { + readDeadline *time.Time + writeDeadline *time.Time + setReadErr error + setWriteErr error +} + +func (f *fakeDeadlineSetter) SetReadDeadline(t time.Time) error { + f.readDeadline = &t + return f.setReadErr +} + +func (f *fakeDeadlineSetter) SetWriteDeadline(t time.Time) error { + f.writeDeadline = &t + return f.setWriteErr +} + +// responseWriterWithFailingDeadlines implements http.ResponseWriter and the optional +// SetReadDeadline/SetWriteDeadline methods so that http.ResponseController will call them. +// Returning an error from those methods allows testing that ServeHTTP returns 500. +type responseWriterWithFailingDeadlines struct { + http.ResponseWriter + readErr error + writeErr error +} + +func (w *responseWriterWithFailingDeadlines) SetReadDeadline(time.Time) error { return w.readErr } +func (w *responseWriterWithFailingDeadlines) SetWriteDeadline(time.Time) error { return w.writeErr } + +func TestServeHTTPReturns500WhenDeadlineFailsToSet(t *testing.T) { + t.Parallel() + deadlineErr := errors.New("set deadline failed") + + handler := NewUnaryHandler( + "/test.Service/Method", + func(context.Context, *Request[struct{}]) (*Response[struct{}], error) { + return &Response[struct{}]{}, nil + }, + WithReadTimeout(time.Second), + WithWriteTimeout(time.Second), + ) + + tests := []struct { + name string + readErr error + writeErr error + wantCode int + }{ + {"SetReadDeadline error returns 500", deadlineErr, nil, http.StatusInternalServerError}, + {"SetWriteDeadline error returns 500", nil, deadlineErr, http.StatusInternalServerError}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + rec := httptest.NewRecorder() + w := &responseWriterWithFailingDeadlines{ + ResponseWriter: rec, + readErr: tt.readErr, + writeErr: tt.writeErr, + } + + req := httptest.NewRequest(http.MethodPost, "http://test/", nil) + handler.ServeHTTP(w, req) + + assert.Equal(t, rec.Code, tt.wantCode) + }) + } +} + +func TestGetDeadline(t *testing.T) { + t.Parallel() + tests := []struct { + name string + timeout time.Duration + // wantNil is true when the result should be nil. + wantNil bool + // wantZero is true when the result should be a non-nil pointer to time.Time{}. + wantZero bool + }{ + { + name: "zero returns nil", + timeout: 0, + wantNil: true, + }, + { + name: "negative returns zero value", + timeout: -1, + wantZero: true, + }, + { + name: "positive returns future time", + timeout: 5 * time.Second, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + before := time.Now() + got := getDeadline(tt.timeout) + after := time.Now() + + if tt.wantNil { + assert.Nil(t, got) + return + } + + assert.NotNil(t, got) + if tt.wantZero { + assert.Equal(t, *got, time.Time{}) + return + } + + // Positive timeout: getDeadline uses time.Now() between our before and after, + // so deadline is in [before+timeout, after+timeout]. + assert.True(t, !got.Before(before.Add(tt.timeout)), + assert.Sprintf("deadline %v should not be before %v", got, before.Add(tt.timeout))) + assert.True(t, !got.After(after.Add(tt.timeout)), + assert.Sprintf("deadline %v should not be after %v", got, after.Add(tt.timeout))) + }) + } +} + +func TestApplyDeadlines(t *testing.T) { + t.Parallel() + setErr := errors.New("set deadline failed") + tests := []struct { + name string + readTimeout time.Duration + writeTimeout time.Duration + setReadErr error + setWriteErr error + wantErr bool + // wantReadSet: was SetReadDeadline called (timeout was non-zero)? + wantReadSet bool + // wantWriteSet: was SetWriteDeadline called? + wantWriteSet bool + }{ + { + name: "both zero neither set", + readTimeout: 0, + writeTimeout: 0, + wantReadSet: false, + wantWriteSet: false, + }, + { + name: "read only set", + readTimeout: 5 * time.Second, + writeTimeout: 0, + wantReadSet: true, + wantWriteSet: false, + }, + { + name: "write only set", + readTimeout: 0, + writeTimeout: time.Second, + wantReadSet: false, + wantWriteSet: true, + }, + { + name: "both set", + readTimeout: time.Second, + writeTimeout: 2 * time.Second, + wantReadSet: true, + wantWriteSet: true, + }, + { + name: "SetReadDeadline error returned", + readTimeout: time.Second, + setReadErr: setErr, + wantErr: true, + }, + { + name: "SetWriteDeadline error returned", + writeTimeout: time.Second, + setWriteErr: setErr, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + fake := &fakeDeadlineSetter{setReadErr: tt.setReadErr, setWriteErr: tt.setWriteErr} + before := time.Now() + err := applyDeadlines(tt.readTimeout, tt.writeTimeout, fake) + after := time.Now() + if tt.wantErr { + assert.NotNil(t, err) + return + } + assert.Nil(t, err) + if tt.wantReadSet { + assert.NotNil(t, fake.readDeadline) + assert.True(t, !fake.readDeadline.Before(before.Add(tt.readTimeout)), + assert.Sprintf("read deadline %v before %v", *fake.readDeadline, before.Add(tt.readTimeout))) + assert.True(t, !fake.readDeadline.After(after.Add(tt.readTimeout)), + assert.Sprintf("read deadline %v after %v", *fake.readDeadline, after.Add(tt.readTimeout))) + } else { + assert.Nil(t, fake.readDeadline) + } + if tt.wantWriteSet { + assert.NotNil(t, fake.writeDeadline) + assert.True(t, !fake.writeDeadline.Before(before.Add(tt.writeTimeout)), + assert.Sprintf("write deadline %v before %v", *fake.writeDeadline, before.Add(tt.writeTimeout))) + assert.True(t, !fake.writeDeadline.After(after.Add(tt.writeTimeout)), + assert.Sprintf("write deadline %v after %v", *fake.writeDeadline, after.Add(tt.writeTimeout))) + } else { + assert.Nil(t, fake.writeDeadline) + } + }) + } +} diff --git a/option.go b/option.go index fe0a2cd9..b8256071 100644 --- a/option.go +++ b/option.go @@ -19,6 +19,7 @@ import ( "context" "io" "net/http" + "time" ) // A ClientOption configures a [Client]. @@ -177,6 +178,14 @@ func WithConditionalHandlerOptions(conditional func(spec Spec) []HandlerOption) return &conditionalHandlerOptions{conditional: conditional} } +func WithReadTimeout(timeout time.Duration) HandlerOption { + return &readTimeoutOption{timeout} +} + +func WithWriteTimeout(timeout time.Duration) HandlerOption { + return &writeTimeoutOption{timeout} +} + // Option implements both [ClientOption] and [HandlerOption], so it can be // applied both client-side and server-side. type Option interface { @@ -645,3 +654,19 @@ func (o *conditionalHandlerOptions) applyToHandler(config *handlerConfig) { option.applyToHandler(config) } } + +type readTimeoutOption struct { + timeout time.Duration +} + +func (o *readTimeoutOption) applyToHandler(config *handlerConfig) { + config.ReadTimeout = o.timeout +} + +type writeTimeoutOption struct { + timeout time.Duration +} + +func (o *writeTimeoutOption) applyToHandler(config *handlerConfig) { + config.WriteTimeout = o.timeout +}