diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8dace7b8f..b9e727bb2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,7 +15,7 @@ jobs: build: name: Basic Tests - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Check out code @@ -36,7 +36,7 @@ jobs: grpc_web: name: gRPC-Web Tests - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Check out code diff --git a/Directory.Packages.props b/Directory.Packages.props index 1c82cda73..35b4940a7 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -33,7 +33,7 @@ - + diff --git a/examples/GreeterByServiceDefinition/Client/Client.csproj b/examples/GreeterByServiceDefinition/Client/Client.csproj new file mode 100644 index 000000000..8256c46ea --- /dev/null +++ b/examples/GreeterByServiceDefinition/Client/Client.csproj @@ -0,0 +1,18 @@ + + + + Exe + net7.0 + + + + + + + + + + + + + diff --git a/examples/GreeterByServiceDefinition/Client/Program.cs b/examples/GreeterByServiceDefinition/Client/Program.cs new file mode 100644 index 000000000..2b5958daf --- /dev/null +++ b/examples/GreeterByServiceDefinition/Client/Program.cs @@ -0,0 +1,30 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using Greet; +using Grpc.Net.Client; + +using var channel = GrpcChannel.ForAddress("https://localhost:5001"); +var client = new Greeter.GreeterClient(channel); + +var reply = await client.SayHelloAsync(new HelloRequest { Name = "GreeterClient" }); +Console.WriteLine("Greeting: " + reply.Message); + +Console.WriteLine("Shutting down"); +Console.WriteLine("Press any key to exit..."); +Console.ReadKey(); diff --git a/examples/GreeterByServiceDefinition/GreeterByServiceDefinition.sln b/examples/GreeterByServiceDefinition/GreeterByServiceDefinition.sln new file mode 100644 index 000000000..b19669e10 --- /dev/null +++ b/examples/GreeterByServiceDefinition/GreeterByServiceDefinition.sln @@ -0,0 +1,84 @@ + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 17 +VisualStudioVersion = 17.6.33829.357 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Server", "Server\Server.csproj", "{534AC5F8-2DF2-40BD-87A5-B3D8310118C4}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Client", "Client\Client.csproj", "{48A1D3BC-A14B-436A-8822-6DE2BEF8B747}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "ref", "ref", "{32B810EF-93B2-46C2-879A-BBA345A10E71}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Grpc.Core.Api", "..\..\src\Grpc.Core.Api\Grpc.Core.Api.csproj", "{BF8BD8C9-70D7-486F-BE4D-9ED2C7EA8CB1}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Grpc.Net.Common", "..\..\src\Grpc.Net.Common\Grpc.Net.Common.csproj", "{912BCAE2-04D8-4FFE-B9A5-C7FAEA7EF808}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Grpc.Net.ClientFactory", "..\..\src\Grpc.Net.ClientFactory\Grpc.Net.ClientFactory.csproj", "{3F49C6CE-D3AC-4609-B416-5E180DA59C7F}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Grpc.AspNetCore.Server.ClientFactory", "..\..\src\Grpc.AspNetCore.Server.ClientFactory\Grpc.AspNetCore.Server.ClientFactory.csproj", "{B38F8199-FD16-4E02-B1E2-CECEBF29A638}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Grpc.AspNetCore.Server", "..\..\src\Grpc.AspNetCore.Server\Grpc.AspNetCore.Server.csproj", "{5E857C51-76FF-4263-9BD7-CCB7997795F9}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Grpc.AspNetCore", "..\..\src\Grpc.AspNetCore\Grpc.AspNetCore.csproj", "{7D83B407-3C89-4671-BF97-A1B196633B0D}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Grpc.Net.Client", "..\..\src\Grpc.Net.Client\Grpc.Net.Client.csproj", "{CD4371A4-F789-4752-B1C0-DD95B1D6A090}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + Release|Any CPU = Release|Any CPU + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {534AC5F8-2DF2-40BD-87A5-B3D8310118C4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {534AC5F8-2DF2-40BD-87A5-B3D8310118C4}.Debug|Any CPU.Build.0 = Debug|Any CPU + {534AC5F8-2DF2-40BD-87A5-B3D8310118C4}.Release|Any CPU.ActiveCfg = Release|Any CPU + {534AC5F8-2DF2-40BD-87A5-B3D8310118C4}.Release|Any CPU.Build.0 = Release|Any CPU + {48A1D3BC-A14B-436A-8822-6DE2BEF8B747}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {48A1D3BC-A14B-436A-8822-6DE2BEF8B747}.Debug|Any CPU.Build.0 = Debug|Any CPU + {48A1D3BC-A14B-436A-8822-6DE2BEF8B747}.Release|Any CPU.ActiveCfg = Release|Any CPU + {48A1D3BC-A14B-436A-8822-6DE2BEF8B747}.Release|Any CPU.Build.0 = Release|Any CPU + {BF8BD8C9-70D7-486F-BE4D-9ED2C7EA8CB1}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {BF8BD8C9-70D7-486F-BE4D-9ED2C7EA8CB1}.Debug|Any CPU.Build.0 = Debug|Any CPU + {BF8BD8C9-70D7-486F-BE4D-9ED2C7EA8CB1}.Release|Any CPU.ActiveCfg = Release|Any CPU + {BF8BD8C9-70D7-486F-BE4D-9ED2C7EA8CB1}.Release|Any CPU.Build.0 = Release|Any CPU + {912BCAE2-04D8-4FFE-B9A5-C7FAEA7EF808}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {912BCAE2-04D8-4FFE-B9A5-C7FAEA7EF808}.Debug|Any CPU.Build.0 = Debug|Any CPU + {912BCAE2-04D8-4FFE-B9A5-C7FAEA7EF808}.Release|Any CPU.ActiveCfg = Release|Any CPU + {912BCAE2-04D8-4FFE-B9A5-C7FAEA7EF808}.Release|Any CPU.Build.0 = Release|Any CPU + {3F49C6CE-D3AC-4609-B416-5E180DA59C7F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {3F49C6CE-D3AC-4609-B416-5E180DA59C7F}.Debug|Any CPU.Build.0 = Debug|Any CPU + {3F49C6CE-D3AC-4609-B416-5E180DA59C7F}.Release|Any CPU.ActiveCfg = Release|Any CPU + {3F49C6CE-D3AC-4609-B416-5E180DA59C7F}.Release|Any CPU.Build.0 = Release|Any CPU + {B38F8199-FD16-4E02-B1E2-CECEBF29A638}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {B38F8199-FD16-4E02-B1E2-CECEBF29A638}.Debug|Any CPU.Build.0 = Debug|Any CPU + {B38F8199-FD16-4E02-B1E2-CECEBF29A638}.Release|Any CPU.ActiveCfg = Release|Any CPU + {B38F8199-FD16-4E02-B1E2-CECEBF29A638}.Release|Any CPU.Build.0 = Release|Any CPU + {5E857C51-76FF-4263-9BD7-CCB7997795F9}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {5E857C51-76FF-4263-9BD7-CCB7997795F9}.Debug|Any CPU.Build.0 = Debug|Any CPU + {5E857C51-76FF-4263-9BD7-CCB7997795F9}.Release|Any CPU.ActiveCfg = Release|Any CPU + {5E857C51-76FF-4263-9BD7-CCB7997795F9}.Release|Any CPU.Build.0 = Release|Any CPU + {7D83B407-3C89-4671-BF97-A1B196633B0D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {7D83B407-3C89-4671-BF97-A1B196633B0D}.Debug|Any CPU.Build.0 = Debug|Any CPU + {7D83B407-3C89-4671-BF97-A1B196633B0D}.Release|Any CPU.ActiveCfg = Release|Any CPU + {7D83B407-3C89-4671-BF97-A1B196633B0D}.Release|Any CPU.Build.0 = Release|Any CPU + {CD4371A4-F789-4752-B1C0-DD95B1D6A090}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {CD4371A4-F789-4752-B1C0-DD95B1D6A090}.Debug|Any CPU.Build.0 = Debug|Any CPU + {CD4371A4-F789-4752-B1C0-DD95B1D6A090}.Release|Any CPU.ActiveCfg = Release|Any CPU + {CD4371A4-F789-4752-B1C0-DD95B1D6A090}.Release|Any CPU.Build.0 = Release|Any CPU + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection + GlobalSection(NestedProjects) = preSolution + {BF8BD8C9-70D7-486F-BE4D-9ED2C7EA8CB1} = {32B810EF-93B2-46C2-879A-BBA345A10E71} + {912BCAE2-04D8-4FFE-B9A5-C7FAEA7EF808} = {32B810EF-93B2-46C2-879A-BBA345A10E71} + {3F49C6CE-D3AC-4609-B416-5E180DA59C7F} = {32B810EF-93B2-46C2-879A-BBA345A10E71} + {B38F8199-FD16-4E02-B1E2-CECEBF29A638} = {32B810EF-93B2-46C2-879A-BBA345A10E71} + {5E857C51-76FF-4263-9BD7-CCB7997795F9} = {32B810EF-93B2-46C2-879A-BBA345A10E71} + {7D83B407-3C89-4671-BF97-A1B196633B0D} = {32B810EF-93B2-46C2-879A-BBA345A10E71} + {CD4371A4-F789-4752-B1C0-DD95B1D6A090} = {32B810EF-93B2-46C2-879A-BBA345A10E71} + EndGlobalSection + GlobalSection(ExtensibilityGlobals) = postSolution + SolutionGuid = {D22B3129-3BFB-41FA-9FCE-E45EBEF8C2DD} + EndGlobalSection +EndGlobal diff --git a/examples/GreeterByServiceDefinition/Proto/greet.proto b/examples/GreeterByServiceDefinition/Proto/greet.proto new file mode 100644 index 000000000..26d0c794d --- /dev/null +++ b/examples/GreeterByServiceDefinition/Proto/greet.proto @@ -0,0 +1,33 @@ +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package greet; + +// The greeting service definition. +service Greeter { + // Sends a greeting + rpc SayHello (HelloRequest) returns (HelloReply); +} + +// The request message containing the user's name. +message HelloRequest { + string name = 1; +} + +// The response message containing the greetings +message HelloReply { + string message = 1; +} diff --git a/examples/GreeterByServiceDefinition/Server/Program.cs b/examples/GreeterByServiceDefinition/Server/Program.cs new file mode 100644 index 000000000..69de37f8c --- /dev/null +++ b/examples/GreeterByServiceDefinition/Server/Program.cs @@ -0,0 +1,36 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using Grpc.Core; +using Server; +using Microsoft.AspNetCore.Builder; + +var builder = WebApplication.CreateBuilder(args); +builder.Services.AddGrpc(); + +var app = builder.Build(); +app.MapGrpcService(getGreeterService); + +app.Run(); + +static ServerServiceDefinition getGreeterService(IServiceProvider serviceProvider) +{ + var loggerFactory = serviceProvider.GetRequiredService(); + var service = new GreeterService(loggerFactory); + return Greet.Greeter.BindService(service); +} diff --git a/examples/GreeterByServiceDefinition/Server/Server.csproj b/examples/GreeterByServiceDefinition/Server/Server.csproj new file mode 100644 index 000000000..099b6adbd --- /dev/null +++ b/examples/GreeterByServiceDefinition/Server/Server.csproj @@ -0,0 +1,16 @@ + + + + net9.0 + + + + + + + + + + + + diff --git a/examples/GreeterByServiceDefinition/Server/Services/GreeterService.cs b/examples/GreeterByServiceDefinition/Server/Services/GreeterService.cs new file mode 100644 index 000000000..1ca09856d --- /dev/null +++ b/examples/GreeterByServiceDefinition/Server/Services/GreeterService.cs @@ -0,0 +1,41 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System.Threading.Tasks; +using Greet; +using Grpc.Core; +using Microsoft.Extensions.Logging; + +namespace Server +{ + public class GreeterService : Greeter.GreeterBase + { + private readonly ILogger _logger; + + public GreeterService(ILoggerFactory loggerFactory) + { + _logger = loggerFactory.CreateLogger(); + } + + public override Task SayHello(HelloRequest request, ServerCallContext context) + { + _logger.LogInformation($"Sending hello to {request.Name}"); + return Task.FromResult(new HelloReply { Message = "Hello " + request.Name }); + } + } +} diff --git a/examples/GreeterByServiceDefinition/Server/appsettings.Development.json b/examples/GreeterByServiceDefinition/Server/appsettings.Development.json new file mode 100644 index 000000000..fe20c40cc --- /dev/null +++ b/examples/GreeterByServiceDefinition/Server/appsettings.Development.json @@ -0,0 +1,10 @@ +{ + "Logging": { + "LogLevel": { + "Default": "Debug", + "System": "Information", + "Grpc": "Information", + "Microsoft": "Information" + } + } +} diff --git a/examples/GreeterByServiceDefinition/Server/appsettings.json b/examples/GreeterByServiceDefinition/Server/appsettings.json new file mode 100644 index 000000000..f5f63744b --- /dev/null +++ b/examples/GreeterByServiceDefinition/Server/appsettings.json @@ -0,0 +1,13 @@ +{ + "Logging": { + "LogLevel": { + "Default": "Information" + } + }, + "AllowedHosts": "*", + "Kestrel": { + "EndpointDefaults": { + "Protocols": "Http2" + } + } +} diff --git a/examples/README.md b/examples/README.md index 14a93010b..c018c0a0f 100644 --- a/examples/README.md +++ b/examples/README.md @@ -328,3 +328,13 @@ The error example shows how to use a richer error model with `Grpc.StatusProto`. * Error handling * Validation * [`google.rpc.Status`](https://cloud.google.com/apis/design/errors#error_model) + + +## [GreeterByServiceDefinition](./GreeterByServiceDefinition) + +This sample is similar with [Greeter](#greeter), but its service instance for server is mapped by using `ServerServiceDefinition`. + +##### Scenarios: + +* Mapping server service by using `ServerServiceDefinition` +* Unary call diff --git a/examples/Spar/Server/ClientApp/package-lock.json b/examples/Spar/Server/ClientApp/package-lock.json index 7cf245d58..b80027769 100644 --- a/examples/Spar/Server/ClientApp/package-lock.json +++ b/examples/Spar/Server/ClientApp/package-lock.json @@ -3837,10 +3837,11 @@ "dev": true }, "node_modules/elliptic": { - "version": "6.6.0", - "resolved": "https://registry.npmjs.org/elliptic/-/elliptic-6.6.0.tgz", - "integrity": "sha512-dpwoQcLc/2WLQvJvLRHKZ+f9FgOdjnq11rurqwekGQygGPsYSK29OMMD2WalatiqQ+XGFDglTNixpPfI+lpaAA==", + "version": "6.6.1", + "resolved": "https://registry.npmjs.org/elliptic/-/elliptic-6.6.1.tgz", + "integrity": "sha512-RaddvvMatK2LJHqFJ+YA4WysVN5Ita9E35botqIYspQ4TkRAlCicdzKOjlyv/1Za5RyTNn7di//eEV0uTAfe3g==", "dev": true, + "license": "MIT", "dependencies": { "bn.js": "^4.11.9", "brorand": "^1.1.0", @@ -12413,9 +12414,9 @@ "dev": true }, "elliptic": { - "version": "6.6.0", - "resolved": "https://registry.npmjs.org/elliptic/-/elliptic-6.6.0.tgz", - "integrity": "sha512-dpwoQcLc/2WLQvJvLRHKZ+f9FgOdjnq11rurqwekGQygGPsYSK29OMMD2WalatiqQ+XGFDglTNixpPfI+lpaAA==", + "version": "6.6.1", + "resolved": "https://registry.npmjs.org/elliptic/-/elliptic-6.6.1.tgz", + "integrity": "sha512-RaddvvMatK2LJHqFJ+YA4WysVN5Ita9E35botqIYspQ4TkRAlCicdzKOjlyv/1Za5RyTNn7di//eEV0uTAfe3g==", "dev": true, "requires": { "bn.js": "^4.11.9", diff --git a/nuget.config b/nuget.config index e4d73b867..d5ff7157f 100644 --- a/nuget.config +++ b/nuget.config @@ -3,7 +3,6 @@ - diff --git a/src/Grpc.AspNetCore.Server/GrpcEndpointRouteBuilderExtensions.cs b/src/Grpc.AspNetCore.Server/GrpcEndpointRouteBuilderExtensions.cs index f583a2c4c..3be616df8 100644 --- a/src/Grpc.AspNetCore.Server/GrpcEndpointRouteBuilderExtensions.cs +++ b/src/Grpc.AspNetCore.Server/GrpcEndpointRouteBuilderExtensions.cs @@ -19,6 +19,7 @@ using System.Diagnostics.CodeAnalysis; using Grpc.AspNetCore.Server.Internal; using Grpc.AspNetCore.Server.Model.Internal; +using Grpc.Core; using Grpc.Shared; using Microsoft.AspNetCore.Routing; using Microsoft.Extensions.DependencyInjection; @@ -48,6 +49,43 @@ public static class GrpcEndpointRouteBuilderExtensions return new GrpcServiceEndpointConventionBuilder(endpointConventionBuilders); } + /// + /// Maps incoming requests to the specified instance. + /// + /// The to add the route to. + /// The instance of . + /// A for endpoints associated with the service. + [RequiresUnreferencedCode("Due to type erasure in ServerServiceDefinition, MapGrpcService is incompatible with trimming.")] + public static GrpcServiceEndpointConventionBuilder MapGrpcService(this IEndpointRouteBuilder builder, ServerServiceDefinition serviceDefinition) + { + ArgumentNullException.ThrowIfNull(builder, nameof(builder)); + ArgumentNullException.ThrowIfNull(serviceDefinition, nameof(serviceDefinition)); + + var serviceRouteBuilder = builder.ServiceProvider.GetRequiredService(); + var endpointConventionBuilders = serviceRouteBuilder.Build(builder, serviceDefinition); + + return new GrpcServiceEndpointConventionBuilder(endpointConventionBuilders); + } + + /// + /// Maps incoming requests to the instance from the specified factory. + /// + /// The to add the route to. + /// The factory for instance. + /// A for endpoints associated with the service. + [RequiresUnreferencedCode("Due to type erasure in ServerServiceDefinition, MapGrpcService is incompatible with trimming.")] + public static GrpcServiceEndpointConventionBuilder MapGrpcService(this IEndpointRouteBuilder builder, Func getServiceDefinition) + { + ArgumentNullException.ThrowIfNull(builder, nameof(builder)); + ArgumentNullException.ThrowIfNull(getServiceDefinition, nameof(getServiceDefinition)); + + var serviceDefinition = getServiceDefinition(builder.ServiceProvider); + var serviceRouteBuilder = builder.ServiceProvider.GetRequiredService(); + var endpointConventionBuilders = serviceRouteBuilder.Build(builder, serviceDefinition); + + return new GrpcServiceEndpointConventionBuilder(endpointConventionBuilders); + } + private static void ValidateServicesRegistered(IServiceProvider serviceProvider) { var marker = serviceProvider.GetService(typeof(GrpcMarkerService)); diff --git a/src/Grpc.AspNetCore.Server/GrpcServiceExtensions.cs b/src/Grpc.AspNetCore.Server/GrpcServiceExtensions.cs index 0f9b8852d..b45dbf396 100644 --- a/src/Grpc.AspNetCore.Server/GrpcServiceExtensions.cs +++ b/src/Grpc.AspNetCore.Server/GrpcServiceExtensions.cs @@ -68,6 +68,7 @@ public static IGrpcServerBuilder AddGrpc(this IServiceCollection services) #endif services.AddOptions(); services.TryAddSingleton(); + services.TryAddSingleton(typeof(ServerCallHandlerFactory)); services.TryAddSingleton(typeof(ServerCallHandlerFactory<>)); services.TryAddSingleton(typeof(IGrpcServiceActivator<>), typeof(DefaultGrpcServiceActivator<>)); services.TryAddSingleton(typeof(IGrpcInterceptorActivator<>), typeof(DefaultGrpcInterceptorActivator<>)); @@ -75,6 +76,7 @@ public static IGrpcServerBuilder AddGrpc(this IServiceCollection services) // Model services.TryAddSingleton(); + services.TryAddSingleton(typeof(ServiceRouteBuilder)); services.TryAddSingleton(typeof(ServiceRouteBuilder<>)); services.TryAddEnumerable(ServiceDescriptor.Singleton(typeof(IServiceMethodProvider<>), typeof(BinderServiceMethodProvider<>))); diff --git a/src/Grpc.AspNetCore.Server/Internal/CallHandlers/ClientStreamingServerCallHandler.cs b/src/Grpc.AspNetCore.Server/Internal/CallHandlers/ClientStreamingServerCallHandler.cs index 99437edc9..0272ac6bd 100644 --- a/src/Grpc.AspNetCore.Server/Internal/CallHandlers/ClientStreamingServerCallHandler.cs +++ b/src/Grpc.AspNetCore.Server/Internal/CallHandlers/ClientStreamingServerCallHandler.cs @@ -24,6 +24,56 @@ namespace Grpc.AspNetCore.Server.Internal.CallHandlers; +internal sealed class ClientStreamingServerCallHandler : ServerCallHandlerBase + where TRequest : class + where TResponse : class +{ + private readonly ClientStreamingServerMethodInvoker _invoker; + + public ClientStreamingServerCallHandler( + ClientStreamingServerMethodInvoker invoker, + ILoggerFactory loggerFactory) + : base(invoker, loggerFactory) + { + _invoker = invoker; + } + + protected override async Task HandleCallAsyncCore(HttpContext httpContext, HttpContextServerCallContext serverCallContext) + { + // Disable request body data rate for client streaming + DisableMinRequestBodyDataRateAndMaxRequestBodySize(httpContext); + + TResponse? response; + + var streamReader = new HttpContextStreamReader(serverCallContext, MethodInvoker.Method.RequestMarshaller.ContextualDeserializer); + try + { + response = await _invoker.Invoke(httpContext, serverCallContext, streamReader); + } + finally + { + streamReader.Complete(); + } + + if (response == null) + { + // This is consistent with Grpc.Core when a null value is returned + throw new RpcException(new Status(StatusCode.Cancelled, "No message returned from method.")); + } + + // Check if deadline exceeded while method was invoked. If it has then skip trying to write + // the response message because it will always fail. + // Note that the call is still going so the deadline could still be exceeded after this point. + if (serverCallContext.DeadlineManager?.IsDeadlineExceededStarted ?? false) + { + return; + } + + var responseBodyWriter = httpContext.Response.BodyWriter; + await responseBodyWriter.WriteSingleMessageAsync(response, serverCallContext, MethodInvoker.Method.ResponseMarshaller.ContextualSerializer); + } +} + internal sealed class ClientStreamingServerCallHandler<[DynamicallyAccessedMembers(GrpcProtocolConstants.ServiceAccessibility)] TService, TRequest, TResponse> : ServerCallHandlerBase where TRequest : class where TResponse : class diff --git a/src/Grpc.AspNetCore.Server/Internal/CallHandlers/DuplexStreamingServerCallHandler.cs b/src/Grpc.AspNetCore.Server/Internal/CallHandlers/DuplexStreamingServerCallHandler.cs index 2649b1fb5..b7811e131 100644 --- a/src/Grpc.AspNetCore.Server/Internal/CallHandlers/DuplexStreamingServerCallHandler.cs +++ b/src/Grpc.AspNetCore.Server/Internal/CallHandlers/DuplexStreamingServerCallHandler.cs @@ -23,6 +23,39 @@ namespace Grpc.AspNetCore.Server.Internal.CallHandlers; +internal sealed class DuplexStreamingServerCallHandler : ServerCallHandlerBase + where TRequest : class + where TResponse : class +{ + private readonly DuplexStreamingServerMethodInvoker _invoker; + + public DuplexStreamingServerCallHandler( + DuplexStreamingServerMethodInvoker invoker, + ILoggerFactory loggerFactory) + : base(invoker, loggerFactory) + { + _invoker = invoker; + } + + protected override async Task HandleCallAsyncCore(HttpContext httpContext, HttpContextServerCallContext serverCallContext) + { + // Disable request body data rate for client streaming + DisableMinRequestBodyDataRateAndMaxRequestBodySize(httpContext); + + var streamReader = new HttpContextStreamReader(serverCallContext, MethodInvoker.Method.RequestMarshaller.ContextualDeserializer); + var streamWriter = new HttpContextStreamWriter(serverCallContext, MethodInvoker.Method.ResponseMarshaller.ContextualSerializer); + try + { + await _invoker.Invoke(httpContext, serverCallContext, streamReader, streamWriter); + } + finally + { + streamReader.Complete(); + streamWriter.Complete(); + } + } +} + internal sealed class DuplexStreamingServerCallHandler<[DynamicallyAccessedMembers(GrpcProtocolConstants.ServiceAccessibility)] TService, TRequest, TResponse> : ServerCallHandlerBase where TRequest : class where TResponse : class diff --git a/src/Grpc.AspNetCore.Server/Internal/CallHandlers/ServerCallHandlerBase.cs b/src/Grpc.AspNetCore.Server/Internal/CallHandlers/ServerCallHandlerBase.cs index f4db8601f..2b447a910 100644 --- a/src/Grpc.AspNetCore.Server/Internal/CallHandlers/ServerCallHandlerBase.cs +++ b/src/Grpc.AspNetCore.Server/Internal/CallHandlers/ServerCallHandlerBase.cs @@ -30,7 +30,140 @@ namespace Grpc.AspNetCore.Server.Internal.CallHandlers; -internal abstract class ServerCallHandlerBase<[DynamicallyAccessedMembers(GrpcProtocolConstants.ServiceAccessibility)] TService, TRequest, TResponse> +internal abstract class ServerCallHandlerBase + where TRequest : class + where TResponse : class +{ + private const string LoggerName = "Grpc.AspNetCore.Server.ServerCallHandler"; + + protected ServerMethodInvokerBase MethodInvoker { get; } + protected ILogger Logger { get; } + + protected ServerCallHandlerBase( + ServerMethodInvokerBase methodInvoker, + ILoggerFactory loggerFactory) + { + MethodInvoker = methodInvoker; + Logger = loggerFactory.CreateLogger(LoggerName); + } + + public Task HandleCallAsync(HttpContext httpContext) + { + if (GrpcProtocolHelpers.IsInvalidContentType(httpContext, out var error)) + { + return ProcessInvalidContentTypeRequest(httpContext, error); + } + + if (!GrpcProtocolConstants.IsHttp2(httpContext.Request.Protocol) +#if NET6_0_OR_GREATER + && !GrpcProtocolConstants.IsHttp3(httpContext.Request.Protocol) +#endif + ) + { + return ProcessNonHttp2Request(httpContext); + } + + var serverCallContext = new HttpContextServerCallContext(httpContext, MethodInvoker.Options, typeof(TRequest), typeof(TResponse), Logger); + httpContext.Features.Set(serverCallContext); + + GrpcProtocolHelpers.AddProtocolHeaders(httpContext.Response); + + try + { + serverCallContext.Initialize(); + + var handleCallTask = HandleCallAsyncCore(httpContext, serverCallContext); + + if (handleCallTask.IsCompletedSuccessfully) + { + return serverCallContext.EndCallAsync(); + } + else + { + return AwaitHandleCall(serverCallContext, MethodInvoker.Method, handleCallTask); + } + } + catch (Exception ex) + { + return serverCallContext.ProcessHandlerErrorAsync(ex, MethodInvoker.Method.Name); + } + + static async Task AwaitHandleCall(HttpContextServerCallContext serverCallContext, Method method, Task handleCall) + { + try + { + await handleCall; + await serverCallContext.EndCallAsync(); + } + catch (Exception ex) + { + await serverCallContext.ProcessHandlerErrorAsync(ex, method.Name); + } + } + } + + protected abstract Task HandleCallAsyncCore(HttpContext httpContext, HttpContextServerCallContext serverCallContext); + + /// + /// This should only be called from client streaming calls + /// + /// + protected void DisableMinRequestBodyDataRateAndMaxRequestBodySize(HttpContext httpContext) + { + var minRequestBodyDataRateFeature = httpContext.Features.Get(); + if (minRequestBodyDataRateFeature != null) + { + minRequestBodyDataRateFeature.MinDataRate = null; + } + + var maxRequestBodySizeFeature = httpContext.Features.Get(); + if (maxRequestBodySizeFeature != null) + { + if (!maxRequestBodySizeFeature.IsReadOnly) + { + maxRequestBodySizeFeature.MaxRequestBodySize = null; + } + else + { + // IsReadOnly could be true if middleware has already started reading the request body + // In that case we can't disable the max request body size for the request stream + GrpcServerLog.UnableToDisableMaxRequestBodySize(Logger); + } + } + } + + private Task ProcessNonHttp2Request(HttpContext httpContext) + { + GrpcServerLog.UnsupportedRequestProtocol(Logger, httpContext.Request.Protocol); + + var protocolError = $"Request protocol '{httpContext.Request.Protocol}' is not supported."; + GrpcProtocolHelpers.BuildHttpErrorResponse(httpContext.Response, StatusCodes.Status426UpgradeRequired, StatusCode.Internal, protocolError); + httpContext.Response.Headers[HeaderNames.Upgrade] = GrpcProtocolConstants.Http2Protocol; + return Task.CompletedTask; + } + + private Task ProcessInvalidContentTypeRequest(HttpContext httpContext, string error) + { + // This might be a CORS preflight request and CORS middleware hasn't been configured + if (GrpcProtocolHelpers.IsCorsPreflightRequest(httpContext)) + { + GrpcServerLog.UnhandledCorsPreflightRequest(Logger); + + GrpcProtocolHelpers.BuildHttpErrorResponse(httpContext.Response, StatusCodes.Status405MethodNotAllowed, StatusCode.Internal, "Unhandled CORS preflight request received. CORS may not be configured correctly in the application."); + httpContext.Response.Headers[HeaderNames.Allow] = HttpMethods.Post; + return Task.CompletedTask; + } + else + { + GrpcServerLog.UnsupportedRequestContentType(Logger, httpContext.Request.ContentType); + + GrpcProtocolHelpers.BuildHttpErrorResponse(httpContext.Response, StatusCodes.Status415UnsupportedMediaType, StatusCode.Internal, error); + return Task.CompletedTask; + } + } +} + +internal abstract class ServerCallHandlerBase<[DynamicallyAccessedMembers(GrpcProtocolConstants.ServiceAccessibility)]TService, TRequest, TResponse> where TService : class where TRequest : class where TResponse : class diff --git a/src/Grpc.AspNetCore.Server/Internal/CallHandlers/ServerStreamingServerCallHandler.cs b/src/Grpc.AspNetCore.Server/Internal/CallHandlers/ServerStreamingServerCallHandler.cs index da49752c9..0b315677f 100644 --- a/src/Grpc.AspNetCore.Server/Internal/CallHandlers/ServerStreamingServerCallHandler.cs +++ b/src/Grpc.AspNetCore.Server/Internal/CallHandlers/ServerStreamingServerCallHandler.cs @@ -23,6 +23,37 @@ namespace Grpc.AspNetCore.Server.Internal.CallHandlers; +internal sealed class ServerStreamingServerCallHandler : ServerCallHandlerBase + where TRequest : class + where TResponse : class +{ + private readonly ServerStreamingServerMethodInvoker _invoker; + + public ServerStreamingServerCallHandler( + ServerStreamingServerMethodInvoker invoker, + ILoggerFactory loggerFactory) + : base(invoker, loggerFactory) + { + _invoker = invoker; + } + + protected override async Task HandleCallAsyncCore(HttpContext httpContext, HttpContextServerCallContext serverCallContext) + { + // Decode request + var request = await httpContext.Request.BodyReader.ReadSingleMessageAsync(serverCallContext, MethodInvoker.Method.RequestMarshaller.ContextualDeserializer); + + var streamWriter = new HttpContextStreamWriter(serverCallContext, MethodInvoker.Method.ResponseMarshaller.ContextualSerializer); + try + { + await _invoker.Invoke(httpContext, serverCallContext, request, streamWriter); + } + finally + { + streamWriter.Complete(); + } + } +} + internal sealed class ServerStreamingServerCallHandler<[DynamicallyAccessedMembers(GrpcProtocolConstants.ServiceAccessibility)] TService, TRequest, TResponse> : ServerCallHandlerBase where TRequest : class where TResponse : class diff --git a/src/Grpc.AspNetCore.Server/Internal/CallHandlers/UnaryServerCallHandler.cs b/src/Grpc.AspNetCore.Server/Internal/CallHandlers/UnaryServerCallHandler.cs index f0d3047cd..557036d18 100644 --- a/src/Grpc.AspNetCore.Server/Internal/CallHandlers/UnaryServerCallHandler.cs +++ b/src/Grpc.AspNetCore.Server/Internal/CallHandlers/UnaryServerCallHandler.cs @@ -24,6 +24,45 @@ namespace Grpc.AspNetCore.Server.Internal.CallHandlers; +internal sealed class UnaryServerCallHandler : ServerCallHandlerBase + where TRequest : class + where TResponse : class +{ + private readonly UnaryServerMethodInvoker _invoker; + + public UnaryServerCallHandler( + UnaryServerMethodInvoker invoker, + ILoggerFactory loggerFactory) + : base(invoker, loggerFactory) + { + _invoker = invoker; + } + + protected override async Task HandleCallAsyncCore(HttpContext httpContext, HttpContextServerCallContext serverCallContext) + { + var request = await httpContext.Request.BodyReader.ReadSingleMessageAsync(serverCallContext, MethodInvoker.Method.RequestMarshaller.ContextualDeserializer); + + var response = await _invoker.Invoke(httpContext, serverCallContext, request); + + if (response == null) + { + // This is consistent with Grpc.Core when a null value is returned + throw new RpcException(new Status(StatusCode.Cancelled, "No message returned from method.")); + } + + // Check if deadline exceeded while method was invoked. If it has then skip trying to write + // the response message because it will always fail. + // Note that the call is still going so the deadline could still be exceeded after this point. + if (serverCallContext.DeadlineManager?.IsDeadlineExceededStarted ?? false) + { + return; + } + + var responseBodyWriter = httpContext.Response.BodyWriter; + await responseBodyWriter.WriteSingleMessageAsync(response, serverCallContext, MethodInvoker.Method.ResponseMarshaller.ContextualSerializer); + } +} + internal sealed class UnaryServerCallHandler<[DynamicallyAccessedMembers(GrpcProtocolConstants.ServiceAccessibility)] TService, TRequest, TResponse> : ServerCallHandlerBase where TRequest : class where TResponse : class diff --git a/src/Grpc.AspNetCore.Server/Internal/EndpointServiceBinder.cs b/src/Grpc.AspNetCore.Server/Internal/EndpointServiceBinder.cs new file mode 100644 index 000000000..23ef15e6f --- /dev/null +++ b/src/Grpc.AspNetCore.Server/Internal/EndpointServiceBinder.cs @@ -0,0 +1,126 @@ +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.Logging; +using Grpc.Core; +using Microsoft.AspNetCore.Routing.Patterns; +using Grpc.AspNetCore.Server.Model.Internal; + +namespace Grpc.AspNetCore.Server.Internal; + +/// +/// The service binder to bind into ASP.Net core web application server. +/// +internal class EndpointServiceBinder : ServiceBinderBase +{ + private readonly ServerCallHandlerFactory _serverCallHandlerFactory; + private readonly IEndpointRouteBuilder _routeBuilder; + private readonly ILogger _logger; + public List EndpointConventionBuilders { get; } + public List MethodModels { get; } + + public EndpointServiceBinder( + ServerCallHandlerFactory serverCallHandlerFactory, + IEndpointRouteBuilder routeBuilder, + ILoggerFactory loggerFactory) + { + _serverCallHandlerFactory = serverCallHandlerFactory; + _routeBuilder = routeBuilder; + _logger = loggerFactory.CreateLogger(); + EndpointConventionBuilders = new List(); + MethodModels = new List(); + } + + public override void AddMethod(Method method, UnaryServerMethod? handler) + { + if(handler?.Method.DeclaringType == null) + { + throw new InvalidOperationException($"Instance methods are only allowed as server implementation for Grpc.Core.ServerServiceDefinition."); + } + var serviceType = handler.Method.DeclaringType; + var metadata = CreateMetadata(serviceType, handler); + var callHandler = _serverCallHandlerFactory.CreateUnary(method, handler); + var pattern = RoutePatternFactory.Parse(method.FullName); + AddMethod(new MethodModel(method, pattern, metadata, callHandler.HandleCallAsync)); + } + + public override void AddMethod(Method method, ClientStreamingServerMethod? handler) + { + if (handler?.Method.DeclaringType == null) + { + throw new InvalidOperationException($"Instance methods are only allowed as server implementation for Grpc.Core.ServerServiceDefinition."); + } + var serviceType = handler.Method.DeclaringType; + var metadata = CreateMetadata(serviceType, handler); + var callHandler = _serverCallHandlerFactory.CreateClientStreaming(method, handler); + var pattern = RoutePatternFactory.Parse(method.FullName); + AddMethod(new MethodModel(method, pattern, metadata, callHandler.HandleCallAsync)); + } + + public override void AddMethod(Method method, ServerStreamingServerMethod? handler) + { + if (handler?.Method.DeclaringType == null) + { + throw new InvalidOperationException($"Instance methods are only allowed as server implementation for Grpc.Core.ServerServiceDefinition."); + } + var serviceType = handler.Method.DeclaringType; + var metadata = CreateMetadata(serviceType, handler); + var callHandler = _serverCallHandlerFactory.CreateServerStreaming(method, handler); + var pattern = RoutePatternFactory.Parse(method.FullName); + AddMethod(new MethodModel(method, pattern, metadata, callHandler.HandleCallAsync)); + } + + public override void AddMethod(Method method, DuplexStreamingServerMethod? handler) + { + if (handler?.Method.DeclaringType == null) + { + throw new InvalidOperationException($"Instance methods are only allowed as server implementation for Grpc.Core.ServerServiceDefinition."); + } + var serviceType = handler.Method.DeclaringType; + var metadata = CreateMetadata(serviceType, handler); + var callHandler = _serverCallHandlerFactory.CreateDuplexStreaming(method, handler); + var pattern = RoutePatternFactory.Parse(method.FullName); + AddMethod(new MethodModel(method, pattern, metadata, callHandler.HandleCallAsync)); + } + + private IList CreateMetadata(Type serviceType, Delegate handler) + { + var metadata = new List(); + // Add type metadata first so it has a lower priority + metadata.AddRange(serviceType.GetCustomAttributes(inherit: true)); + // Add method metadata last so it has a higher priority + metadata.AddRange(handler.Method.GetCustomAttributes(inherit: true)); + + // Accepting CORS preflight means gRPC will allow requests with OPTIONS + preflight headers. + // If CORS middleware hasn't been configured then the request will reach gRPC handler. + // gRPC will return 405 response and log that CORS has not been configured. + metadata.Add(new HttpMethodMetadata(new[] { "POST" }, acceptCorsPreflight: true)); + + return metadata; + } + + private void AddMethod(MethodModel method) + { + var endpointBuilder = _routeBuilder.Map(method.Pattern, method.RequestDelegate); + endpointBuilder.Add(ep => + { + ep.DisplayName = $"gRPC - {method.Pattern.RawText}"; + + foreach (var item in method.Metadata) + { + ep.Metadata.Add(item); + } + }); + EndpointConventionBuilders.Add(endpointBuilder); + MethodModels.Add(method); + + var httpMethod = method.Metadata.OfType().LastOrDefault(); + + ServiceRouteBuilderLog.LogAddedServiceMethod( + _logger, + method.Method.Name, + method.Method.ServiceName, + method.Method.Type, + httpMethod?.HttpMethods ?? Array.Empty(), + method.Pattern.RawText ?? string.Empty); + } +} diff --git a/src/Grpc.AspNetCore.Server/Internal/ServerCallHandlerFactory.cs b/src/Grpc.AspNetCore.Server/Internal/ServerCallHandlerFactory.cs index 7da6ff722..5fbcebe44 100644 --- a/src/Grpc.AspNetCore.Server/Internal/ServerCallHandlerFactory.cs +++ b/src/Grpc.AspNetCore.Server/Internal/ServerCallHandlerFactory.cs @@ -28,10 +28,138 @@ namespace Grpc.AspNetCore.Server.Internal; +internal interface IServerCallHandlerFactory +{ + bool IgnoreUnknownServices { get; } + bool IgnoreUnknownMethods { get; } + RequestDelegate CreateUnimplementedMethod(); + RequestDelegate CreateUnimplementedService(); +} + +/// +/// Creates server call handlers for . Provides a place to get services that call handlers will use. +/// +internal partial class ServerCallHandlerFactory : IServerCallHandlerFactory +{ + private readonly ILoggerFactory _loggerFactory; + private readonly GrpcServiceOptions _globalOptions; + + public ServerCallHandlerFactory( + ILoggerFactory loggerFactory, + IOptions globalOptions) + { + _loggerFactory = loggerFactory; + _globalOptions = globalOptions.Value; + } + + // Internal for testing + internal MethodOptions CreateMethodOptions() + { + return MethodOptions.Create(new[] { _globalOptions }); + } + + public UnaryServerCallHandler CreateUnary(Method method, UnaryServerMethod invoker) + where TRequest : class + where TResponse : class + { + var options = CreateMethodOptions(); + var methodInvoker = new UnaryServerMethodInvoker(invoker, method, options); + + return new UnaryServerCallHandler(methodInvoker, _loggerFactory); + } + + public ClientStreamingServerCallHandler CreateClientStreaming(Method method, ClientStreamingServerMethod invoker) + where TRequest : class + where TResponse : class + { + var options = CreateMethodOptions(); + var methodInvoker = new ClientStreamingServerMethodInvoker(invoker, method, options); + + return new ClientStreamingServerCallHandler(methodInvoker, _loggerFactory); + } + + public DuplexStreamingServerCallHandler CreateDuplexStreaming(Method method, DuplexStreamingServerMethod invoker) + where TRequest : class + where TResponse : class + { + var options = CreateMethodOptions(); + var methodInvoker = new DuplexStreamingServerMethodInvoker(invoker, method, options); + + return new DuplexStreamingServerCallHandler(methodInvoker, _loggerFactory); + } + + public ServerStreamingServerCallHandler CreateServerStreaming(Method method, ServerStreamingServerMethod invoker) + where TRequest : class + where TResponse : class + { + var options = CreateMethodOptions(); + var methodInvoker = new ServerStreamingServerMethodInvoker(invoker, method, options); + + return new ServerStreamingServerCallHandler(methodInvoker, _loggerFactory); + } + + public RequestDelegate CreateUnimplementedMethod() + { + var logger = _loggerFactory.CreateLogger(); + + return httpContext => + { + // CORS preflight request should be handled by CORS middleware. + // If it isn't then return 404 from endpoint request delegate. + if (GrpcProtocolHelpers.IsCorsPreflightRequest(httpContext)) + { + httpContext.Response.StatusCode = StatusCodes.Status404NotFound; + return Task.CompletedTask; + } + + GrpcProtocolHelpers.AddProtocolHeaders(httpContext.Response); + + var unimplementedMethod = httpContext.Request.RouteValues["unimplementedMethod"]?.ToString() ?? ""; + Log.MethodUnimplemented(logger, unimplementedMethod); + if (GrpcEventSource.Log.IsEnabled()) + { + GrpcEventSource.Log.CallUnimplemented(httpContext.Request.Path.Value!); + } + GrpcProtocolHelpers.SetStatus(GrpcProtocolHelpers.GetTrailersDestination(httpContext.Response), new Status(StatusCode.Unimplemented, "Method is unimplemented.")); + return Task.CompletedTask; + }; + } + + public bool IgnoreUnknownServices => _globalOptions.IgnoreUnknownServices ?? false; + public bool IgnoreUnknownMethods => false; + + public RequestDelegate CreateUnimplementedService() + { + var logger = _loggerFactory.CreateLogger(); + + return httpContext => + { + // CORS preflight request should be handled by CORS middleware. + // If it isn't then return 404 from endpoint request delegate. + if (GrpcProtocolHelpers.IsCorsPreflightRequest(httpContext)) + { + httpContext.Response.StatusCode = StatusCodes.Status404NotFound; + return Task.CompletedTask; + } + + GrpcProtocolHelpers.AddProtocolHeaders(httpContext.Response); + + var unimplementedService = httpContext.Request.RouteValues["unimplementedService"]?.ToString() ?? ""; + Log.ServiceUnimplemented(logger, unimplementedService); + if (GrpcEventSource.Log.IsEnabled()) + { + GrpcEventSource.Log.CallUnimplemented(httpContext.Request.Path.Value!); + } + GrpcProtocolHelpers.SetStatus(GrpcProtocolHelpers.GetTrailersDestination(httpContext.Response), new Status(StatusCode.Unimplemented, "Service is unimplemented.")); + return Task.CompletedTask; + }; + } +} + /// /// Creates server call handlers. Provides a place to get services that call handlers will use. /// -internal sealed partial class ServerCallHandlerFactory<[DynamicallyAccessedMembers(GrpcProtocolConstants.ServiceAccessibility)] TService> where TService : class +internal sealed partial class ServerCallHandlerFactory<[DynamicallyAccessedMembers(GrpcProtocolConstants.ServiceAccessibility)] TService> : IServerCallHandlerFactory where TService : class { private readonly ILoggerFactory _loggerFactory; private readonly IGrpcServiceActivator _serviceActivator; diff --git a/src/Grpc.AspNetCore.Server/Model/Internal/ServiceRouteBuilder.cs b/src/Grpc.AspNetCore.Server/Model/Internal/ServiceRouteBuilder.cs index b30d28a39..498d567bb 100644 --- a/src/Grpc.AspNetCore.Server/Model/Internal/ServiceRouteBuilder.cs +++ b/src/Grpc.AspNetCore.Server/Model/Internal/ServiceRouteBuilder.cs @@ -23,10 +23,67 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Routing; using Microsoft.Extensions.Logging; -using Log = Grpc.AspNetCore.Server.Model.Internal.ServiceRouteBuilderLog; +using Helper = Grpc.AspNetCore.Server.Model.Internal.ServiceRouteBuilderHelper; namespace Grpc.AspNetCore.Server.Model.Internal; +internal class ServiceRouteBuilder +{ + private readonly ServerCallHandlerFactory _serverCallHandlerFactory; + private readonly ServiceMethodsRegistry _serviceMethodsRegistry; + private readonly ILoggerFactory _loggerFactory; + private readonly ILogger _logger; + + public ServiceRouteBuilder( + ServerCallHandlerFactory callHandlerFactory, + ServiceMethodsRegistry serviceMethodsRegistry, + ILoggerFactory loggerFactory) + { + _serverCallHandlerFactory = callHandlerFactory; + _serviceMethodsRegistry = serviceMethodsRegistry; + _loggerFactory = loggerFactory; + _logger = loggerFactory.CreateLogger(); + } + + [RequiresUnreferencedCode("Due to type erasure in ServerServiceDefinition, Build is incompatible with trimming.")] + internal List Build(IEndpointRouteBuilder endpointRouteBuilder, ServerServiceDefinition serverServiceDefinition) + { + ServiceRouteBuilderLog.DiscoveringServiceMethods(_logger, typeof(ServerServiceDefinition)); + + var serviceBinder = new EndpointServiceBinder(_serverCallHandlerFactory, endpointRouteBuilder, _loggerFactory); + + serverServiceDefinition.BindService(serviceBinder); + var endpointConventionBuilders = serviceBinder.EndpointConventionBuilders; + + if(serviceBinder.MethodModels.Count > 0) + { + foreach(var method in serviceBinder.MethodModels) + { + var serviceMethodAttribute = method.Metadata + .Select(data => data as BindServiceMethodAttribute) + .SingleOrDefault(data => data is not null); + var serviceType = serviceMethodAttribute?.BindType ?? typeof(ServerServiceDefinition); + Helper.AddImplementedEndpoint(_logger, serviceType, endpointConventionBuilders, endpointRouteBuilder, method); + } + } + else + { + ServiceRouteBuilderLog.NoServiceMethodsDiscovered(_logger, typeof(ServerServiceDefinition)); + } + + Helper.CreateUnimplementedEndpoints( + endpointRouteBuilder, + _serviceMethodsRegistry, + _serverCallHandlerFactory, + serviceBinder.MethodModels, + endpointConventionBuilders); + + _serviceMethodsRegistry.Methods.AddRange(serviceBinder.MethodModels); + + return endpointConventionBuilders; + } +} + internal sealed class ServiceRouteBuilder<[DynamicallyAccessedMembers(GrpcProtocolConstants.ServiceAccessibility)] TService> where TService : class { private readonly IEnumerable> _serviceMethodProviders; @@ -48,7 +105,7 @@ public ServiceRouteBuilder( internal List Build(IEndpointRouteBuilder endpointRouteBuilder) { - Log.DiscoveringServiceMethods(_logger, typeof(TService)); + ServiceRouteBuilderLog.DiscoveringServiceMethods(_logger, typeof(TService)); var serviceMethodProviderContext = new ServiceMethodProviderContext(_serverCallHandlerFactory); foreach (var serviceMethodProvider in _serviceMethodProviders) @@ -61,39 +118,15 @@ internal List Build(IEndpointRouteBuilder endpointRo { foreach (var method in serviceMethodProviderContext.Methods) { - var endpointBuilder = endpointRouteBuilder.Map(method.Pattern, method.RequestDelegate); - - endpointBuilder.Add(ep => - { - ep.DisplayName = $"gRPC - {method.Pattern.RawText}"; - - ep.Metadata.Add(new GrpcMethodMetadata(typeof(TService), method.Method)); - foreach (var item in method.Metadata) - { - ep.Metadata.Add(item); - } - }); - - endpointConventionBuilders.Add(endpointBuilder); - - // Report the last HttpMethodMetadata added. It's the metadata used by routing. - var httpMethod = method.Metadata.OfType().LastOrDefault(); - - Log.AddedServiceMethod( - _logger, - method.Method.Name, - method.Method.ServiceName, - method.Method.Type, - httpMethod?.HttpMethods ?? Array.Empty(), - method.Pattern.RawText ?? string.Empty); + Helper.AddImplementedEndpoint(_logger, typeof(TService), endpointConventionBuilders, endpointRouteBuilder, method); } } else { - Log.NoServiceMethodsDiscovered(_logger, typeof(TService)); + ServiceRouteBuilderLog.NoServiceMethodsDiscovered(_logger, typeof(TService)); } - CreateUnimplementedEndpoints( + Helper.CreateUnimplementedEndpoints( endpointRouteBuilder, _serviceMethodsRegistry, _serverCallHandlerFactory, @@ -104,11 +137,48 @@ internal List Build(IEndpointRouteBuilder endpointRo return endpointConventionBuilders; } +} + +internal static class ServiceRouteBuilderHelper +{ + internal static void AddImplementedEndpoint( + ILogger logger, + [DynamicallyAccessedMembers(GrpcProtocolConstants.ServiceAccessibility)] Type serviceType, + List endpointConventionBuilders, + IEndpointRouteBuilder endpointRouteBuilder, + MethodModel method) + { + var endpointBuilder = endpointRouteBuilder.Map(method.Pattern, method.RequestDelegate); + + endpointBuilder.Add(ep => + { + ep.DisplayName = $"gRPC - {method.Pattern.RawText}"; + + ep.Metadata.Add(new GrpcMethodMetadata(serviceType, method.Method)); + foreach (var item in method.Metadata) + { + ep.Metadata.Add(item); + } + }); + + endpointConventionBuilders.Add(endpointBuilder); + + // Report the last HttpMethodMetadata added. It's the metadata used by routing. + var httpMethod = method.Metadata.OfType().LastOrDefault(); + + ServiceRouteBuilderLog.LogAddedServiceMethod( + logger, + method.Method.Name, + method.Method.ServiceName, + method.Method.Type, + httpMethod?.HttpMethods ?? Array.Empty(), + method.Pattern.RawText ?? string.Empty); + } internal static void CreateUnimplementedEndpoints( IEndpointRouteBuilder endpointRouteBuilder, ServiceMethodsRegistry serviceMethodsRegistry, - ServerCallHandlerFactory serverCallHandlerFactory, + IServerCallHandlerFactory serverCallHandlerFactory, List serviceMethods, List endpointConventionBuilders) { @@ -161,7 +231,7 @@ internal static partial class ServiceRouteBuilderLog [LoggerMessage(Level = LogLevel.Trace, EventId = 1, EventName = "AddedServiceMethod", Message = "Added gRPC method '{MethodName}' to service '{ServiceName}'. Method type: {MethodType}, HTTP method: {HttpMethod}, route pattern: '{RoutePattern}'.")] private static partial void AddedServiceMethod(ILogger logger, string methodName, string serviceName, MethodType methodType, string HttpMethod, string routePattern); - public static void AddedServiceMethod(ILogger logger, string methodName, string serviceName, MethodType methodType, IReadOnlyList httpMethods, string routePattern) + public static void LogAddedServiceMethod(ILogger logger, string methodName, string serviceName, MethodType methodType, IReadOnlyList httpMethods, string routePattern) { if (logger.IsEnabled(LogLevel.Trace)) { diff --git a/src/Grpc.Core.Api/ServerServiceDefinition.cs b/src/Grpc.Core.Api/ServerServiceDefinition.cs index cc5f1e6da..255549697 100644 --- a/src/Grpc.Core.Api/ServerServiceDefinition.cs +++ b/src/Grpc.Core.Api/ServerServiceDefinition.cs @@ -38,7 +38,7 @@ internal ServerServiceDefinition(List> addMethodAction /// /// Forwards all the previously stored AddMethod calls to the service binder. /// - internal void BindService(ServiceBinderBase serviceBinder) + public void BindService(ServiceBinderBase serviceBinder) { foreach (var addMethodAction in addMethodActions) { diff --git a/src/Grpc.Net.Client/Balancer/Subchannel.cs b/src/Grpc.Net.Client/Balancer/Subchannel.cs index db0a2ee42..d80d5adeb 100644 --- a/src/Grpc.Net.Client/Balancer/Subchannel.cs +++ b/src/Grpc.Net.Client/Balancer/Subchannel.cs @@ -237,6 +237,7 @@ public void UpdateAddresses(IReadOnlyList addresses) /// public void RequestConnection() { + var connectionRequested = false; lock (Lock) { switch (_state) @@ -245,7 +246,8 @@ public void RequestConnection() SubchannelLog.ConnectionRequested(_logger, Id); // Only start connecting underlying transport if in an idle state. - UpdateConnectivityState(ConnectivityState.Connecting, "Connection requested."); + // Update connectivity state outside of subchannel lock to avoid deadlock. + connectionRequested = true; break; case ConnectivityState.Connecting: case ConnectivityState.Ready: @@ -264,6 +266,11 @@ public void RequestConnection() } } + if (connectionRequested) + { + UpdateConnectivityState(ConnectivityState.Connecting, "Connection requested."); + } + // Don't capture the current ExecutionContext and its AsyncLocals onto the connect var restoreFlow = false; if (!ExecutionContext.IsFlowSuppressed()) @@ -448,6 +455,8 @@ internal bool UpdateConnectivityState(ConnectivityState state, string successDet internal bool UpdateConnectivityState(ConnectivityState state, Status status) { + Debug.Assert(!Monitor.IsEntered(Lock), "Ensure the subchannel lock isn't held here. Updating channel state with the subchannel lock can cause a deadlock."); + lock (Lock) { // Don't update subchannel state if the state is the same or the subchannel has been shutdown. @@ -462,7 +471,7 @@ internal bool UpdateConnectivityState(ConnectivityState state, Status status) } _state = state; } - + // Notify channel outside of lock to avoid deadlocks. _manager.OnSubchannelStateChange(this, state, status); return true; diff --git a/src/Grpc.Net.Client/GrpcChannel.cs b/src/Grpc.Net.Client/GrpcChannel.cs index 05e170e04..1ee1d41a6 100644 --- a/src/Grpc.Net.Client/GrpcChannel.cs +++ b/src/Grpc.Net.Client/GrpcChannel.cs @@ -16,7 +16,6 @@ #endregion -using System.Collections.Concurrent; using System.Diagnostics; using Grpc.Core; #if SUPPORT_LOAD_BALANCING @@ -51,7 +50,7 @@ public sealed partial class GrpcChannel : ChannelBase, IDisposable internal const long DefaultMaxRetryBufferPerCallSize = 1024 * 1024; // 1 MB private readonly object _lock; - private readonly ConcurrentDictionary _methodInfoCache; + private readonly ThreadSafeLookup _methodInfoCache; private readonly Func _createMethodInfoFunc; private readonly Dictionary? _serviceConfigMethods; private readonly bool _isSecure; @@ -109,7 +108,7 @@ public sealed partial class GrpcChannel : ChannelBase, IDisposable internal GrpcChannel(Uri address, GrpcChannelOptions channelOptions) : base(address.Authority) { _lock = new object(); - _methodInfoCache = new ConcurrentDictionary(); + _methodInfoCache = new ThreadSafeLookup(); // Dispose the HTTP client/handler if... // 1. No client/handler was specified and so the channel created the client itself diff --git a/src/Grpc.Net.Client/Internal/ThreadSafeLookup.cs b/src/Grpc.Net.Client/Internal/ThreadSafeLookup.cs new file mode 100644 index 000000000..b7a37b8d1 --- /dev/null +++ b/src/Grpc.Net.Client/Internal/ThreadSafeLookup.cs @@ -0,0 +1,104 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System.Collections.Concurrent; + +internal sealed class ThreadSafeLookup where TKey : notnull +{ + // Avoid allocating ConcurrentDictionary until the threshold is reached. + // Looking up a key in an array is as fast as a dictionary for small collections and uses much less memory. + internal const int Threshold = 10; + + private KeyValuePair[] _array = Array.Empty>(); + private ConcurrentDictionary? _dictionary; + + /// + /// Gets the value for the key if it exists. If the key does not exist then the value is created using the valueFactory. + /// The value is created outside of a lock and there is no guarentee which value will be stored or returned. + /// + public TValue GetOrAdd(TKey key, Func valueFactory) + { + if (_dictionary != null) + { + return _dictionary.GetOrAdd(key, valueFactory); + } + + if (TryGetValue(_array, key, out var value)) + { + return value; + } + + var newValue = valueFactory(key); + + lock (this) + { + if (_dictionary != null) + { + _dictionary.TryAdd(key, newValue); + } + else + { + // Double check inside lock if the key was added to the array by another thread. + if (TryGetValue(_array, key, out value)) + { + return value; + } + + if (_array.Length > Threshold - 1) + { + // Array length exceeds threshold so switch to dictionary. + var newDict = new ConcurrentDictionary(); + foreach (var kvp in _array) + { + newDict.TryAdd(kvp.Key, kvp.Value); + } + newDict.TryAdd(key, newValue); + + _dictionary = newDict; + _array = Array.Empty>(); + } + else + { + // Add new value by creating a new array with old plus new value. + var newArray = new KeyValuePair[_array.Length + 1]; + Array.Copy(_array, newArray, _array.Length); + newArray[newArray.Length - 1] = new KeyValuePair(key, newValue); + + _array = newArray; + } + } + } + + return newValue; + } + + private static bool TryGetValue(KeyValuePair[] array, TKey key, out TValue value) + { + foreach (var kvp in array) + { + if (EqualityComparer.Default.Equals(kvp.Key, key)) + { + value = kvp.Value; + return true; + } + } + + value = default!; + return false; + } +} diff --git a/src/Shared/Server/ClientStreamingServerMethodInvoker.cs b/src/Shared/Server/ClientStreamingServerMethodInvoker.cs index efbad4804..b8178685d 100644 --- a/src/Shared/Server/ClientStreamingServerMethodInvoker.cs +++ b/src/Shared/Server/ClientStreamingServerMethodInvoker.cs @@ -25,6 +25,52 @@ namespace Grpc.Shared.Server; +/// +/// Client streaming server method invoker for . +/// +/// Request message type for this method. +/// Response message type for this method. +internal sealed class ClientStreamingServerMethodInvoker : ServerMethodInvokerBase + where TRequest : class + where TResponse : class +{ + private readonly ClientStreamingServerMethod _invoker; + + /// + /// Creates a new instance of . + /// + /// The client streaming method to invoke. + /// The description of the gRPC method. + /// The options used to execute the method. + public ClientStreamingServerMethodInvoker( + ClientStreamingServerMethod invoker, + Method method, + MethodOptions options) + : base(method, options) + { + _invoker = invoker; + + if (Options.HasInterceptors) + { + var interceptorPipeline = new InterceptorPipelineBuilder(Options.Interceptors); + _invoker = interceptorPipeline.ClientStreamingPipeline(_invoker); + } + } + + /// + /// Invoke the client streaming method with the specified . + /// + /// The for the current request. + /// The . + /// The reader. + /// A that represents the asynchronous method. The + /// property returns the message. + public async Task Invoke(HttpContext _, ServerCallContext serverCallContext, IAsyncStreamReader requestStream) + { + return await _invoker(requestStream, serverCallContext); + } +} + /// /// Client streaming server method invoker. /// diff --git a/src/Shared/Server/DuplexStreamingServerMethodInvoker.cs b/src/Shared/Server/DuplexStreamingServerMethodInvoker.cs index e195fb88f..69c08f3ec 100644 --- a/src/Shared/Server/DuplexStreamingServerMethodInvoker.cs +++ b/src/Shared/Server/DuplexStreamingServerMethodInvoker.cs @@ -25,6 +25,52 @@ namespace Grpc.Shared.Server; +/// +/// Duplex streaming server method invoker for . +/// +/// Request message type for this method. +/// Response message type for this method. +internal sealed class DuplexStreamingServerMethodInvoker : ServerMethodInvokerBase + where TRequest : class + where TResponse : class +{ + private readonly DuplexStreamingServerMethod _invoker; + + /// + /// Creates a new instance of . + /// + /// The duplex streaming method to invoke. + /// The description of the gRPC method. + /// The options used to execute the method. + public DuplexStreamingServerMethodInvoker( + DuplexStreamingServerMethod invoker, + Method method, + MethodOptions options) + : base(method, options) + { + _invoker = invoker; + + if (Options.HasInterceptors) + { + var interceptorPipeline = new InterceptorPipelineBuilder(Options.Interceptors); + _invoker = interceptorPipeline.DuplexStreamingPipeline(_invoker); + } + } + + /// + /// Invoke the duplex streaming method with the specified . + /// + /// The for the current request. + /// The . + /// The reader. + /// The writer. + /// A that represents the asynchronous method. + public async Task Invoke(HttpContext _, ServerCallContext serverCallContext, IAsyncStreamReader requestStream, IServerStreamWriter responseStream) + { + await _invoker(requestStream, responseStream, serverCallContext); + } +} + /// /// Duplex streaming server method invoker. /// diff --git a/src/Shared/Server/ServerMethodInvokerBase.cs b/src/Shared/Server/ServerMethodInvokerBase.cs index b2afd3ef4..b1c1f3a41 100644 --- a/src/Shared/Server/ServerMethodInvokerBase.cs +++ b/src/Shared/Server/ServerMethodInvokerBase.cs @@ -23,6 +23,39 @@ namespace Grpc.Shared.Server; +/// +/// Server method invoker base type for . +/// +/// Request message type for this method. +/// Response message type for this method. +internal abstract class ServerMethodInvokerBase + where TRequest : class + where TResponse : class +{ + /// + /// Gets the description of the gRPC method. + /// + public Method Method { get; } + + /// + /// Gets the options used to execute the method. + /// + public MethodOptions Options { get; } + + /// + /// Creates a new instance of . + /// + /// The description of the gRPC method. + /// The options used to execute the method. + private protected ServerMethodInvokerBase( + Method method, + MethodOptions options) + { + Method = method; + Options = options; + } +} + /// /// Server method invoker base type. /// diff --git a/src/Shared/Server/ServerStreamingServerMethodInvoker.cs b/src/Shared/Server/ServerStreamingServerMethodInvoker.cs index 268e17463..50d2b4471 100644 --- a/src/Shared/Server/ServerStreamingServerMethodInvoker.cs +++ b/src/Shared/Server/ServerStreamingServerMethodInvoker.cs @@ -25,6 +25,52 @@ namespace Grpc.Shared.Server; +/// +/// Server streaming server method invoker for . +/// +/// Request message type for this method. +/// Response message type for this method. +internal sealed class ServerStreamingServerMethodInvoker : ServerMethodInvokerBase + where TRequest : class + where TResponse : class +{ + private readonly ServerStreamingServerMethod _invoker; + + /// + /// Creates a new instance of . + /// + /// The server streaming method to invoke. + /// The description of the gRPC method. + /// The options used to execute the method. + public ServerStreamingServerMethodInvoker( + ServerStreamingServerMethod invoker, + Method method, + MethodOptions options) + : base(method, options) + { + _invoker = invoker; + + if (Options.HasInterceptors) + { + var interceptorPipeline = new InterceptorPipelineBuilder(Options.Interceptors); + _invoker = interceptorPipeline.ServerStreamingPipeline(_invoker); + } + } + + /// + /// Invoke the server streaming method with the specified . + /// + /// The for the current request. + /// The . + /// The message. + /// The stream writer. + /// A that represents the asynchronous method. + public async Task Invoke(HttpContext _, ServerCallContext serverCallContext, TRequest request, IServerStreamWriter streamWriter) + { + await _invoker(request, streamWriter, serverCallContext); + } +} + /// /// Server streaming server method invoker. /// diff --git a/src/Shared/Server/UnaryServerMethodInvoker.cs b/src/Shared/Server/UnaryServerMethodInvoker.cs index cbd1115fc..d88c951e4 100644 --- a/src/Shared/Server/UnaryServerMethodInvoker.cs +++ b/src/Shared/Server/UnaryServerMethodInvoker.cs @@ -26,6 +26,52 @@ namespace Grpc.Shared.Server; +/// +/// Unary server method invoker for . +/// +/// Request message type for this method. +/// Response message type for this method. +internal sealed class UnaryServerMethodInvoker : ServerMethodInvokerBase + where TRequest : class + where TResponse : class +{ + private readonly UnaryServerMethod _invoker; + + /// + /// Creates a new instance of . + /// + /// The unary method to invoke. + /// The description of the gRPC method. + /// The options used to execute the method. + public UnaryServerMethodInvoker( + UnaryServerMethod invoker, + Method method, + MethodOptions options) + : base(method, options) + { + _invoker = invoker; + + if (Options.HasInterceptors) + { + var interceptorPipeline = new InterceptorPipelineBuilder(Options.Interceptors); + _invoker = interceptorPipeline.UnaryPipeline(_invoker); + } + } + + /// + /// Invoke the unary method with the specified . + /// + /// The for the current request. + /// The . + /// The message. + /// A that represents the asynchronous method. The + /// property returns the message. + public Task Invoke(HttpContext _, ServerCallContext serverCallContext, TRequest request) + { + return _invoker(request, serverCallContext); + } +} + /// /// Unary server method invoker. /// diff --git a/test/Grpc.AspNetCore.Server.Tests/GrpcEndpointRouteBuilderExtensionsTests.cs b/test/Grpc.AspNetCore.Server.Tests/GrpcEndpointRouteBuilderExtensionsTests.cs index 423444573..5cd5789b5 100644 --- a/test/Grpc.AspNetCore.Server.Tests/GrpcEndpointRouteBuilderExtensionsTests.cs +++ b/test/Grpc.AspNetCore.Server.Tests/GrpcEndpointRouteBuilderExtensionsTests.cs @@ -103,6 +103,40 @@ public void MapGrpcService_CanBindSubSubclass_CreatesEndpoints() BindServiceCore(); } + [Test] + public void MapGrpcService_ServerServiceDefinition_CreateEndPoints() + { + // Arrange + var services = ServicesHelpers.CreateServices(); + + var routeBuilder = CreateTestEndpointRouteBuilder(services.BuildServiceProvider(validateScopes: true)); + + // Act + var service = GreeterWithAttribute.BindService(new GreeterWithAttributeService()); + routeBuilder.MapGrpcService(service); + + // Assert + AssertForBindServiceCore(routeBuilder); + } + + [Test] + public void MapGrpcService_GetServerServiceDefinition_CreateEndPoints() + { + // Arrange + var services = ServicesHelpers.CreateServices(); + + var routeBuilder = CreateTestEndpointRouteBuilder(services.BuildServiceProvider(validateScopes: true)); + + // Act + static ServerServiceDefinition serverServiceDefinition(IServiceProvider provider) + => GreeterWithAttribute.BindService(new GreeterWithAttributeService()); + + routeBuilder.MapGrpcService(serverServiceDefinition); + + // Assert + AssertForBindServiceCore(routeBuilder); + } + private void BindServiceCore() where TService : class { // Arrange @@ -114,6 +148,11 @@ private void BindServiceCore() where TService : class routeBuilder.MapGrpcService(); // Assert + AssertForBindServiceCore(routeBuilder); + } + + private void AssertForBindServiceCore(IEndpointRouteBuilder routeBuilder) + { var endpoints = routeBuilder.DataSources .SelectMany(ds => ds.Endpoints) .Where(e => e.Metadata.GetMetadata() != null) diff --git a/test/Grpc.AspNetCore.Server.Tests/TestObjects/Services/WithAttribute/GreeterWithAttribute.cs b/test/Grpc.AspNetCore.Server.Tests/TestObjects/Services/WithAttribute/GreeterWithAttribute.cs index 08dabb7ca..6a6e2bae8 100644 --- a/test/Grpc.AspNetCore.Server.Tests/TestObjects/Services/WithAttribute/GreeterWithAttribute.cs +++ b/test/Grpc.AspNetCore.Server.Tests/TestObjects/Services/WithAttribute/GreeterWithAttribute.cs @@ -1,4 +1,4 @@ -#region Copyright notice and license +#region Copyright notice and license // Copyright 2019 The gRPC Authors // @@ -57,7 +57,10 @@ public abstract partial class GreeterBase public static ServerServiceDefinition BindService(GreeterBase serviceImpl) { - throw new NotImplementedException(); + return ServerServiceDefinition.CreateBuilder() + .AddMethod(__Method_SayHello, serviceImpl.SayHello) + .AddMethod(__Method_SayHellos, serviceImpl.SayHellos) + .Build(); } public static void BindService(ServiceBinderBase serviceBinder, GreeterBase serviceImpl) diff --git a/test/Grpc.Net.Client.Tests/Balancer/ConnectionManagerTests.cs b/test/Grpc.Net.Client.Tests/Balancer/ConnectionManagerTests.cs index 743639886..7a5221728 100644 --- a/test/Grpc.Net.Client.Tests/Balancer/ConnectionManagerTests.cs +++ b/test/Grpc.Net.Client.Tests/Balancer/ConnectionManagerTests.cs @@ -17,6 +17,7 @@ #endregion #if SUPPORT_LOAD_BALANCING +using System.Diagnostics; using System.Net; using System.Threading.Channels; using Greet; @@ -29,6 +30,7 @@ using Grpc.Tests.Shared; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Logging.Testing; using NUnit.Framework; using ChannelState = Grpc.Net.Client.Balancer.ChannelState; @@ -535,6 +537,89 @@ public async Task PickAsync_DoesNotDeadlockAfterReconnect_WithZeroAddressResolve await pickTask.DefaultTimeout(); } + [Test] + public async Task PickAsync_UpdateAddressesWhileRequestingConnection_DoesNotDeadlock() + { + var services = new ServiceCollection(); + services.AddNUnitLogger(); + + var testSink = new TestSink(); + var testProvider = new TestLoggerProvider(testSink); + + services.AddLogging(b => + { + b.AddProvider(testProvider); + }); + + await using var serviceProvider = services.BuildServiceProvider(); + var loggerFactory = serviceProvider.GetRequiredService(); + + var resolver = new TestResolver(loggerFactory); + resolver.UpdateAddresses(new List + { + new BalancerAddress("localhost", 80) + }); + + var channelOptions = new GrpcChannelOptions(); + + var transportFactory = new TestSubchannelTransportFactory(); + var clientChannel = CreateConnectionManager(loggerFactory, resolver, transportFactory, new[] { new PickFirstBalancerFactory() }); + // Configure balancer similar to how GrpcChannel constructor does it + clientChannel.ConfigureBalancer(c => new ChildHandlerLoadBalancer( + c, + channelOptions.ServiceConfig, + clientChannel)); + + await clientChannel.ConnectAsync(waitForReady: true, cancellationToken: CancellationToken.None); + + transportFactory.Transports.ForEach(t => t.Disconnect()); + + var requestConnectionSyncPoint = new SyncPoint(runContinuationsAsynchronously: true); + testSink.MessageLogged += (w) => + { + if (w.EventId.Name == "ConnectionRequested") + { + requestConnectionSyncPoint.WaitToContinue().Wait(); + } + }; + + // Task should pause when requesting connection because of the logger sink. + var pickTask = Task.Run(() => clientChannel.PickAsync( + new PickContext { Request = new HttpRequestMessage() }, + waitForReady: true, + CancellationToken.None).AsTask()); + + // Wait until we're paused on requesting a connection. + await requestConnectionSyncPoint.WaitForSyncPoint().DefaultTimeout(); + + // Update addresses while requesting a connection. + var updateAddressesTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var updateAddressesTask = Task.Run(() => + { + updateAddressesTcs.TrySetResult(null); + resolver.UpdateAddresses(new List + { + new BalancerAddress("localhost", 81) + }); + }); + + // There isn't a clean way to wait for UpdateAddresses to be waiting for the subchannel lock. + // Use a long delay to ensure we're waiting for the lock and are in the right state. + await updateAddressesTcs.Task.DefaultTimeout(); + await Task.Delay(500); + requestConnectionSyncPoint.Continue(); + + // Ensure the pick completes without deadlock. + try + { + await pickTask.DefaultTimeout(); + } + catch (TimeoutException ex) + { + throw new InvalidOperationException("Likely deadlock when picking subchannel.", ex); + } + } + [Test] public async Task PickAsync_ExecutionContext_DoesNotCaptureAsyncLocalsInConnect() { diff --git a/test/Grpc.Net.Client.Tests/Balancer/ConnectivityStateTests.cs b/test/Grpc.Net.Client.Tests/Balancer/ConnectivityStateTests.cs index 9b9d60ea4..621f042c7 100644 --- a/test/Grpc.Net.Client.Tests/Balancer/ConnectivityStateTests.cs +++ b/test/Grpc.Net.Client.Tests/Balancer/ConnectivityStateTests.cs @@ -1,4 +1,4 @@ -#region Copyright notice and license +#region Copyright notice and license // Copyright 2019 The gRPC Authors // @@ -29,6 +29,7 @@ using Grpc.Net.Client.Tests.Infrastructure.Balancer; using Grpc.Tests.Shared; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; using NUnit.Framework; namespace Grpc.Net.Client.Tests.Balancer; @@ -52,11 +53,13 @@ public async Task ResolverReturnsNoAddresses_CallWithWaitForReady_Wait() }); var services = new ServiceCollection(); + services.AddNUnitLogger(); services.AddSingleton(); services.AddSingleton(); services.AddSingleton(new TestSubchannelTransportFactory()); var serviceProvider = services.BuildServiceProvider(); + var logger = serviceProvider.GetRequiredService>(); var invoker = HttpClientCallInvokerFactory.Create(testMessageHandler, "test:///localhost", configure: o => { o.Credentials = ChannelCredentials.Insecure; @@ -72,6 +75,8 @@ public async Task ResolverReturnsNoAddresses_CallWithWaitForReady_Wait() Assert.IsNull(authority); var resolver = serviceProvider.GetRequiredService(); + + logger.LogInformation("UpdateAddresses"); resolver.UpdateAddresses(new List { new BalancerAddress("localhost", 81) diff --git a/test/Grpc.Net.Client.Tests/ThreadSafeLookupTests.cs b/test/Grpc.Net.Client.Tests/ThreadSafeLookupTests.cs new file mode 100644 index 000000000..57d073c61 --- /dev/null +++ b/test/Grpc.Net.Client.Tests/ThreadSafeLookupTests.cs @@ -0,0 +1,69 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +namespace Grpc.Net.Client.Tests; + +[TestFixture] +public class ThreadSafeLookupTests +{ + [Test] + public void GetOrAdd_ReturnsCorrectValueForNewKey() + { + var lookup = new ThreadSafeLookup(); + var result = lookup.GetOrAdd(1, k => "Value-1"); + + Assert.AreEqual("Value-1", result); + } + + [Test] + public void GetOrAdd_ReturnsExistingValueForExistingKey() + { + var lookup = new ThreadSafeLookup(); + lookup.GetOrAdd(1, k => "InitialValue"); + var result = lookup.GetOrAdd(1, k => "NewValue"); + + Assert.AreEqual("InitialValue", result); + } + + [Test] + public void GetOrAdd_SwitchesToDictionaryAfterThreshold() + { + var addCount = (ThreadSafeLookup.Threshold * 2); + var lookup = new ThreadSafeLookup(); + + for (var i = 0; i <= addCount; i++) + { + lookup.GetOrAdd(i, k => $"Value-{k}"); + } + + var result = lookup.GetOrAdd(addCount, k => $"NewValue-{addCount}"); + + Assert.AreEqual($"Value-{addCount}", result); + } + + [Test] + public void GetOrAdd_HandlesConcurrentAccess() + { + var lookup = new ThreadSafeLookup(); + Parallel.For(0, 1000, i => + { + var value = lookup.GetOrAdd(i, k => $"Value-{k}"); + Assert.AreEqual($"Value-{i}", value); + }); + } +} diff --git a/test/Shared/TestResolver.cs b/test/Shared/TestResolver.cs index 8136f981b..c6584d219 100644 --- a/test/Shared/TestResolver.cs +++ b/test/Shared/TestResolver.cs @@ -32,6 +32,7 @@ namespace Grpc.Tests.Shared; internal class TestResolver : PollingResolver { + private readonly object _lock; private readonly Func? _onRefreshAsync; private readonly TaskCompletionSource _hasResolvedTcs; private readonly ILogger _logger; @@ -45,6 +46,7 @@ public TestResolver(ILoggerFactory loggerFactory) : this(loggerFactory, null) public TestResolver(ILoggerFactory? loggerFactory = null, Func? onRefreshAsync = null) : base(loggerFactory ?? NullLoggerFactory.Instance) { + _lock = new object(); _onRefreshAsync = onRefreshAsync; _hasResolvedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; @@ -64,8 +66,11 @@ public void UpdateError(Status status) public void UpdateResult(ResolverResult result) { - _result = result; - Listener?.Invoke(result); + lock (_lock) + { + _result = result; + Listener?.Invoke(result); + } } protected override async Task ResolveAsync(CancellationToken cancellationToken) @@ -75,7 +80,10 @@ protected override async Task ResolveAsync(CancellationToken cancellationToken) await _onRefreshAsync(); } - Listener(_result ?? ResolverResult.ForResult(Array.Empty(), serviceConfig: null, serviceConfigStatus: null)); + lock (_lock) + { + Listener(_result ?? ResolverResult.ForResult(Array.Empty(), serviceConfig: null, serviceConfigStatus: null)); + } _hasResolvedTcs.TrySetResult(null); } }