|
1 | 1 | package com.service; |
2 | 2 |
|
| 3 | +import java.util.HashMap; |
| 4 | +import java.util.List; |
| 5 | +import java.util.Map; |
| 6 | + |
3 | 7 | import io.dapr.client.DaprClientBuilder; |
4 | 8 | import io.dapr.client.DaprPreviewClient; |
5 | | -import io.dapr.client.domain.ConversationInput; |
6 | | -import io.dapr.client.domain.ConversationRequest; |
7 | | -import io.dapr.client.domain.ConversationResponse; |
8 | | -import reactor.core.publisher.Mono; |
9 | | - |
10 | | -import java.util.List; |
| 9 | +import io.dapr.client.domain.ConversationInputAlpha2; |
| 10 | +import io.dapr.client.domain.ConversationMessageContent; |
| 11 | +import io.dapr.client.domain.ConversationRequestAlpha2; |
| 12 | +import io.dapr.client.domain.ConversationResponseAlpha2; |
| 13 | +import io.dapr.client.domain.ConversationResultAlpha2; |
| 14 | +import io.dapr.client.domain.ConversationResultChoices; |
| 15 | +import io.dapr.client.domain.ConversationResultMessage; |
| 16 | +import io.dapr.client.domain.ConversationToolCalls; |
| 17 | +import io.dapr.client.domain.ConversationTools; |
| 18 | +import io.dapr.client.domain.ConversationToolsFunction; |
| 19 | +import io.dapr.client.domain.UserMessage; |
11 | 20 |
|
12 | 21 | public class Conversation { |
13 | 22 |
|
14 | | - public static void main(String[] args) { |
15 | | - String prompt = "What is Dapr?"; |
| 23 | + private static final String CONVERSATION_COMPONENT_NAME = "echo"; |
| 24 | + private static final String CONVERSATION_TEXT = "What is dapr?"; |
| 25 | + private static final String TOOL_CALL_INPUT = "What is the weather like in San Francisco in celsius?"; |
16 | 26 |
|
| 27 | + public static void main(String[] args) { |
17 | 28 | try (DaprPreviewClient client = new DaprClientBuilder().buildPreviewClient()) { |
18 | | - System.out.println("Input: " + prompt); |
19 | 29 |
|
20 | | - ConversationInput daprConversationInput = new ConversationInput(prompt); |
| 30 | + // Define tool function parameters schema |
| 31 | + Map<String, Object> locationProperty = new HashMap<>(); |
| 32 | + locationProperty.put("type", "string"); |
| 33 | + locationProperty.put("description", "The city and state, e.g. San Francisco, CA"); |
| 34 | + |
| 35 | + Map<String, Object> unitProperty = new HashMap<>(); |
| 36 | + unitProperty.put("type", "string"); |
| 37 | + unitProperty.put("enum", List.of("celsius", "fahrenheit")); |
| 38 | + unitProperty.put("description", "The temperature unit to use"); |
| 39 | + |
| 40 | + Map<String, Object> properties = new HashMap<>(); |
| 41 | + properties.put("location", locationProperty); |
| 42 | + properties.put("unit", unitProperty); |
| 43 | + |
| 44 | + Map<String, Object> parameters = new HashMap<>(); |
| 45 | + parameters.put("type", "object"); |
| 46 | + parameters.put("properties", properties); |
| 47 | + parameters.put("required", List.of("location")); |
| 48 | + |
| 49 | + // Create the tool function |
| 50 | + ConversationToolsFunction getWeatherFunction = new ConversationToolsFunction("get_weather", parameters) |
| 51 | + .setDescription("Get the current weather for a location"); |
| 52 | + ConversationTools weatherTool = new ConversationTools(getWeatherFunction); |
| 53 | + |
| 54 | + // ========================================== |
| 55 | + // Simple Conversation |
| 56 | + // ========================================== |
| 57 | + System.out.println("=== Simple Conversation ==="); |
| 58 | + |
| 59 | + UserMessage conversationMessage = new UserMessage( |
| 60 | + List.of(new ConversationMessageContent(CONVERSATION_TEXT))) |
| 61 | + .setName("TestUser"); |
| 62 | + ConversationInputAlpha2 conversationInput = new ConversationInputAlpha2(List.of(conversationMessage)); |
| 63 | + |
| 64 | + ConversationResponseAlpha2 conversationResponse = client.converseAlpha2( |
| 65 | + new ConversationRequestAlpha2(CONVERSATION_COMPONENT_NAME, List.of(conversationInput)) |
| 66 | + .setScrubPii(false) |
| 67 | + .setToolChoice("auto") |
| 68 | + .setTemperature(0.7) |
| 69 | + .setTools(List.of(weatherTool))).block(); |
| 70 | + |
| 71 | + System.out.println("Conversation input sent: " + CONVERSATION_TEXT); |
| 72 | + String outputContent = conversationResponse.getOutputs().get(0) |
| 73 | + .getChoices().get(0).getMessage().getContent(); |
| 74 | + System.out.println("Output response: " + outputContent); |
| 75 | + |
| 76 | + // ========================================== |
| 77 | + // Tool Calling |
| 78 | + // ========================================== |
| 79 | + System.out.println("\n=== Tool Calling ==="); |
| 80 | + |
| 81 | + UserMessage toolCallMessage = new UserMessage( |
| 82 | + List.of(new ConversationMessageContent(TOOL_CALL_INPUT))) |
| 83 | + .setName("TestUser"); |
| 84 | + ConversationInputAlpha2 toolCallInput = new ConversationInputAlpha2(List.of(toolCallMessage)); |
| 85 | + |
| 86 | + ConversationResponseAlpha2 toolCallResponse = client.converseAlpha2( |
| 87 | + new ConversationRequestAlpha2(CONVERSATION_COMPONENT_NAME, List.of(toolCallInput)) |
| 88 | + .setScrubPii(false) |
| 89 | + .setToolChoice("auto") |
| 90 | + .setTemperature(0.7) |
| 91 | + .setTools(List.of(weatherTool))).block(); |
| 92 | + |
| 93 | + System.out.println("Tool calling input sent: " + TOOL_CALL_INPUT); |
| 94 | + |
| 95 | + ConversationResultAlpha2 result = toolCallResponse.getOutputs().get(0); |
| 96 | + ConversationResultChoices choice = result.getChoices().get(0); |
| 97 | + ConversationResultMessage message = choice.getMessage(); |
| 98 | + |
| 99 | + System.out.println("Output message: " + message.getContent()); |
| 100 | + |
| 101 | + if (message.hasToolCalls()) { |
| 102 | + System.out.println("Tool calls detected:"); |
| 103 | + for (ConversationToolCalls toolCall : message.getToolCalls()) { |
| 104 | + String functionName = toolCall.getFunction().getName(); |
| 105 | + String functionArguments = toolCall.getFunction().getArguments(); |
| 106 | + |
| 107 | + System.out.println("Tool call: {\"id\": \"" + toolCall.getId() |
| 108 | + + "\", \"function\": {\"name\": \"" + functionName |
| 109 | + + "\", \"arguments\": " + functionArguments + "}}"); |
| 110 | + System.out.println("Function name: " + functionName); |
| 111 | + System.out.println("Function arguments: " + functionArguments); |
| 112 | + } |
| 113 | + } |
21 | 114 |
|
22 | | - // Component name is the name provided in the metadata block of the conversation.yaml file. |
23 | | - Mono<ConversationResponse> responseMono = client.converse(new ConversationRequest("echo", |
24 | | - List.of(daprConversationInput)) |
25 | | - .setContextId("contextId") |
26 | | - .setScrubPii(true).setTemperature(1.1d)); |
27 | | - ConversationResponse response = responseMono.block(); |
28 | | - System.out.printf("Output response: %s", response.getConversationOutputs().get(0).getResult()); |
29 | 115 | } catch (Exception e) { |
30 | 116 | throw new RuntimeException(e); |
31 | 117 | } |
|
0 commit comments