diff --git a/.github/workflows/validate.yml b/.github/workflows/validate.yml index 7f09a2325a..42a07c8785 100644 --- a/.github/workflows/validate.yml +++ b/.github/workflows/validate.yml @@ -121,6 +121,12 @@ jobs: mm.py ./src/main/java/io/dapr/examples/jobs/README.md env: DOCKER_HOST: ${{steps.setup_docker.outputs.sock}} + - name: Validate conversation ai example + working-directory: ./examples + run: | + mm.py ./src/main/java/io/dapr/examples/conversation/README.md + env: + DOCKER_HOST: ${{steps.setup_docker.outputs.sock}} - name: Validate invoke http example working-directory: ./examples run: | diff --git a/examples/components/conversation/conversation.yaml b/examples/components/conversation/conversation.yaml new file mode 100644 index 0000000000..efb651fef1 --- /dev/null +++ b/examples/components/conversation/conversation.yaml @@ -0,0 +1,7 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: echo +spec: + type: conversation.echo + version: v1 diff --git a/examples/src/main/java/io/dapr/examples/conversation/DemoConversationAI.java b/examples/src/main/java/io/dapr/examples/conversation/DemoConversationAI.java new file mode 100644 index 0000000000..09c9570262 --- /dev/null +++ b/examples/src/main/java/io/dapr/examples/conversation/DemoConversationAI.java @@ -0,0 +1,49 @@ +/* + * Copyright 2021 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.examples.conversation; + +import io.dapr.client.DaprClientBuilder; +import io.dapr.client.DaprPreviewClient; +import io.dapr.client.domain.ConversationInput; +import io.dapr.client.domain.ConversationRequest; +import io.dapr.client.domain.ConversationResponse; +import reactor.core.publisher.Mono; + +import java.util.List; + +public class DemoConversationAI { + /** + * The main method to start the client. + * + * @param args Input arguments (unused). + */ + public static void main(String[] args) { + try (DaprPreviewClient client = new DaprClientBuilder().buildPreviewClient()) { + System.out.println("Sending the following input to LLM: Hello How are you? This is the my number 672-123-4567"); + + ConversationInput daprConversationInput = new ConversationInput("Hello How are you? " + + "This is the my number 672-123-4567"); + + // Component name is the name provided in the metadata block of the conversation.yaml file. + Mono responseMono = client.converse(new ConversationRequest("echo", + List.of(daprConversationInput)) + .setContextId("contextId") + .setScrubPii(true).setTemperature(1.1d)); + ConversationResponse response = responseMono.block(); + System.out.printf("Conversation output: %s", response.getConversationOutputs().get(0).getResult()); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/examples/src/main/java/io/dapr/examples/conversation/README.md b/examples/src/main/java/io/dapr/examples/conversation/README.md new file mode 100644 index 0000000000..29468cfb35 --- /dev/null +++ b/examples/src/main/java/io/dapr/examples/conversation/README.md @@ -0,0 +1,114 @@ +## Manage Dapr via the Conversation API + +This example provides the different capabilities provided by Dapr Java SDK for Conversation. For further information about Conversation APIs please refer to [this link](https://docs.dapr.io/developing-applications/building-blocks/conversation/conversation-overview/) + +### Using the Conversation API + +The Java SDK exposes several methods for this - +* `client.converse(...)` for conversing with an LLM through Dapr. + +## Pre-requisites + +* [Dapr CLI](https://docs.dapr.io/getting-started/install-dapr-cli/). +* Java JDK 11 (or greater): + * [Microsoft JDK 11](https://docs.microsoft.com/en-us/java/openjdk/download#openjdk-11) + * [Oracle JDK 11](https://www.oracle.com/technetwork/java/javase/downloads/index.html#JDK11) + * [OpenJDK 11](https://jdk.java.net/11/) +* [Apache Maven](https://maven.apache.org/install.html) version 3.x. + +### Checking out the code + +Clone this repository: + +```sh +git clone https://github.com/dapr/java-sdk.git +cd java-sdk +``` + +Then build the Maven project: + +```sh +# make sure you are in the `java-sdk` directory +mvn install +``` + +Then get into the examples directory: + +```sh +cd examples +``` + +### Initialize Dapr + +Run `dapr init` to initialize Dapr in Self-Hosted Mode if it's not already initialized. + +### Running the example + +This example uses the Java SDK Dapr client in order to **Converse** with an LLM. +`DemoConversationAI.java` is the example class demonstrating these features. +Kindly check [DaprPreviewClient.java](https://github.com/dapr/java-sdk/blob/master/sdk/src/main/java/io/dapr/client/DaprPreviewClient.java) for a detailed description of the supported APIs. + +```java +public class DemoConversationAI { + /** + * The main method to start the client. + * + * @param args Input arguments (unused). + */ + public static void main(String[] args) { + try (DaprPreviewClient client = new DaprClientBuilder().buildPreviewClient()) { + System.out.println("Sending the following input to LLM: Hello How are you? This is the my number 672-123-4567"); + + ConversationInput daprConversationInput = new ConversationInput("Hello How are you? " + + "This is the my number 672-123-4567"); + + // Component name is the name provided in the metadata block of the conversation.yaml file. + Mono responseMono = client.converse(new ConversationRequest("echo", + List.of(daprConversationInput)) + .setContextId("contextId") + .setScrubPii(true).setTemperature(1.1d)); + ConversationResponse response = responseMono.block(); + System.out.printf("Conversation output: %s", response.getConversationOutpus().get(0).getResult()); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} +``` + +Use the following command to run this example- + + + +```bash +dapr run --resources-path ./components/conversation --app-id myapp --app-port 8080 --dapr-http-port 3500 --dapr-grpc-port 51439 --log-level debug -- java -jar target/dapr-java-sdk-examples-exec.jar io.dapr.examples.conversation.DemoConversationAI +``` + + + +### Sample output +``` +== APP == Conversation output: Hello How are you? This is the my number +``` +### Cleanup + +To stop the app, run (or press CTRL+C): + + + +```bash +dapr stop --app-id myapp +``` + + + diff --git a/sdk-tests/src/test/java/io/dapr/it/testcontainers/DaprConversationIT.java b/sdk-tests/src/test/java/io/dapr/it/testcontainers/DaprConversationIT.java new file mode 100644 index 0000000000..013a5cdf0e --- /dev/null +++ b/sdk-tests/src/test/java/io/dapr/it/testcontainers/DaprConversationIT.java @@ -0,0 +1,133 @@ +/* + * Copyright 2021 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.it.testcontainers; + +import io.dapr.client.DaprPreviewClient; +import io.dapr.client.domain.ConversationInput; +import io.dapr.client.domain.ConversationRequest; +import io.dapr.client.domain.ConversationResponse; +import io.dapr.testcontainers.Component; +import io.dapr.testcontainers.DaprContainer; +import io.dapr.testcontainers.DaprLogLevel; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.context.SpringBootTest.WebEnvironment; +import org.springframework.test.context.DynamicPropertyRegistry; +import org.springframework.test.context.DynamicPropertySource; +import org.testcontainers.containers.Network; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Random; + +import static io.dapr.it.testcontainers.ContainerConstants.DAPR_RUNTIME_IMAGE_TAG; + +@SpringBootTest( + webEnvironment = WebEnvironment.RANDOM_PORT, + classes = { + DaprPreviewClientConfiguration.class, + TestConversationApplication.class + } +) +@Testcontainers +@Tag("testcontainers") +public class DaprConversationIT { + + private static final Network DAPR_NETWORK = Network.newNetwork(); + private static final Random RANDOM = new Random(); + private static final int PORT = RANDOM.nextInt(1000) + 8000; + + @Container + private static final DaprContainer DAPR_CONTAINER = new DaprContainer(DAPR_RUNTIME_IMAGE_TAG) + .withAppName("conversation-dapr-app") + .withComponent(new Component("echo", "conversation.echo", "v1", new HashMap<>())) + .withNetwork(DAPR_NETWORK) + .withDaprLogLevel(DaprLogLevel.DEBUG) + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withAppChannelAddress("host.testcontainers.internal") + .withAppPort(PORT); + + /** + * Expose the Dapr port to the host. + * + * @param registry the dynamic property registry + */ + @DynamicPropertySource + static void daprProperties(DynamicPropertyRegistry registry) { + registry.add("dapr.http.endpoint", DAPR_CONTAINER::getHttpEndpoint); + registry.add("dapr.grpc.endpoint", DAPR_CONTAINER::getGrpcEndpoint); + registry.add("server.port", () -> PORT); + } + + @Autowired + private DaprPreviewClient daprPreviewClient; + + @BeforeEach + public void setUp(){ + org.testcontainers.Testcontainers.exposeHostPorts(PORT); + } + + @Test + public void testConversationSDKShouldHaveSameOutputAndInput() { + ConversationInput conversationInput = new ConversationInput("input this"); + List conversationInputList = new ArrayList<>(); + conversationInputList.add(conversationInput); + + ConversationResponse response = + this.daprPreviewClient.converse(new ConversationRequest("echo", conversationInputList)).block(); + + Assertions.assertEquals("", response.getContextId()); + Assertions.assertEquals("input this", response.getConversationOutputs().get(0).getResult()); + } + + @Test + public void testConversationSDKShouldScrubPIIWhenScrubPIIIsSetInRequestBody() { + List conversationInputList = new ArrayList<>(); + conversationInputList.add(new ConversationInput("input this abcd@gmail.com")); + conversationInputList.add(new ConversationInput("input this +12341567890")); + + ConversationResponse response = + this.daprPreviewClient.converse(new ConversationRequest("echo", conversationInputList) + .setScrubPii(true)).block(); + + Assertions.assertEquals("", response.getContextId()); + Assertions.assertEquals("input this ", + response.getConversationOutputs().get(0).getResult()); + Assertions.assertEquals("input this ", + response.getConversationOutputs().get(1).getResult()); + } + + @Test + public void testConversationSDKShouldScrubPIIOnlyForTheInputWhereScrubPIIIsSet() { + List conversationInputList = new ArrayList<>(); + conversationInputList.add(new ConversationInput("input this abcd@gmail.com")); + conversationInputList.add(new ConversationInput("input this +12341567890").setScrubPii(true)); + + ConversationResponse response = + this.daprPreviewClient.converse(new ConversationRequest("echo", conversationInputList)).block(); + + Assertions.assertEquals("", response.getContextId()); + Assertions.assertEquals("input this abcd@gmail.com", + response.getConversationOutputs().get(0).getResult()); + Assertions.assertEquals("input this ", + response.getConversationOutputs().get(1).getResult()); + } +} diff --git a/sdk-tests/src/test/java/io/dapr/it/testcontainers/DaprJobsIT.java b/sdk-tests/src/test/java/io/dapr/it/testcontainers/DaprJobsIT.java index 3cb433cf13..5b52c0267b 100644 --- a/sdk-tests/src/test/java/io/dapr/it/testcontainers/DaprJobsIT.java +++ b/sdk-tests/src/test/java/io/dapr/it/testcontainers/DaprJobsIT.java @@ -45,7 +45,7 @@ @SpringBootTest( webEnvironment = WebEnvironment.RANDOM_PORT, classes = { - TestDaprJobsConfiguration.class, + DaprPreviewClientConfiguration.class, TestJobsApplication.class } ) diff --git a/sdk-tests/src/test/java/io/dapr/it/testcontainers/TestDaprJobsConfiguration.java b/sdk-tests/src/test/java/io/dapr/it/testcontainers/DaprPreviewClientConfiguration.java similarity index 91% rename from sdk-tests/src/test/java/io/dapr/it/testcontainers/TestDaprJobsConfiguration.java rename to sdk-tests/src/test/java/io/dapr/it/testcontainers/DaprPreviewClientConfiguration.java index 5e0e2e8c89..66dce6d726 100644 --- a/sdk-tests/src/test/java/io/dapr/it/testcontainers/TestDaprJobsConfiguration.java +++ b/sdk-tests/src/test/java/io/dapr/it/testcontainers/DaprPreviewClientConfiguration.java @@ -14,11 +14,9 @@ package io.dapr.it.testcontainers; import io.dapr.client.DaprClientBuilder; -import io.dapr.client.DaprClientImpl; import io.dapr.client.DaprPreviewClient; import io.dapr.config.Properties; import io.dapr.config.Property; -import io.dapr.serializer.DefaultObjectSerializer; import org.springframework.beans.factory.annotation.Value; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -26,7 +24,7 @@ import java.util.Map; @Configuration -public class TestDaprJobsConfiguration { +public class DaprPreviewClientConfiguration { @Bean public DaprPreviewClient daprPreviewClient( @Value("${dapr.http.endpoint}") String daprHttpEndpoint, diff --git a/sdk-tests/src/test/java/io/dapr/it/testcontainers/TestConversationApplication.java b/sdk-tests/src/test/java/io/dapr/it/testcontainers/TestConversationApplication.java new file mode 100644 index 0000000000..2bb9eeac10 --- /dev/null +++ b/sdk-tests/src/test/java/io/dapr/it/testcontainers/TestConversationApplication.java @@ -0,0 +1,25 @@ +/* + * Copyright 2024 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.it.testcontainers; + +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; + +@SpringBootApplication +public class TestConversationApplication { + + public static void main(String[] args) { + SpringApplication.run(TestConversationApplication.class, args); + } +} diff --git a/sdk/src/main/java/io/dapr/client/DaprClientImpl.java b/sdk/src/main/java/io/dapr/client/DaprClientImpl.java index aad67bc23d..66c8772d66 100644 --- a/sdk/src/main/java/io/dapr/client/DaprClientImpl.java +++ b/sdk/src/main/java/io/dapr/client/DaprClientImpl.java @@ -27,6 +27,10 @@ import io.dapr.client.domain.CloudEvent; import io.dapr.client.domain.ComponentMetadata; import io.dapr.client.domain.ConfigurationItem; +import io.dapr.client.domain.ConversationInput; +import io.dapr.client.domain.ConversationOutput; +import io.dapr.client.domain.ConversationRequest; +import io.dapr.client.domain.ConversationResponse; import io.dapr.client.domain.DaprMetadata; import io.dapr.client.domain.DeleteJobRequest; import io.dapr.client.domain.DeleteStateRequest; @@ -99,7 +103,6 @@ import java.io.IOException; import java.time.Duration; import java.time.Instant; -import java.time.OffsetDateTime; import java.time.ZoneOffset; import java.time.format.DateTimeFormatter; import java.util.ArrayList; @@ -1552,6 +1555,79 @@ public Mono getMetadata() { }); } + /** + * {@inheritDoc} + */ + @Override + public Mono converse(ConversationRequest conversationRequest) { + + try { + validateConversationRequest(conversationRequest); + + DaprProtos.ConversationRequest.Builder protosConversationRequestBuilder = DaprProtos.ConversationRequest + .newBuilder().setTemperature(conversationRequest.getTemperature()) + .setScrubPII(conversationRequest.isScrubPii()) + .setName(conversationRequest.getName()); + + if (conversationRequest.getContextId() != null) { + protosConversationRequestBuilder.setContextID(conversationRequest.getContextId()); + } + + for (ConversationInput input : conversationRequest.getInputs()) { + if (input.getContent() == null || input.getContent().isEmpty()) { + throw new IllegalArgumentException("Conversation input content cannot be null or empty."); + } + + DaprProtos.ConversationInput.Builder conversationInputOrBuilder = DaprProtos.ConversationInput.newBuilder() + .setContent(input.getContent()) + .setScrubPII(input.isScrubPii()); + + if (input.getRole() != null) { + conversationInputOrBuilder.setRole(input.getRole().toString()); + } + + protosConversationRequestBuilder.addInputs(conversationInputOrBuilder.build()); + } + + Mono conversationResponseMono = Mono.deferContextual( + context -> this.createMono( + it -> intercept(context, asyncStub) + .converseAlpha1(protosConversationRequestBuilder.build(), it) + ) + ); + + return conversationResponseMono.map(conversationResponse -> { + + List conversationOutputs = new ArrayList<>(); + for (DaprProtos.ConversationResult conversationResult : conversationResponse.getOutputsList()) { + Map parameters = new HashMap<>(); + for (Map.Entry entrySet : conversationResult.getParametersMap().entrySet()) { + parameters.put(entrySet.getKey(), entrySet.getValue().toByteArray()); + } + + ConversationOutput conversationOutput = + new ConversationOutput(conversationResult.getResult(), parameters); + conversationOutputs.add(conversationOutput); + } + + return new ConversationResponse(conversationResponse.getContextID(), conversationOutputs); + }); + } catch (Exception ex) { + return DaprException.wrapMono(ex); + } + } + + private void validateConversationRequest(ConversationRequest conversationRequest) { + if ((conversationRequest.getName() == null) || (conversationRequest.getName().trim().isEmpty())) { + throw new IllegalArgumentException("LLM name cannot be null or empty."); + } + + if ((conversationRequest.getInputs() == null) || (conversationRequest + .getInputs().isEmpty())) { + throw new IllegalArgumentException("Conversation inputs cannot be null or empty."); + } + } + private DaprMetadata buildDaprMetadata(DaprProtos.GetMetadataResponse response) throws IOException { String id = response.getId(); String runtimeVersion = response.getRuntimeVersion(); diff --git a/sdk/src/main/java/io/dapr/client/DaprPreviewClient.java b/sdk/src/main/java/io/dapr/client/DaprPreviewClient.java index b4fba8ef38..89c6eded8f 100644 --- a/sdk/src/main/java/io/dapr/client/DaprPreviewClient.java +++ b/sdk/src/main/java/io/dapr/client/DaprPreviewClient.java @@ -17,6 +17,8 @@ import io.dapr.client.domain.BulkPublishRequest; import io.dapr.client.domain.BulkPublishResponse; import io.dapr.client.domain.BulkPublishResponseFailedEntry; +import io.dapr.client.domain.ConversationRequest; +import io.dapr.client.domain.ConversationResponse; import io.dapr.client.domain.DeleteJobRequest; import io.dapr.client.domain.GetJobRequest; import io.dapr.client.domain.GetJobResponse; @@ -304,4 +306,12 @@ Subscription subscribeToEvents( * @throws IllegalArgumentException If the request or its required fields like name are null or empty. */ public Mono deleteJob(DeleteJobRequest deleteJobRequest); + + /* + * Converse with an LLM. + * + * @param conversationRequest request to be passed to the LLM. + * @return {@link ConversationResponse}. + */ + public Mono converse(ConversationRequest conversationRequest); } diff --git a/sdk/src/main/java/io/dapr/client/domain/ConversationInput.java b/sdk/src/main/java/io/dapr/client/domain/ConversationInput.java new file mode 100644 index 0000000000..0a1dbfe8a6 --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/domain/ConversationInput.java @@ -0,0 +1,84 @@ +/* + * Copyright 2021 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.client.domain; + +/** + * Represents an input message for a conversation with an LLM. + */ +public class ConversationInput { + + private final String content; + + private String role; + + private boolean scrubPii; + + /** + * Constructor. + * + * @param content for the llm. + */ + public ConversationInput(String content) { + this.content = content; + } + + /** + * The message content to send to the LLM. Required + * + * @return The content to be sent to the LLM. + */ + public String getContent() { + return content; + } + + /** + * The role for the LLM to assume. + * + * @return this. + */ + public String getRole() { + return role; + } + + /** + * Set the role for LLM to assume. + * + * @param role The role to assign to the message. + * @return this. + */ + public ConversationInput setRole(String role) { + this.role = role; + return this; + } + + /** + * Checks if Personally Identifiable Information (PII) should be scrubbed before sending to the LLM. + * + * @return {@code true} if PII should be scrubbed, {@code false} otherwise. + */ + public boolean isScrubPii() { + return scrubPii; + } + + /** + * Enable obfuscation of sensitive information present in the content field. Optional + * + * @param scrubPii A boolean indicating whether to remove PII. + * @return this. + */ + public ConversationInput setScrubPii(boolean scrubPii) { + this.scrubPii = scrubPii; + return this; + } +} diff --git a/sdk/src/main/java/io/dapr/client/domain/ConversationOutput.java b/sdk/src/main/java/io/dapr/client/domain/ConversationOutput.java new file mode 100644 index 0000000000..efe82e2eb3 --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/domain/ConversationOutput.java @@ -0,0 +1,56 @@ +/* + * Copyright 2021 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.client.domain; + +import java.util.Collections; +import java.util.Map; + +/** + * Returns the conversation output. + */ +public class ConversationOutput { + + private final String result; + + private final Map parameters; + + /** + * Constructor. + * + * @param result result for one of the conversation input. + * @param parameters all custom fields. + */ + public ConversationOutput(String result, Map parameters) { + this.result = result; + this.parameters = Map.copyOf(parameters); + } + + /** + * Result for the one conversation input. + * + * @return result output from the LLM. + */ + public String getResult() { + return this.result; + } + + /** + * Parameters for all custom fields. + * + * @return parameters. + */ + public Map getParameters() { + return this.parameters; + } +} diff --git a/sdk/src/main/java/io/dapr/client/domain/ConversationRequest.java b/sdk/src/main/java/io/dapr/client/domain/ConversationRequest.java new file mode 100644 index 0000000000..8bac65b9a2 --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/domain/ConversationRequest.java @@ -0,0 +1,119 @@ +/* + * Copyright 2021 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.client.domain; + +import java.util.List; + +/** + * Represents a conversation configuration with details about component name, + * conversation inputs, context identifier, PII scrubbing, and temperature control. + */ +public class ConversationRequest { + + private final String name; + private final List inputs; + private String contextId; + private boolean scrubPii; + private double temperature; + + /** + * Constructs a DaprConversation with a component name and conversation inputs. + * + * @param name The name of the Dapr conversation component. See a list of all available conversation components + * @see + * @param inputs the list of Dapr conversation inputs + */ + public ConversationRequest(String name, List inputs) { + this.name = name; + this.inputs = inputs; + } + + /** + * Gets the conversation component name. + * + * @return the conversation component name + */ + public String getName() { + return name; + } + + /** + * Gets the list of Dapr conversation input. + * + * @return the list of conversation input + */ + public List getInputs() { + return inputs; + } + + /** + * Gets the context identifier. + * + * @return the context identifier + */ + public String getContextId() { + return contextId; + } + + /** + * Sets the context identifier. + * + * @param contextId the context identifier to set + * @return the current instance of {@link ConversationRequest} + */ + public ConversationRequest setContextId(String contextId) { + this.contextId = contextId; + return this; + } + + /** + * Checks if PII scrubbing is enabled. + * + * @return true if PII scrubbing is enabled, false otherwise + */ + public boolean isScrubPii() { + return scrubPii; + } + + /** + * Enable obfuscation of sensitive information returning from the LLM. Optional. + * + * @param scrubPii whether to enable PII scrubbing + * @return the current instance of {@link ConversationRequest} + */ + public ConversationRequest setScrubPii(boolean scrubPii) { + this.scrubPii = scrubPii; + return this; + } + + /** + * Gets the temperature of the model. Used to optimize for consistency and creativity. Optional + * + * @return the temperature value + */ + public double getTemperature() { + return temperature; + } + + /** + * Sets the temperature of the model. Used to optimize for consistency and creativity. Optional + * + * @param temperature the temperature value to set + * @return the current instance of {@link ConversationRequest} + */ + public ConversationRequest setTemperature(double temperature) { + this.temperature = temperature; + return this; + } +} diff --git a/sdk/src/main/java/io/dapr/client/domain/ConversationResponse.java b/sdk/src/main/java/io/dapr/client/domain/ConversationResponse.java new file mode 100644 index 0000000000..8059365544 --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/domain/ConversationResponse.java @@ -0,0 +1,56 @@ +/* + * Copyright 2021 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.client.domain; + +import java.util.Collections; +import java.util.List; + +/** + * Response from the Dapr Conversation API. + */ +public class ConversationResponse { + + private String contextId; + + private final List outputs; + + /** + * Constructor. + * + * @param contextId context id supplied to LLM. + * @param outputs outputs from the LLM. + */ + public ConversationResponse(String contextId, List outputs) { + this.contextId = contextId; + this.outputs = List.copyOf(outputs); + } + + /** + * The ID of an existing chat (like in ChatGPT). + * + * @return String identifier. + */ + public String getContextId() { + return this.contextId; + } + + /** + * Get list of conversation outputs. + * + * @return List{@link ConversationOutput}. + */ + public List getConversationOutputs() { + return this.outputs; + } +} diff --git a/sdk/src/test/java/io/dapr/client/DaprPreviewClientGrpcTest.java b/sdk/src/test/java/io/dapr/client/DaprPreviewClientGrpcTest.java index aec0f287ae..5cc49edfdd 100644 --- a/sdk/src/test/java/io/dapr/client/DaprPreviewClientGrpcTest.java +++ b/sdk/src/test/java/io/dapr/client/DaprPreviewClientGrpcTest.java @@ -26,6 +26,9 @@ import io.dapr.client.domain.GetJobRequest; import io.dapr.client.domain.GetJobResponse; import io.dapr.client.domain.JobSchedule; +import io.dapr.client.domain.ConversationInput; +import io.dapr.client.domain.ConversationRequest; +import io.dapr.client.domain.ConversationResponse; import io.dapr.client.domain.QueryStateItem; import io.dapr.client.domain.QueryStateRequest; import io.dapr.client.domain.QueryStateResponse; @@ -71,6 +74,7 @@ import java.util.concurrent.atomic.AtomicInteger; import static io.dapr.utils.TestUtils.assertThrowsDaprException; +import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -110,7 +114,7 @@ public void setup() throws IOException { daprHttp = mock(DaprHttp.class); when(daprStub.withInterceptors(any())).thenReturn(daprStub); previewClient = new DaprClientImpl( - channel, daprStub, daprHttp, new DefaultObjectSerializer(), new DefaultObjectSerializer()); + channel, daprStub, daprHttp, new DefaultObjectSerializer(), new DefaultObjectSerializer()); doNothing().when(channel).close(); } @@ -128,28 +132,28 @@ public void publishEventsExceptionThrownTest() { }).when(daprStub).bulkPublishEventAlpha1(any(DaprProtos.BulkPublishRequest.class), any()); assertThrowsDaprException( - StatusRuntimeException.class, - "INVALID_ARGUMENT", - "INVALID_ARGUMENT: bad bad argument", - () -> previewClient.publishEvents(new BulkPublishRequest<>(PUBSUB_NAME, TOPIC_NAME, - Collections.EMPTY_LIST)).block()); + StatusRuntimeException.class, + "INVALID_ARGUMENT", + "INVALID_ARGUMENT: bad bad argument", + () -> previewClient.publishEvents(new BulkPublishRequest<>(PUBSUB_NAME, TOPIC_NAME, + Collections.EMPTY_LIST)).block()); } @Test public void publishEventsCallbackExceptionThrownTest() { doAnswer((Answer) invocation -> { StreamObserver observer = - (StreamObserver) invocation.getArguments()[1]; + (StreamObserver) invocation.getArguments()[1]; observer.onError(newStatusRuntimeException("INVALID_ARGUMENT", "bad bad argument")); return null; }).when(daprStub).bulkPublishEventAlpha1(any(DaprProtos.BulkPublishRequest.class), any()); assertThrowsDaprException( - ExecutionException.class, - "INVALID_ARGUMENT", - "INVALID_ARGUMENT: bad bad argument", - () -> previewClient.publishEvents(new BulkPublishRequest<>(PUBSUB_NAME, TOPIC_NAME, - Collections.EMPTY_LIST)).block()); + ExecutionException.class, + "INVALID_ARGUMENT", + "INVALID_ARGUMENT: bad bad argument", + () -> previewClient.publishEvents(new BulkPublishRequest<>(PUBSUB_NAME, TOPIC_NAME, + Collections.EMPTY_LIST)).block()); } @Test @@ -157,7 +161,7 @@ public void publishEventsContentTypeMismatchException() throws IOException { DaprObjectSerializer mockSerializer = mock(DaprObjectSerializer.class); doAnswer((Answer) invocation -> { StreamObserver observer = - (StreamObserver) invocation.getArguments()[1]; + (StreamObserver) invocation.getArguments()[1]; observer.onNext(DaprProtos.BulkPublishResponse.getDefaultInstance()); observer.onCompleted(); return null; @@ -165,9 +169,9 @@ public void publishEventsContentTypeMismatchException() throws IOException { BulkPublishEntry entry = new BulkPublishEntry<>("1", "testEntry" - , "application/octet-stream", null); + , "application/octet-stream", null); BulkPublishRequest wrongReq = new BulkPublishRequest<>(PUBSUB_NAME, TOPIC_NAME, - Collections.singletonList(entry)); + Collections.singletonList(entry)); assertThrows(IllegalArgumentException.class, () -> previewClient.publishEvents(wrongReq).block()); } @@ -178,30 +182,30 @@ public void publishEventsSerializeException() throws IOException { previewClient = new DaprClientImpl(channel, daprStub, daprHttp, mockSerializer, new DefaultObjectSerializer()); doAnswer((Answer) invocation -> { StreamObserver observer = - (StreamObserver) invocation.getArguments()[1]; + (StreamObserver) invocation.getArguments()[1]; observer.onNext(DaprProtos.BulkPublishResponse.getDefaultInstance()); observer.onCompleted(); return null; }).when(daprStub).publishEvent(any(DaprProtos.PublishEventRequest.class), any()); BulkPublishEntry> entry = new BulkPublishEntry<>("1", new HashMap<>(), - "application/json", null); + "application/json", null); BulkPublishRequest> req = new BulkPublishRequest<>(PUBSUB_NAME, TOPIC_NAME, - Collections.singletonList(entry)); + Collections.singletonList(entry)); when(mockSerializer.serialize(any())).thenThrow(IOException.class); Mono>> result = previewClient.publishEvents(req); assertThrowsDaprException( - IOException.class, - "UNKNOWN", - "UNKNOWN: ", - () -> result.block()); + IOException.class, + "UNKNOWN", + "UNKNOWN: ", + () -> result.block()); } @Test public void publishEventsTest() { doAnswer((Answer) invocation -> { StreamObserver observer = - (StreamObserver) invocation.getArguments()[1]; + (StreamObserver) invocation.getArguments()[1]; DaprProtos.BulkPublishResponse.Builder builder = DaprProtos.BulkPublishResponse.newBuilder(); observer.onNext(builder.build()); observer.onCompleted(); @@ -209,9 +213,9 @@ public void publishEventsTest() { }).when(daprStub).bulkPublishEventAlpha1(any(DaprProtos.BulkPublishRequest.class), any()); BulkPublishEntry entry = new BulkPublishEntry<>("1", "test", - "text/plain", null); + "text/plain", null); BulkPublishRequest req = new BulkPublishRequest<>(PUBSUB_NAME, TOPIC_NAME, - Collections.singletonList(entry)); + Collections.singletonList(entry)); Mono> result = previewClient.publishEvents(req); BulkPublishResponse res = result.block(); Assertions.assertNotNull(res); @@ -222,7 +226,7 @@ public void publishEventsTest() { public void publishEventsWithoutMetaTest() { doAnswer((Answer) invocation -> { StreamObserver observer = - (StreamObserver) invocation.getArguments()[1]; + (StreamObserver) invocation.getArguments()[1]; DaprProtos.BulkPublishResponse.Builder builder = DaprProtos.BulkPublishResponse.newBuilder(); observer.onNext(builder.build()); observer.onCompleted(); @@ -230,7 +234,7 @@ public void publishEventsWithoutMetaTest() { }).when(daprStub).bulkPublishEventAlpha1(any(DaprProtos.BulkPublishRequest.class), any()); Mono> result = previewClient.publishEvents(PUBSUB_NAME, TOPIC_NAME, - "text/plain", Collections.singletonList("test")); + "text/plain", Collections.singletonList("test")); BulkPublishResponse res = result.block(); Assertions.assertNotNull(res); assertEquals( 0, res.getFailedEntries().size(), "expected no entries in failed entries list"); @@ -240,7 +244,7 @@ public void publishEventsWithoutMetaTest() { public void publishEventsWithRequestMetaTest() { doAnswer((Answer) invocation -> { StreamObserver observer = - (StreamObserver) invocation.getArguments()[1]; + (StreamObserver) invocation.getArguments()[1]; DaprProtos.BulkPublishResponse.Builder builder = DaprProtos.BulkPublishResponse.newBuilder(); observer.onNext(builder.build()); observer.onCompleted(); @@ -248,9 +252,9 @@ public void publishEventsWithRequestMetaTest() { }).when(daprStub).bulkPublishEventAlpha1(any(DaprProtos.BulkPublishRequest.class), any()); Mono> result = previewClient.publishEvents(PUBSUB_NAME, TOPIC_NAME, - "text/plain", new HashMap(){{ - put("ttlInSeconds", "123"); - }}, Collections.singletonList("test")); + "text/plain", new HashMap(){{ + put("ttlInSeconds", "123"); + }}, Collections.singletonList("test")); BulkPublishResponse res = result.block(); Assertions.assertNotNull(res); assertEquals( 0, res.getFailedEntries().size(), "expected no entry in failed entries list"); @@ -260,7 +264,7 @@ public void publishEventsWithRequestMetaTest() { public void publishEventsObjectTest() { doAnswer((Answer) invocation -> { StreamObserver observer = - (StreamObserver) invocation.getArguments()[1]; + (StreamObserver) invocation.getArguments()[1]; observer.onNext(DaprProtos.BulkPublishResponse.getDefaultInstance()); observer.onCompleted(); return null; @@ -271,7 +275,7 @@ public void publishEventsObjectTest() { } if (!"{\"id\":1,\"value\":\"Event\"}".equals(new String(entry.getEvent().toByteArray())) && - !"{\"value\":\"Event\",\"id\":1}".equals(new String(entry.getEvent().toByteArray()))) { + !"{\"value\":\"Event\",\"id\":1}".equals(new String(entry.getEvent().toByteArray()))) { return false; } return true; @@ -280,9 +284,9 @@ public void publishEventsObjectTest() { DaprClientGrpcTest.MyObject event = new DaprClientGrpcTest.MyObject(1, "Event"); BulkPublishEntry entry = new BulkPublishEntry<>("1", event, - "application/json", null); + "application/json", null); BulkPublishRequest req = new BulkPublishRequest<>(PUBSUB_NAME, TOPIC_NAME, - Collections.singletonList(entry)); + Collections.singletonList(entry)); BulkPublishResponse result = previewClient.publishEvents(req).block(); Assertions.assertNotNull(result); Assertions.assertEquals(0, result.getFailedEntries().size(), "expected no entries to be failed"); @@ -292,7 +296,7 @@ public void publishEventsObjectTest() { public void publishEventsContentTypeOverrideTest() { doAnswer((Answer) invocation -> { StreamObserver observer = - (StreamObserver) invocation.getArguments()[1]; + (StreamObserver) invocation.getArguments()[1]; observer.onNext(DaprProtos.BulkPublishResponse.getDefaultInstance()); observer.onCompleted(); return null; @@ -309,9 +313,9 @@ public void publishEventsContentTypeOverrideTest() { }), any()); BulkPublishEntry entry = new BulkPublishEntry<>("1", "hello", - "", null); + "", null); BulkPublishRequest req = new BulkPublishRequest<>(PUBSUB_NAME, TOPIC_NAME, - Collections.singletonList(entry)); + Collections.singletonList(entry)); BulkPublishResponse result = previewClient.publishEvents(req).block(); Assertions.assertNotNull(result); Assertions.assertEquals( 0, result.getFailedEntries().size(), "expected no entries to be failed"); @@ -351,7 +355,7 @@ public void queryState() throws JsonProcessingException { assertEquals(0, req.getMetadataCount()); StreamObserver observer = (StreamObserver) - invocation.getArguments()[1]; + invocation.getArguments()[1]; observer.onNext(responseEnvelope); observer.onCompleted(); return null; @@ -378,14 +382,14 @@ public void queryStateMetadataError() throws JsonProcessingException { assertEquals(1, req.getMetadataCount()); StreamObserver observer = (StreamObserver) - invocation.getArguments()[1]; + invocation.getArguments()[1]; observer.onNext(responseEnvelope); observer.onCompleted(); return null; }).when(daprStub).queryStateAlpha1(any(DaprProtos.QueryStateRequest.class), any()); QueryStateResponse response = previewClient.queryState(QUERY_STORE_NAME, "query", - new HashMap(){{ put("key", "error"); }}, String.class).block(); + new HashMap(){{ put("key", "error"); }}, String.class).block(); assertNotNull(response); assertEquals(1, response.getResults().size(), "result size must be 1"); assertEquals( "1", response.getResults().get(0).getKey(), "result must be same"); @@ -396,7 +400,7 @@ public void queryStateMetadataError() throws JsonProcessingException { public void tryLock() { DaprProtos.TryLockResponse.Builder builder = DaprProtos.TryLockResponse.newBuilder() - .setSuccess(true); + .setSuccess(true); DaprProtos.TryLockResponse response = builder.build(); @@ -408,7 +412,7 @@ public void tryLock() { assertEquals(10, req.getExpiryInSeconds()); StreamObserver observer = - (StreamObserver) invocation.getArguments()[1]; + (StreamObserver) invocation.getArguments()[1]; observer.onNext(response); observer.onCompleted(); return null; @@ -422,7 +426,7 @@ public void tryLock() { public void unLock() { DaprProtos.UnlockResponse.Builder builder = DaprProtos.UnlockResponse.newBuilder() - .setStatus(DaprProtos.UnlockResponse.Status.SUCCESS); + .setStatus(DaprProtos.UnlockResponse.Status.SUCCESS); DaprProtos.UnlockResponse response = builder.build(); @@ -433,7 +437,7 @@ public void unLock() { assertEquals("owner", req.getLockOwner()); StreamObserver observer = - (StreamObserver) invocation.getArguments()[1]; + (StreamObserver) invocation.getArguments()[1]; observer.onNext(response); observer.onCompleted(); return null; @@ -457,7 +461,7 @@ public void subscribeEventTest() throws Exception { doAnswer((Answer>) invocation -> { StreamObserver observer = - (StreamObserver) invocation.getArguments()[0]; + (StreamObserver) invocation.getArguments()[0]; var emitterThread = new Thread(() -> { try { started.acquire(); @@ -467,27 +471,27 @@ public void subscribeEventTest() throws Exception { observer.onNext(DaprProtos.SubscribeTopicEventsResponseAlpha1.getDefaultInstance()); for (int i = 0; i < numEvents; i++) { observer.onNext(DaprProtos.SubscribeTopicEventsResponseAlpha1.newBuilder() - .setEventMessage(DaprAppCallbackProtos.TopicEventRequest.newBuilder() - .setId(Integer.toString(i)) - .setPubsubName(pubsubName) - .setTopic(topicName) - .setData(ByteString.copyFromUtf8("\"" + data + "\"")) - .setDataContentType("application/json") - .build()) - .build()); + .setEventMessage(DaprAppCallbackProtos.TopicEventRequest.newBuilder() + .setId(Integer.toString(i)) + .setPubsubName(pubsubName) + .setTopic(topicName) + .setData(ByteString.copyFromUtf8("\"" + data + "\"")) + .setDataContentType("application/json") + .build()) + .build()); } for (int i = 0; i < numDrops; i++) { // Bad messages observer.onNext(DaprProtos.SubscribeTopicEventsResponseAlpha1.newBuilder() - .setEventMessage(DaprAppCallbackProtos.TopicEventRequest.newBuilder() - .setId(UUID.randomUUID().toString()) - .setPubsubName("bad pubsub") - .setTopic("bad topic") - .setData(ByteString.copyFromUtf8("\"\"")) - .setDataContentType("application/json") - .build()) - .build()); + .setEventMessage(DaprAppCallbackProtos.TopicEventRequest.newBuilder() + .setId(UUID.randomUUID().toString()) + .setPubsubName("bad pubsub") + .setTopic("bad topic") + .setData(ByteString.copyFromUtf8("\"\"")) + .setDataContentType("application/json") + .build()) + .build()); } observer.onCompleted(); }); @@ -517,38 +521,38 @@ public void onCompleted() { final AtomicInteger errorsToBeEmitted = new AtomicInteger(numErrors); var subscription = previewClient.subscribeToEvents( - "pubsubname", - "topic", - new SubscriptionListener<>() { - @Override - public Mono onEvent(CloudEvent event) { - if (event.getPubsubName().equals(pubsubName) && - event.getTopic().equals(topicName) && - event.getData().equals(data)) { - - // Simulate an error - if ((success.size() == 4 /* some random entry */) && errorsToBeEmitted.decrementAndGet() >= 0) { - throw new RuntimeException("simulated exception on event " + event.getId()); + "pubsubname", + "topic", + new SubscriptionListener<>() { + @Override + public Mono onEvent(CloudEvent event) { + if (event.getPubsubName().equals(pubsubName) && + event.getTopic().equals(topicName) && + event.getData().equals(data)) { + + // Simulate an error + if ((success.size() == 4 /* some random entry */) && errorsToBeEmitted.decrementAndGet() >= 0) { + throw new RuntimeException("simulated exception on event " + event.getId()); + } + + success.add(event.getId()); + if (success.size() >= numEvents) { + gotAll.release(); + } + return Mono.just(Status.SUCCESS); + } + + dropCounter.incrementAndGet(); + return Mono.just(Status.DROP); } - success.add(event.getId()); - if (success.size() >= numEvents) { - gotAll.release(); + @Override + public void onError(RuntimeException exception) { + errors.add(exception.getMessage()); } - return Mono.just(Status.SUCCESS); - } - - dropCounter.incrementAndGet(); - return Mono.just(Status.DROP); - } - - @Override - public void onError(RuntimeException exception) { - errors.add(exception.getMessage()); - } - }, - TypeRef.STRING); + }, + TypeRef.STRING); gotAll.acquire(); subscription.close(); @@ -558,17 +562,152 @@ public void onError(RuntimeException exception) { assertEquals(numErrors, errors.size()); } + @Test + public void converseShouldThrowIllegalArgumentExceptionWhenComponentNameIsNull() throws Exception { + List inputs = new ArrayList<>(); + inputs.add(new ConversationInput("Hello there !")); + + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> + previewClient.converse(new ConversationRequest(null, inputs)).block()); + assertEquals("LLM name cannot be null or empty.", exception.getMessage()); + } + + @Test + public void converseShouldThrowIllegalArgumentExceptionWhenConversationComponentIsEmpty() throws Exception { + List inputs = new ArrayList<>(); + inputs.add(new ConversationInput("Hello there !")); + + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> + previewClient.converse(new ConversationRequest("", inputs)).block()); + assertEquals("LLM name cannot be null or empty.", exception.getMessage()); + } + + @Test + public void converseShouldThrowIllegalArgumentExceptionWhenInputsIsEmpty() throws Exception { + List inputs = new ArrayList<>(); + + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> + previewClient.converse(new ConversationRequest("openai", inputs)).block()); + assertEquals("Conversation inputs cannot be null or empty.", exception.getMessage()); + } + + @Test + public void converseShouldThrowIllegalArgumentExceptionWhenInputsIsNull() throws Exception { + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> + previewClient.converse(new ConversationRequest("openai", null)).block()); + assertEquals("Conversation inputs cannot be null or empty.", exception.getMessage()); + } + + @Test + public void converseShouldThrowIllegalArgumentExceptionWhenInputContentIsNull() throws Exception { + List inputs = new ArrayList<>(); + inputs.add(new ConversationInput(null)); + + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> + previewClient.converse(new ConversationRequest("openai", inputs)).block()); + assertEquals("Conversation input content cannot be null or empty.", exception.getMessage()); + } + + @Test + public void converseShouldThrowIllegalArgumentExceptionWhenInputContentIsEmpty() throws Exception { + List inputs = new ArrayList<>(); + inputs.add(new ConversationInput("")); + + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> + previewClient.converse(new ConversationRequest("openai", inputs)).block()); + assertEquals("Conversation input content cannot be null or empty.", exception.getMessage()); + } + + @Test + public void converseShouldReturnConversationResponseWhenRequiredInputsAreValid() throws Exception { + DaprProtos.ConversationResponse conversationResponse = DaprProtos.ConversationResponse.newBuilder() + .addOutputs(DaprProtos.ConversationResult.newBuilder().setResult("Hello How are you").build()).build(); + + doAnswer(invocation -> { + StreamObserver observer = invocation.getArgument(1); + observer.onNext(conversationResponse); + observer.onCompleted(); + return null; + }).when(daprStub).converseAlpha1(any(DaprProtos.ConversationRequest.class), any()); + + List inputs = new ArrayList<>(); + inputs.add(new ConversationInput("Hello there")); + ConversationResponse response = + previewClient.converse(new ConversationRequest("openai", inputs)).block(); + + ArgumentCaptor captor = + ArgumentCaptor.forClass(DaprProtos.ConversationRequest.class); + verify(daprStub, times(1)).converseAlpha1(captor.capture(), Mockito.any()); + + DaprProtos.ConversationRequest conversationRequest = captor.getValue(); + + assertEquals("openai", conversationRequest.getName()); + assertEquals("Hello there", conversationRequest.getInputs(0).getContent()); + assertEquals("Hello How are you", + response.getConversationOutputs().get(0).getResult()); + } + + @Test + public void converseShouldReturnConversationResponseWhenRequiredAndOptionalInputsAreValid() throws Exception { + DaprProtos.ConversationResponse conversationResponse = DaprProtos.ConversationResponse.newBuilder() + .setContextID("contextId") + .addOutputs(DaprProtos.ConversationResult.newBuilder().setResult("Hello How are you").build()).build(); + + doAnswer(invocation -> { + StreamObserver observer = invocation.getArgument(1); + observer.onNext(conversationResponse); + observer.onCompleted(); + return null; + }).when(daprStub).converseAlpha1(any(DaprProtos.ConversationRequest.class), any()); + + ConversationInput daprConversationInput = new ConversationInput("Hello there") + .setRole("Assistant") + .setScrubPii(true); + + List inputs = new ArrayList<>(); + inputs.add(daprConversationInput); + + ConversationResponse response = + previewClient.converse(new ConversationRequest("openai", inputs) + .setContextId("contextId") + .setScrubPii(true) + .setTemperature(1.1d)).block(); + + ArgumentCaptor captor = + ArgumentCaptor.forClass(DaprProtos.ConversationRequest.class); + verify(daprStub, times(1)).converseAlpha1(captor.capture(), Mockito.any()); + + DaprProtos.ConversationRequest conversationRequest = captor.getValue(); + + assertEquals("openai", conversationRequest.getName()); + assertEquals("contextId", conversationRequest.getContextID()); + assertTrue(conversationRequest.getScrubPII()); + assertEquals(1.1d, conversationRequest.getTemperature(), 0d); + assertEquals("Hello there", conversationRequest.getInputs(0).getContent()); + assertTrue(conversationRequest.getInputs(0).getScrubPII()); + assertEquals("Assistant", conversationRequest.getInputs(0).getRole()); + assertEquals("contextId", response.getContextId()); + assertEquals("Hello How are you", + response.getConversationOutputs().get(0).getResult()); + } + @Test public void scheduleJobShouldSucceedWhenAllFieldsArePresentInRequest() { DateTimeFormatter iso8601Formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") - .withZone(ZoneOffset.UTC); + .withZone(ZoneOffset.UTC); ScheduleJobRequest expectedScheduleJobRequest = new ScheduleJobRequest("testJob", - JobSchedule.fromString("*/5 * * * *")) - .setData("testData".getBytes()) - .setTtl(Instant.now().plus(1, ChronoUnit.DAYS)) - .setRepeat(5) - .setDueTime(Instant.now().plus(10, ChronoUnit.MINUTES)); + JobSchedule.fromString("*/5 * * * *")) + .setData("testData".getBytes()) + .setTtl(Instant.now().plus(1, ChronoUnit.DAYS)) + .setRepeat(5) + .setDueTime(Instant.now().plus(10, ChronoUnit.MINUTES)); doAnswer(invocation -> { StreamObserver observer = invocation.getArgument(1); @@ -579,14 +718,14 @@ public void scheduleJobShouldSucceedWhenAllFieldsArePresentInRequest() { assertDoesNotThrow(() -> previewClient.scheduleJob(expectedScheduleJobRequest).block()); ArgumentCaptor captor = - ArgumentCaptor.forClass(DaprProtos.ScheduleJobRequest.class); + ArgumentCaptor.forClass(DaprProtos.ScheduleJobRequest.class); verify(daprStub, times(1)).scheduleJobAlpha1(captor.capture(), Mockito.any()); DaprProtos.ScheduleJobRequest actualScheduleJobReq = captor.getValue(); assertEquals("testJob", actualScheduleJobReq.getJob().getName()); assertEquals("testData", - new String(actualScheduleJobReq.getJob().getData().getValue().toByteArray(), StandardCharsets.UTF_8)); + new String(actualScheduleJobReq.getJob().getData().getValue().toByteArray(), StandardCharsets.UTF_8)); assertEquals("*/5 * * * *", actualScheduleJobReq.getJob().getSchedule()); assertEquals(iso8601Formatter.format(expectedScheduleJobRequest.getTtl()), actualScheduleJobReq.getJob().getTtl()); assertEquals(expectedScheduleJobRequest.getRepeats(), actualScheduleJobReq.getJob().getRepeats()); @@ -596,7 +735,7 @@ public void scheduleJobShouldSucceedWhenAllFieldsArePresentInRequest() { @Test public void scheduleJobShouldSucceedWhenRequiredFieldsNameAndDueTimeArePresentInRequest() { DateTimeFormatter iso8601Formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") - .withZone(ZoneOffset.UTC); + .withZone(ZoneOffset.UTC); doAnswer(invocation -> { StreamObserver observer = invocation.getArgument(1); @@ -605,11 +744,11 @@ public void scheduleJobShouldSucceedWhenRequiredFieldsNameAndDueTimeArePresentIn }).when(daprStub).scheduleJobAlpha1(any(DaprProtos.ScheduleJobRequest.class), any()); ScheduleJobRequest expectedScheduleJobRequest = - new ScheduleJobRequest("testJob", Instant.now().plus(10, ChronoUnit.MINUTES)); + new ScheduleJobRequest("testJob", Instant.now().plus(10, ChronoUnit.MINUTES)); assertDoesNotThrow(() -> previewClient.scheduleJob(expectedScheduleJobRequest).block()); ArgumentCaptor captor = - ArgumentCaptor.forClass(DaprProtos.ScheduleJobRequest.class); + ArgumentCaptor.forClass(DaprProtos.ScheduleJobRequest.class); verify(daprStub, times(1)).scheduleJobAlpha1(captor.capture(), Mockito.any()); DaprProtos.ScheduleJobRequest actualScheduleJobRequest = captor.getValue(); @@ -620,13 +759,13 @@ public void scheduleJobShouldSucceedWhenRequiredFieldsNameAndDueTimeArePresentIn assertEquals(0, job.getRepeats()); assertFalse(job.hasTtl()); assertEquals(iso8601Formatter.format(expectedScheduleJobRequest.getDueTime()), - actualScheduleJobRequest.getJob().getDueTime()); + actualScheduleJobRequest.getJob().getDueTime()); } @Test public void scheduleJobShouldSucceedWhenRequiredFieldsNameAndScheduleArePresentInRequest() { DateTimeFormatter iso8601Formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") - .withZone(ZoneOffset.UTC); + .withZone(ZoneOffset.UTC); doAnswer(invocation -> { StreamObserver observer = invocation.getArgument(1); @@ -635,11 +774,11 @@ public void scheduleJobShouldSucceedWhenRequiredFieldsNameAndScheduleArePresentI }).when(daprStub).scheduleJobAlpha1(any(DaprProtos.ScheduleJobRequest.class), any()); ScheduleJobRequest expectedScheduleJobRequest = new ScheduleJobRequest("testJob", - JobSchedule.fromString("* * * * * *")); + JobSchedule.fromString("* * * * * *")); assertDoesNotThrow(() -> previewClient.scheduleJob(expectedScheduleJobRequest).block()); ArgumentCaptor captor = - ArgumentCaptor.forClass(DaprProtos.ScheduleJobRequest.class); + ArgumentCaptor.forClass(DaprProtos.ScheduleJobRequest.class); verify(daprStub, times(1)).scheduleJobAlpha1(captor.capture(), Mockito.any()); DaprProtos.ScheduleJobRequest actualScheduleJobRequest = captor.getValue(); @@ -681,24 +820,24 @@ public void scheduleJobShouldThrowWhenNameInRequestIsEmpty() { @Test public void getJobShouldReturnResponseWhenAllFieldsArePresentInRequest() { DateTimeFormatter iso8601Formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") - .withZone(ZoneOffset.UTC); + .withZone(ZoneOffset.UTC); GetJobRequest getJobRequest = new GetJobRequest("testJob"); DaprProtos.Job job = DaprProtos.Job.newBuilder() - .setName("testJob") - .setTtl(OffsetDateTime.now().format(iso8601Formatter)) - .setData(Any.newBuilder().setValue(ByteString.copyFrom("testData".getBytes())).build()) - .setSchedule("*/5 * * * *") - .setRepeats(5) - .setDueTime(iso8601Formatter.format(Instant.now().plus(10, ChronoUnit.MINUTES))) - .build(); + .setName("testJob") + .setTtl(OffsetDateTime.now().format(iso8601Formatter)) + .setData(Any.newBuilder().setValue(ByteString.copyFrom("testData".getBytes())).build()) + .setSchedule("*/5 * * * *") + .setRepeats(5) + .setDueTime(iso8601Formatter.format(Instant.now().plus(10, ChronoUnit.MINUTES))) + .build(); doAnswer(invocation -> { StreamObserver observer = invocation.getArgument(1); observer.onNext(DaprProtos.GetJobResponse.newBuilder() - .setJob(job) - .build()); + .setJob(job) + .build()); observer.onCompleted(); return null; }).when(daprStub).getJobAlpha1(any(DaprProtos.GetJobRequest.class), any()); @@ -720,15 +859,15 @@ public void getJobShouldReturnResponseWithScheduleSetWhenResponseHasSchedule() { GetJobRequest getJobRequest = new GetJobRequest("testJob"); DaprProtos.Job job = DaprProtos.Job.newBuilder() - .setName("testJob") - .setSchedule("0 0 0 1 1 *") - .build(); + .setName("testJob") + .setSchedule("0 0 0 1 1 *") + .build(); doAnswer(invocation -> { StreamObserver observer = invocation.getArgument(1); observer.onNext(DaprProtos.GetJobResponse.newBuilder() - .setJob(job) - .build()); + .setJob(job) + .build()); observer.onCompleted(); return null; }).when(daprStub).getJobAlpha1(any(DaprProtos.GetJobRequest.class), any()); @@ -751,15 +890,15 @@ public void getJobShouldReturnResponseWithDueTimeSetWhenResponseHasDueTime() { String datetime = OffsetDateTime.now().toString(); DaprProtos.Job job = DaprProtos.Job.newBuilder() - .setName("testJob") - .setDueTime(datetime) - .build(); + .setName("testJob") + .setDueTime(datetime) + .build(); doAnswer(invocation -> { StreamObserver observer = invocation.getArgument(1); observer.onNext(DaprProtos.GetJobResponse.newBuilder() - .setJob(job) - .build()); + .setJob(job) + .build()); observer.onCompleted(); return null; }).when(daprStub).getJobAlpha1(any(DaprProtos.GetJobRequest.class), any()); @@ -846,15 +985,15 @@ public void deleteJobShouldThrowWhenNameIsEmptyRequest() { } private DaprProtos.QueryStateResponse buildQueryStateResponse(List> resp,String token) - throws JsonProcessingException { + throws JsonProcessingException { List items = new ArrayList<>(); for (QueryStateItem item: resp) { items.add(buildQueryStateItem(item)); } return DaprProtos.QueryStateResponse.newBuilder() - .addAllResults(items) - .setToken(token) - .build(); + .addAllResults(items) + .setToken(token) + .build(); } private DaprProtos.QueryStateItem buildQueryStateItem(QueryStateItem item) throws JsonProcessingException {