From 30ece0b4494d805435f6d6e955198d908771d4f4 Mon Sep 17 00:00:00 2001 From: Micah Press Date: Wed, 27 Nov 2024 13:29:08 -0800 Subject: [PATCH 1/6] Change conditional logic for constructing non-Azure OpenAI clients With the release of OpenAI's EU endpoint, the client builder needs to create non-Azure OpenAI clients if the endpoint is null or if it matches the default API URL or the EU URL. --- .../azure/ai/openai/OpenAIClientBuilder.java | 7 +++--- .../NonAzureOpenAIClientImpl.java | 22 +++++++++++++++++-- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/OpenAIClientBuilder.java b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/OpenAIClientBuilder.java index 151f613e1fa2..b6d715353103 100644 --- a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/OpenAIClientBuilder.java +++ b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/OpenAIClientBuilder.java @@ -3,6 +3,7 @@ // Code generated by Microsoft (R) TypeSpec Code Generator. package com.azure.ai.openai; +import static com.azure.ai.openai.implementation.NonAzureOpenAIClientImpl.EU_OPEN_AI_ENDPOINT; import static com.azure.ai.openai.implementation.NonAzureOpenAIClientImpl.OPEN_AI_ENDPOINT; import com.azure.ai.openai.implementation.NonAzureOpenAIClientImpl; @@ -333,7 +334,7 @@ private HttpPipeline createHttpPipeline() { private NonAzureOpenAIClientImpl buildInnerNonAzureOpenAIClient() { HttpPipeline localPipeline = (pipeline != null) ? pipeline : createHttpPipeline(); NonAzureOpenAIClientImpl client - = new NonAzureOpenAIClientImpl(localPipeline, JacksonAdapter.createDefaultSerializerAdapter()); + = new NonAzureOpenAIClientImpl(localPipeline, JacksonAdapter.createDefaultSerializerAdapter(), endpoint); return client; } @@ -363,10 +364,10 @@ public OpenAIClient buildClient() { /** * OpenAI service can be used by either not setting the endpoint or by setting the endpoint to start with - * "https://api.openai.com/" + * "https://api.openai.com/" or "https://eu.api.openai.com/". */ private boolean useNonAzureOpenAIService() { - return endpoint == null || endpoint.startsWith(OPEN_AI_ENDPOINT); + return endpoint == null || endpoint.startsWith(OPEN_AI_ENDPOINT) || endpoint.startsWith(EU_OPEN_AI_ENDPOINT); } private void validateClient() { diff --git a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/NonAzureOpenAIClientImpl.java b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/NonAzureOpenAIClientImpl.java index 940766d2e8f8..67fde50c9ad7 100644 --- a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/NonAzureOpenAIClientImpl.java +++ b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/NonAzureOpenAIClientImpl.java @@ -63,20 +63,38 @@ public SerializerAdapter getSerializerAdapter() { return this.serializerAdapter; } + /** The endpoint to send API requests to. */ + private final String endpoint; + /** - * This is the endpoint that non-azure OpenAI supports. Currently, it has only v1 version. + * Gets the endpoint that this client is configured to send requests to. + * + * @return the endpoint value. + */ + public String getEndpoint() { + return this.endpoint; + } + + /** + * This is the generic endpoint that non-azure OpenAI supports. Currently, it has only v1 version. */ public static final String OPEN_AI_ENDPOINT = "https://api.openai.com/v1"; + /** + * This is the EU-based endpoint that non-azure OpenAI supports. Currently, it has only v1 version. + */ + public static final String EU_OPEN_AI_ENDPOINT = "https://eu.api.openai.com/v1"; + /** * Initializes an instance of OpenAIClient client. * * @param httpPipeline The HTTP pipeline to send requests through. * @param serializerAdapter The serializer to serialize an object into a string. */ - public NonAzureOpenAIClientImpl(HttpPipeline httpPipeline, SerializerAdapter serializerAdapter) { + public NonAzureOpenAIClientImpl(HttpPipeline httpPipeline, SerializerAdapter serializerAdapter, String endpoint) { this.httpPipeline = httpPipeline; this.serializerAdapter = serializerAdapter; + this.endpoint = endpoint == null ? OPEN_AI_ENDPOINT : endpoint; this.service = RestProxy.create(NonAzureOpenAIClientService.class, this.httpPipeline, this.getSerializerAdapter()); } From 4666e4be90e41ae9a1cbcd0205f6ad322f5842e2 Mon Sep 17 00:00:00 2001 From: Micah Press Date: Wed, 27 Nov 2024 13:43:20 -0800 Subject: [PATCH 2/6] Use `getEndpoint()` in API call implementations --- .../NonAzureOpenAIClientImpl.java | 99 +++++++++---------- 1 file changed, 48 insertions(+), 51 deletions(-) diff --git a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/NonAzureOpenAIClientImpl.java b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/NonAzureOpenAIClientImpl.java index 67fde50c9ad7..aae01582184b 100644 --- a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/NonAzureOpenAIClientImpl.java +++ b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/NonAzureOpenAIClientImpl.java @@ -605,8 +605,8 @@ public Mono> getEmbeddingsWithResponseAsync(String modelId, final String accept = "application/json"; // modelId is part of the request body in nonAzure OpenAI final BinaryData embeddingsOptionsUpdated = addModelIdJson(embeddingsOptions, modelId); - return FluxUtil.withContext(context -> service.getEmbeddings(OPEN_AI_ENDPOINT, accept, embeddingsOptionsUpdated, - requestOptions, context)); + return FluxUtil.withContext( + context -> service.getEmbeddings(getEndpoint(), accept, embeddingsOptionsUpdated, requestOptions, context)); } /** @@ -662,8 +662,7 @@ public Response getEmbeddingsWithResponse(String modelId, BinaryData final String accept = "application/json"; // modelId is part of the request body in nonAzure OpenAI final BinaryData embeddingsOptionsUpdated = addModelIdJson(embeddingsOptions, modelId); - return service.getEmbeddingsSync(OPEN_AI_ENDPOINT, accept, embeddingsOptionsUpdated, requestOptions, - Context.NONE); + return service.getEmbeddingsSync(getEndpoint(), accept, embeddingsOptionsUpdated, requestOptions, Context.NONE); } /** @@ -754,8 +753,8 @@ public Mono> getCompletionsWithResponseAsync(String modelId final String accept = "application/json"; // modelId is part of the request body in nonAzure OpenAI final BinaryData completionsOptionsUpdated = addModelIdJson(completionsOptions, modelId); - return FluxUtil.withContext(context -> service.getCompletions(OPEN_AI_ENDPOINT, accept, - completionsOptionsUpdated, requestOptions, context)); + return FluxUtil.withContext(context -> service.getCompletions(getEndpoint(), accept, completionsOptionsUpdated, + requestOptions, context)); } /** @@ -844,7 +843,7 @@ public Response getCompletionsWithResponse(String modelId, BinaryDat final String accept = "application/json"; // modelId is part of the request body in nonAzure OpenAI final BinaryData completionsOptionsUpdated = addModelIdJson(completionsOptions, modelId); - return service.getCompletionsSync(OPEN_AI_ENDPOINT, accept, completionsOptionsUpdated, requestOptions, + return service.getCompletionsSync(getEndpoint(), accept, completionsOptionsUpdated, requestOptions, Context.NONE); } @@ -927,7 +926,7 @@ public Mono> getChatCompletionsWithResponseAsync(String mod final String accept = "application/json"; // modelId is part of the request body in nonAzure OpenAI final BinaryData chatCompletionsOptionsUpdated = addModelIdJson(chatCompletionsOptions, modelId); - return FluxUtil.withContext(context -> service.getChatCompletions(OPEN_AI_ENDPOINT, accept, + return FluxUtil.withContext(context -> service.getChatCompletions(getEndpoint(), accept, chatCompletionsOptionsUpdated, requestOptions, context)); } @@ -1009,7 +1008,7 @@ public Response getChatCompletionsWithResponse(String modelId, Binar final String accept = "application/json"; // modelId is part of the request body in nonAzure OpenAI final BinaryData chatCompletionsOptionsUpdated = addModelIdJson(chatCompletionsOptions, modelId); - return service.getChatCompletionsSync(OPEN_AI_ENDPOINT, accept, chatCompletionsOptionsUpdated, requestOptions, + return service.getChatCompletionsSync(getEndpoint(), accept, chatCompletionsOptionsUpdated, requestOptions, Context.NONE); } @@ -1062,7 +1061,7 @@ public Mono> getImageGenerationsWithResponseAsync(String mo final String accept = "application/json"; // modelId is part of the request body in nonAzure OpenAI final BinaryData imageGenerationOptionsUpdated = addModelIdJson(imageGenerationOptions, modelId); - return FluxUtil.withContext(context -> service.getImageGenerations(OPEN_AI_ENDPOINT, accept, + return FluxUtil.withContext(context -> service.getImageGenerations(getEndpoint(), accept, imageGenerationOptionsUpdated, requestOptions, context)); } @@ -1114,7 +1113,7 @@ public Response getImageGenerationsWithResponse(String modelId, Bina RequestOptions requestOptions) { final String accept = "application/json"; final BinaryData imageGenerationOptionsUpdated = addModelIdJson(imageGenerationOptions, modelId); - return service.getImageGenerationsSync(OPEN_AI_ENDPOINT, accept, imageGenerationOptionsUpdated, requestOptions, + return service.getImageGenerationsSync(getEndpoint(), accept, imageGenerationOptionsUpdated, requestOptions, Context.NONE); } @@ -1193,7 +1192,7 @@ public static BinaryData addModelIdJson(BinaryData inputJson, String modelId) { public Mono> getAudioTranscriptionAsResponseObjectWithResponseAsync(String modelId, BinaryData audioTranscriptionOptions, RequestOptions requestOptions) { final String accept = "application/json"; - return FluxUtil.withContext(context -> service.getAudioTranscriptionAsResponseObject(OPEN_AI_ENDPOINT, accept, + return FluxUtil.withContext(context -> service.getAudioTranscriptionAsResponseObject(getEndpoint(), accept, audioTranscriptionOptions, requestOptions, context)); } @@ -1254,7 +1253,7 @@ public Mono> getAudioTranscriptionAsResponseObjectWithRespo public Response getAudioTranscriptionAsResponseObjectWithResponse(String modelId, BinaryData audioTranscriptionOptions, RequestOptions requestOptions) { final String accept = "application/json"; - return service.getAudioTranscriptionAsResponseObjectSync(OPEN_AI_ENDPOINT, accept, audioTranscriptionOptions, + return service.getAudioTranscriptionAsResponseObjectSync(getEndpoint(), accept, audioTranscriptionOptions, requestOptions, Context.NONE); } @@ -1295,7 +1294,7 @@ public Response getAudioTranscriptionAsResponseObjectWithResponse(St public Mono> getAudioTranscriptionAsPlainTextWithResponseAsync(String modelId, BinaryData audioTranscriptionOptions, RequestOptions requestOptions) { final String accept = "application/json"; - return FluxUtil.withContext(context -> service.getAudioTranscriptionAsPlainText(OPEN_AI_ENDPOINT, accept, + return FluxUtil.withContext(context -> service.getAudioTranscriptionAsPlainText(getEndpoint(), accept, audioTranscriptionOptions, requestOptions, context)); } @@ -1335,7 +1334,7 @@ public Mono> getAudioTranscriptionAsPlainTextWithResponseAs public Response getAudioTranscriptionAsPlainTextWithResponse(String modelId, BinaryData audioTranscriptionOptions, RequestOptions requestOptions) { final String accept = "application/json"; - return service.getAudioTranscriptionAsPlainTextSync(OPEN_AI_ENDPOINT, accept, audioTranscriptionOptions, + return service.getAudioTranscriptionAsPlainTextSync(getEndpoint(), accept, audioTranscriptionOptions, requestOptions, Context.NONE); } @@ -1395,7 +1394,7 @@ public Response getAudioTranscriptionAsPlainTextWithResponse(String public Mono> getAudioTranslationAsResponseObjectWithResponseAsync(String deploymentOrModelName, BinaryData audioTranslationOptions, RequestOptions requestOptions) { final String accept = "application/json"; - return FluxUtil.withContext(context -> service.getAudioTranslationAsResponseObject(OPEN_AI_ENDPOINT, accept, + return FluxUtil.withContext(context -> service.getAudioTranslationAsResponseObject(getEndpoint(), accept, audioTranslationOptions, requestOptions, context)); } @@ -1455,7 +1454,7 @@ public Mono> getAudioTranslationAsResponseObjectWithRespons public Response getAudioTranslationAsResponseObjectWithResponse(String modelId, BinaryData audioTranslationOptions, RequestOptions requestOptions) { final String accept = "application/json"; - return service.getAudioTranslationAsResponseObjectSync(OPEN_AI_ENDPOINT, accept, audioTranslationOptions, + return service.getAudioTranslationAsResponseObjectSync(getEndpoint(), accept, audioTranslationOptions, requestOptions, Context.NONE); } @@ -1494,7 +1493,7 @@ public Response getAudioTranslationAsResponseObjectWithResponse(Stri public Mono> getAudioTranslationAsPlainTextWithResponseAsync(String modelId, BinaryData audioTranslationOptions, RequestOptions requestOptions) { final String accept = "application/json"; - return FluxUtil.withContext(context -> service.getAudioTranslationAsPlainText(OPEN_AI_ENDPOINT, accept, + return FluxUtil.withContext(context -> service.getAudioTranslationAsPlainText(getEndpoint(), accept, audioTranslationOptions, requestOptions, context)); } @@ -1533,7 +1532,7 @@ public Mono> getAudioTranslationAsPlainTextWithResponseAsyn public Response getAudioTranslationAsPlainTextWithResponse(String modelId, BinaryData audioTranslationOptions, RequestOptions requestOptions) { final String accept = "application/json"; - return service.getAudioTranslationAsPlainTextSync(OPEN_AI_ENDPOINT, accept, audioTranslationOptions, + return service.getAudioTranslationAsPlainTextSync(getEndpoint(), accept, audioTranslationOptions, requestOptions, Context.NONE); } @@ -1571,7 +1570,7 @@ public Response getAudioTranslationAsPlainTextWithResponse(String mo public Mono> generateSpeechFromTextWithResponseAsync(String modelId, BinaryData speechGenerationOptions, RequestOptions requestOptions) { final String accept = "application/octet-stream, application/json"; - return FluxUtil.withContext(context -> service.generateSpeechFromText(OPEN_AI_ENDPOINT, accept, + return FluxUtil.withContext(context -> service.generateSpeechFromText(getEndpoint(), accept, speechGenerationOptions, requestOptions, context)); } @@ -1608,7 +1607,7 @@ public Mono> generateSpeechFromTextWithResponseAsync(String public Response generateSpeechFromTextWithResponse(BinaryData speechGenerationOptions, RequestOptions requestOptions) { final String accept = "application/octet-stream, application/json"; - return service.generateSpeechFromTextSync(OPEN_AI_ENDPOINT, accept, speechGenerationOptions, requestOptions, + return service.generateSpeechFromTextSync(getEndpoint(), accept, speechGenerationOptions, requestOptions, Context.NONE); } @@ -1653,7 +1652,7 @@ public Response generateSpeechFromTextWithResponse(BinaryData speech @ServiceMethod(returns = ReturnType.SINGLE) public Mono> listFilesWithResponseAsync(RequestOptions requestOptions) { final String accept = "application/json"; - return FluxUtil.withContext(context -> service.listFiles(OPEN_AI_ENDPOINT, accept, requestOptions, context)); + return FluxUtil.withContext(context -> service.listFiles(getEndpoint(), accept, requestOptions, context)); } /** @@ -1697,7 +1696,7 @@ public Mono> listFilesWithResponseAsync(RequestOptions requ @ServiceMethod(returns = ReturnType.SINGLE) public Response listFilesWithResponse(RequestOptions requestOptions) { final String accept = "application/json"; - return service.listFilesSync(OPEN_AI_ENDPOINT, accept, requestOptions, Context.NONE); + return service.listFilesSync(getEndpoint(), accept, requestOptions, Context.NONE); } /** @@ -1731,8 +1730,8 @@ public Mono> uploadFileWithResponseAsync(BinaryData uploadF RequestOptions requestOptions) { final String contentType = "multipart/form-data"; final String accept = "application/json"; - return FluxUtil.withContext(context -> service.uploadFile(OPEN_AI_ENDPOINT, contentType, accept, - uploadFileRequest, requestOptions, context)); + return FluxUtil.withContext(context -> service.uploadFile(getEndpoint(), contentType, accept, uploadFileRequest, + requestOptions, context)); } /** @@ -1764,7 +1763,7 @@ public Mono> uploadFileWithResponseAsync(BinaryData uploadF public Response uploadFileWithResponse(BinaryData uploadFileRequest, RequestOptions requestOptions) { final String contentType = "multipart/form-data"; final String accept = "application/json"; - return service.uploadFileSync(OPEN_AI_ENDPOINT, contentType, accept, uploadFileRequest, requestOptions, + return service.uploadFileSync(getEndpoint(), contentType, accept, uploadFileRequest, requestOptions, Context.NONE); } @@ -1793,7 +1792,7 @@ public Response uploadFileWithResponse(BinaryData uploadFileRequest, public Mono> deleteFileWithResponseAsync(String fileId, RequestOptions requestOptions) { final String accept = "application/json"; return FluxUtil - .withContext(context -> service.deleteFile(OPEN_AI_ENDPOINT, fileId, accept, requestOptions, context)); + .withContext(context -> service.deleteFile(getEndpoint(), fileId, accept, requestOptions, context)); } /** @@ -1819,7 +1818,7 @@ public Mono> deleteFileWithResponseAsync(String fileId, Req @ServiceMethod(returns = ReturnType.SINGLE) public Response deleteFileWithResponse(String fileId, RequestOptions requestOptions) { final String accept = "application/json"; - return service.deleteFileSync(OPEN_AI_ENDPOINT, fileId, accept, requestOptions, Context.NONE); + return service.deleteFileSync(getEndpoint(), fileId, accept, requestOptions, Context.NONE); } /** @@ -1851,8 +1850,7 @@ public Response deleteFileWithResponse(String fileId, RequestOptions @ServiceMethod(returns = ReturnType.SINGLE) public Mono> getFileWithResponseAsync(String fileId, RequestOptions requestOptions) { final String accept = "application/json"; - return FluxUtil - .withContext(context -> service.getFile(OPEN_AI_ENDPOINT, fileId, accept, requestOptions, context)); + return FluxUtil.withContext(context -> service.getFile(getEndpoint(), fileId, accept, requestOptions, context)); } /** @@ -1883,7 +1881,7 @@ public Mono> getFileWithResponseAsync(String fileId, Reques @ServiceMethod(returns = ReturnType.SINGLE) public Response getFileWithResponse(String fileId, RequestOptions requestOptions) { final String accept = "application/json"; - return service.getFileSync(OPEN_AI_ENDPOINT, fileId, accept, requestOptions, Context.NONE); + return service.getFileSync(getEndpoint(), fileId, accept, requestOptions, Context.NONE); } /** @@ -1906,7 +1904,7 @@ public Response getFileWithResponse(String fileId, RequestOptions re public Mono> getFileContentWithResponseAsync(String fileId, RequestOptions requestOptions) { final String accept = "application/json"; return FluxUtil - .withContext(context -> service.getFileContent(OPEN_AI_ENDPOINT, fileId, accept, requestOptions, context)); + .withContext(context -> service.getFileContent(getEndpoint(), fileId, accept, requestOptions, context)); } /** @@ -1928,7 +1926,7 @@ public Mono> getFileContentWithResponseAsync(String fileId, @ServiceMethod(returns = ReturnType.SINGLE) public Response getFileContentWithResponse(String fileId, RequestOptions requestOptions) { final String accept = "application/json"; - return service.getFileContentSync(OPEN_AI_ENDPOINT, fileId, accept, requestOptions, Context.NONE); + return service.getFileContentSync(getEndpoint(), fileId, accept, requestOptions, Context.NONE); } /** @@ -2004,7 +2002,7 @@ public Response getFileContentWithResponse(String fileId, RequestOpt @ServiceMethod(returns = ReturnType.SINGLE) public Mono> listBatchesWithResponseAsync(RequestOptions requestOptions) { final String accept = "application/json"; - return FluxUtil.withContext(context -> service.listBatches(OPEN_AI_ENDPOINT, accept, requestOptions, context)); + return FluxUtil.withContext(context -> service.listBatches(getEndpoint(), accept, requestOptions, context)); } /** @@ -2079,7 +2077,7 @@ public Mono> listBatchesWithResponseAsync(RequestOptions re @ServiceMethod(returns = ReturnType.SINGLE) public Response listBatchesWithResponse(RequestOptions requestOptions) { final String accept = "application/json"; - return service.listBatchesSync(OPEN_AI_ENDPOINT, accept, requestOptions, Context.NONE); + return service.listBatchesSync(getEndpoint(), accept, requestOptions, Context.NONE); } /** @@ -2155,7 +2153,7 @@ public Mono> createBatchWithResponseAsync(BinaryData create RequestOptions requestOptions) { final String accept = "application/json"; return FluxUtil.withContext( - context -> service.createBatch(OPEN_AI_ENDPOINT, accept, createBatchRequest, requestOptions, context)); + context -> service.createBatch(getEndpoint(), accept, createBatchRequest, requestOptions, context)); } /** @@ -2229,7 +2227,7 @@ public Mono> createBatchWithResponseAsync(BinaryData create @ServiceMethod(returns = ReturnType.SINGLE) public Response createBatchWithResponse(BinaryData createBatchRequest, RequestOptions requestOptions) { final String accept = "application/json"; - return service.createBatchSync(OPEN_AI_ENDPOINT, accept, createBatchRequest, requestOptions, Context.NONE); + return service.createBatchSync(getEndpoint(), accept, createBatchRequest, requestOptions, Context.NONE); } /** @@ -2290,7 +2288,7 @@ public Response createBatchWithResponse(BinaryData createBatchReques public Mono> getBatchWithResponseAsync(String batchId, RequestOptions requestOptions) { final String accept = "application/json"; return FluxUtil - .withContext(context -> service.getBatch(OPEN_AI_ENDPOINT, batchId, accept, requestOptions, context)); + .withContext(context -> service.getBatch(getEndpoint(), batchId, accept, requestOptions, context)); } /** @@ -2349,7 +2347,7 @@ public Mono> getBatchWithResponseAsync(String batchId, Requ @ServiceMethod(returns = ReturnType.SINGLE) public Response getBatchWithResponse(String batchId, RequestOptions requestOptions) { final String accept = "application/json"; - return service.getBatchSync(OPEN_AI_ENDPOINT, batchId, accept, requestOptions, Context.NONE); + return service.getBatchSync(getEndpoint(), batchId, accept, requestOptions, Context.NONE); } /** @@ -2410,7 +2408,7 @@ public Response getBatchWithResponse(String batchId, RequestOptions public Mono> cancelBatchWithResponseAsync(String batchId, RequestOptions requestOptions) { final String accept = "application/json"; return FluxUtil - .withContext(context -> service.cancelBatch(OPEN_AI_ENDPOINT, batchId, accept, requestOptions, context)); + .withContext(context -> service.cancelBatch(getEndpoint(), batchId, accept, requestOptions, context)); } /** @@ -2469,7 +2467,7 @@ public Mono> cancelBatchWithResponseAsync(String batchId, R @ServiceMethod(returns = ReturnType.SINGLE) public Response cancelBatchWithResponse(String batchId, RequestOptions requestOptions) { final String accept = "application/json"; - return service.cancelBatchSync(OPEN_AI_ENDPOINT, batchId, accept, requestOptions, Context.NONE); + return service.cancelBatchSync(getEndpoint(), batchId, accept, requestOptions, Context.NONE); } /** @@ -2534,8 +2532,8 @@ public Mono> createUploadWithResponseAsync(BinaryData reque RequestOptions requestOptions) { final String contentType = "application/json"; final String accept = "application/json"; - return FluxUtil.withContext(context -> service.createUpload(OPEN_AI_ENDPOINT, contentType, accept, requestBody, - requestOptions, context)); + return FluxUtil.withContext( + context -> service.createUpload(getEndpoint(), contentType, accept, requestBody, requestOptions, context)); } /** @@ -2598,8 +2596,7 @@ public Mono> createUploadWithResponseAsync(BinaryData reque public Response createUploadWithResponse(BinaryData requestBody, RequestOptions requestOptions) { final String contentType = "application/json"; final String accept = "application/json"; - return service.createUploadSync(OPEN_AI_ENDPOINT, contentType, accept, requestBody, requestOptions, - Context.NONE); + return service.createUploadSync(getEndpoint(), contentType, accept, requestBody, requestOptions, Context.NONE); } /** @@ -2635,7 +2632,7 @@ public Mono> addUploadPartWithResponseAsync(String uploadId RequestOptions requestOptions) { final String contentType = "multipart/form-data"; final String accept = "application/json"; - return FluxUtil.withContext(context -> service.addUploadPart(OPEN_AI_ENDPOINT, contentType, uploadId, accept, + return FluxUtil.withContext(context -> service.addUploadPart(getEndpoint(), contentType, uploadId, accept, requestBody, requestOptions, context)); } @@ -2671,7 +2668,7 @@ public Response addUploadPartWithResponse(String uploadId, BinaryDat RequestOptions requestOptions) { final String contentType = "multipart/form-data"; final String accept = "application/json"; - return service.addUploadPartSync(OPEN_AI_ENDPOINT, contentType, uploadId, accept, requestBody, requestOptions, + return service.addUploadPartSync(getEndpoint(), contentType, uploadId, accept, requestBody, requestOptions, Context.NONE); } @@ -2736,7 +2733,7 @@ public Mono> completeUploadWithResponseAsync(String uploadI RequestOptions requestOptions) { final String contentType = "application/json"; final String accept = "application/json"; - return FluxUtil.withContext(context -> service.completeUpload(OPEN_AI_ENDPOINT, uploadId, contentType, accept, + return FluxUtil.withContext(context -> service.completeUpload(getEndpoint(), uploadId, contentType, accept, requestBody, requestOptions, context)); } @@ -2800,7 +2797,7 @@ public Response completeUploadWithResponse(String uploadId, BinaryDa RequestOptions requestOptions) { final String contentType = "application/json"; final String accept = "application/json"; - return service.completeUploadSync(OPEN_AI_ENDPOINT, uploadId, contentType, accept, requestBody, requestOptions, + return service.completeUploadSync(getEndpoint(), uploadId, contentType, accept, requestBody, requestOptions, Context.NONE); } @@ -2844,7 +2841,7 @@ public Response completeUploadWithResponse(String uploadId, BinaryDa public Mono> cancelUploadWithResponseAsync(String uploadId, RequestOptions requestOptions) { final String accept = "application/json"; return FluxUtil - .withContext(context -> service.cancelUpload(OPEN_AI_ENDPOINT, uploadId, accept, requestOptions, context)); + .withContext(context -> service.cancelUpload(getEndpoint(), uploadId, accept, requestOptions, context)); } /** @@ -2885,6 +2882,6 @@ public Mono> cancelUploadWithResponseAsync(String uploadId, @ServiceMethod(returns = ReturnType.SINGLE) public Response cancelUploadWithResponse(String uploadId, RequestOptions requestOptions) { final String accept = "application/json"; - return service.cancelUploadSync(OPEN_AI_ENDPOINT, uploadId, accept, requestOptions, Context.NONE); + return service.cancelUploadSync(getEndpoint(), uploadId, accept, requestOptions, Context.NONE); } } From c85acd4f960d02e833ba311e96dffe108f894846 Mon Sep 17 00:00:00 2001 From: Micah Press Date: Wed, 27 Nov 2024 13:43:36 -0800 Subject: [PATCH 3/6] Add unit tests for the endpoint configuration option of the client builder --- .../ai/openai/OpenAIClientBuilderTest.java | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/OpenAIClientBuilderTest.java diff --git a/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/OpenAIClientBuilderTest.java b/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/OpenAIClientBuilderTest.java new file mode 100644 index 000000000000..3b5c29b2af71 --- /dev/null +++ b/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/OpenAIClientBuilderTest.java @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.ai.openai; + +import java.lang.reflect.Field; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import org.junit.jupiter.api.Test; +import static com.azure.ai.openai.implementation.NonAzureOpenAIClientImpl.EU_OPEN_AI_ENDPOINT; +import static com.azure.ai.openai.implementation.NonAzureOpenAIClientImpl.OPEN_AI_ENDPOINT; + +public class OpenAIClientBuilderTest { + @Test + public void testInnerImplBasedOnEndpoint() throws NoSuchFieldException, IllegalAccessException { + Field azureClient = OpenAIClient.class.getDeclaredField("serviceClient"); + azureClient.setAccessible(true); + Field nonAzureClient = OpenAIClient.class.getDeclaredField("openAIServiceClient"); + nonAzureClient.setAccessible(true); + + OpenAIClient nullEndpointClient = new OpenAIClientBuilder().buildClient(); + assertNull(azureClient.get(nullEndpointClient)); + assertNotNull(nonAzureClient.get(nullEndpointClient)); + + OpenAIClient customEndpointClient + = new OpenAIClientBuilder().endpoint("https://my.custom.domain/").buildClient(); + assertNotNull(azureClient.get(customEndpointClient)); + assertNull(nonAzureClient.get(customEndpointClient)); + + OpenAIClient defaultEndpointClient = new OpenAIClientBuilder().endpoint(OPEN_AI_ENDPOINT).buildClient(); + assertNull(azureClient.get(defaultEndpointClient)); + assertNotNull(nonAzureClient.get(defaultEndpointClient)); + + OpenAIClient euEndpointClient = new OpenAIClientBuilder().endpoint(EU_OPEN_AI_ENDPOINT).buildClient(); + assertNull(azureClient.get(euEndpointClient)); + assertNotNull(nonAzureClient.get(euEndpointClient)); + } +} From ea6f97e18a924c40988e36a0357906c9b45268e0 Mon Sep 17 00:00:00 2001 From: Micah Press Date: Wed, 27 Nov 2024 11:47:54 -0800 Subject: [PATCH 4/6] Fix capitalization in javadoc --- .../ai/openai/implementation/NonAzureOpenAIClientImpl.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/NonAzureOpenAIClientImpl.java b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/NonAzureOpenAIClientImpl.java index aae01582184b..4d6eb81f6707 100644 --- a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/NonAzureOpenAIClientImpl.java +++ b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/NonAzureOpenAIClientImpl.java @@ -55,7 +55,7 @@ public HttpPipeline getHttpPipeline() { private final SerializerAdapter serializerAdapter; /** - * Gets The serializer to serialize an object into a string. + * Gets the serializer to serialize an object into a string. * * @return the serializerAdapter value. */ From e20a0c3e18b27cf60fca51f15496bcc8317b95c3 Mon Sep 17 00:00:00 2001 From: Micah Press Date: Tue, 10 Dec 2024 15:21:08 -0800 Subject: [PATCH 5/6] Use a regex instead of hard-coded endpoints to determine which client to use This commit generalizes the logic a bit more for deciding whether an Azure OpenAI client should be used or not. The regex allows for a single subdomain in the endpoint as long as the rest of it matches. --- .../azure/ai/openai/OpenAIClientBuilder.java | 10 ++++------ .../NonAzureOpenAIClientImpl.java | 6 ++++-- .../ai/openai/OpenAIClientBuilderTest.java | 19 +++++++++++++++++-- 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/OpenAIClientBuilder.java b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/OpenAIClientBuilder.java index b6d715353103..75b409e44449 100644 --- a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/OpenAIClientBuilder.java +++ b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/OpenAIClientBuilder.java @@ -3,10 +3,8 @@ // Code generated by Microsoft (R) TypeSpec Code Generator. package com.azure.ai.openai; -import static com.azure.ai.openai.implementation.NonAzureOpenAIClientImpl.EU_OPEN_AI_ENDPOINT; -import static com.azure.ai.openai.implementation.NonAzureOpenAIClientImpl.OPEN_AI_ENDPOINT; - import com.azure.ai.openai.implementation.NonAzureOpenAIClientImpl; +import static com.azure.ai.openai.implementation.NonAzureOpenAIClientImpl.OPEN_AI_ENDPOINT_PATTERN; import com.azure.ai.openai.implementation.OpenAIClientImpl; import com.azure.core.annotation.Generated; import com.azure.core.annotation.ServiceClientBuilder; @@ -363,11 +361,11 @@ public OpenAIClient buildClient() { private static final ClientLogger LOGGER = new ClientLogger(OpenAIClientBuilder.class); /** - * OpenAI service can be used by either not setting the endpoint or by setting the endpoint to start with - * "https://api.openai.com/" or "https://eu.api.openai.com/". + * OpenAI service can be used by either not setting the endpoint or by setting the endpoint to a + * URL like "https://api.openai.com/v1" or "https://eu.api.openai.com/v1". */ private boolean useNonAzureOpenAIService() { - return endpoint == null || endpoint.startsWith(OPEN_AI_ENDPOINT) || endpoint.startsWith(EU_OPEN_AI_ENDPOINT); + return endpoint == null || OPEN_AI_ENDPOINT_PATTERN.matcher(endpoint).matches(); } private void validateClient() { diff --git a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/NonAzureOpenAIClientImpl.java b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/NonAzureOpenAIClientImpl.java index 4d6eb81f6707..d84bb2b5a626 100644 --- a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/NonAzureOpenAIClientImpl.java +++ b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/NonAzureOpenAIClientImpl.java @@ -28,6 +28,7 @@ import com.azure.core.util.Context; import com.azure.core.util.FluxUtil; import com.azure.core.util.serializer.SerializerAdapter; +import java.util.regex.Pattern; import reactor.core.publisher.Mono; import java.util.Map; @@ -81,9 +82,10 @@ public String getEndpoint() { public static final String OPEN_AI_ENDPOINT = "https://api.openai.com/v1"; /** - * This is the EU-based endpoint that non-azure OpenAI supports. Currently, it has only v1 version. + * Pattern for validating native OpenAI API endpoint URLs. This allows for subdomains to support + * regional endpoints. */ - public static final String EU_OPEN_AI_ENDPOINT = "https://eu.api.openai.com/v1"; + public static final Pattern OPEN_AI_ENDPOINT_PATTERN = Pattern.compile("https://(\\w*\\.)?api\\.openai\\.com/v1"); /** * Initializes an instance of OpenAIClient client. diff --git a/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/OpenAIClientBuilderTest.java b/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/OpenAIClientBuilderTest.java index 3b5c29b2af71..b5d0ee597c8b 100644 --- a/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/OpenAIClientBuilderTest.java +++ b/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/OpenAIClientBuilderTest.java @@ -2,14 +2,28 @@ // Licensed under the MIT License. package com.azure.ai.openai; +import static com.azure.ai.openai.implementation.NonAzureOpenAIClientImpl.OPEN_AI_ENDPOINT_PATTERN; import java.lang.reflect.Field; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; import org.junit.jupiter.api.Test; -import static com.azure.ai.openai.implementation.NonAzureOpenAIClientImpl.EU_OPEN_AI_ENDPOINT; import static com.azure.ai.openai.implementation.NonAzureOpenAIClientImpl.OPEN_AI_ENDPOINT; public class OpenAIClientBuilderTest { + @Test + public void testEndpointPattern() { + assertTrue(OPEN_AI_ENDPOINT_PATTERN.matcher(OPEN_AI_ENDPOINT).matches()); + assertTrue(OPEN_AI_ENDPOINT_PATTERN.matcher("https://eu.api.openai.com/v1").matches()); + assertTrue(OPEN_AI_ENDPOINT_PATTERN.matcher("https://asdf.api.openai.com/v1").matches()); + assertFalse(OPEN_AI_ENDPOINT_PATTERN.matcher("http://api.openai.com/v1").matches()); + assertFalse(OPEN_AI_ENDPOINT_PATTERN.matcher("https://api.openai.com/").matches()); + assertFalse(OPEN_AI_ENDPOINT_PATTERN.matcher("https://dead.beef.api.openai.com/v1").matches()); + assertFalse(OPEN_AI_ENDPOINT_PATTERN.matcher("https://api.openai.com.org/v1").matches()); + assertFalse(OPEN_AI_ENDPOINT_PATTERN.matcher("https://api.openai.com/v2").matches()); + } + @Test public void testInnerImplBasedOnEndpoint() throws NoSuchFieldException, IllegalAccessException { Field azureClient = OpenAIClient.class.getDeclaredField("serviceClient"); @@ -30,7 +44,8 @@ public void testInnerImplBasedOnEndpoint() throws NoSuchFieldException, IllegalA assertNull(azureClient.get(defaultEndpointClient)); assertNotNull(nonAzureClient.get(defaultEndpointClient)); - OpenAIClient euEndpointClient = new OpenAIClientBuilder().endpoint(EU_OPEN_AI_ENDPOINT).buildClient(); + OpenAIClient euEndpointClient + = new OpenAIClientBuilder().endpoint("https://eu.api.openai.com/v1").buildClient(); assertNull(azureClient.get(euEndpointClient)); assertNotNull(nonAzureClient.get(euEndpointClient)); } From 446114baaf503a036d4dd6753c9939855361b2fb Mon Sep 17 00:00:00 2001 From: Micah Press Date: Wed, 11 Dec 2024 09:44:57 -0800 Subject: [PATCH 6/6] Fix import order for build check --- .../src/main/java/com/azure/ai/openai/OpenAIClientBuilder.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/OpenAIClientBuilder.java b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/OpenAIClientBuilder.java index 75b409e44449..121ab89c37dc 100644 --- a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/OpenAIClientBuilder.java +++ b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/OpenAIClientBuilder.java @@ -3,8 +3,9 @@ // Code generated by Microsoft (R) TypeSpec Code Generator. package com.azure.ai.openai; -import com.azure.ai.openai.implementation.NonAzureOpenAIClientImpl; import static com.azure.ai.openai.implementation.NonAzureOpenAIClientImpl.OPEN_AI_ENDPOINT_PATTERN; + +import com.azure.ai.openai.implementation.NonAzureOpenAIClientImpl; import com.azure.ai.openai.implementation.OpenAIClientImpl; import com.azure.core.annotation.Generated; import com.azure.core.annotation.ServiceClientBuilder;