diff --git a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/OpenAIClientBuilder.java b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/OpenAIClientBuilder.java index 151f613e1fa2..121ab89c37dc 100644 --- a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/OpenAIClientBuilder.java +++ b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/OpenAIClientBuilder.java @@ -3,7 +3,7 @@ // Code generated by Microsoft (R) TypeSpec Code Generator. package com.azure.ai.openai; -import static com.azure.ai.openai.implementation.NonAzureOpenAIClientImpl.OPEN_AI_ENDPOINT; +import static com.azure.ai.openai.implementation.NonAzureOpenAIClientImpl.OPEN_AI_ENDPOINT_PATTERN; import com.azure.ai.openai.implementation.NonAzureOpenAIClientImpl; import com.azure.ai.openai.implementation.OpenAIClientImpl; @@ -333,7 +333,7 @@ private HttpPipeline createHttpPipeline() { private NonAzureOpenAIClientImpl buildInnerNonAzureOpenAIClient() { HttpPipeline localPipeline = (pipeline != null) ? pipeline : createHttpPipeline(); NonAzureOpenAIClientImpl client - = new NonAzureOpenAIClientImpl(localPipeline, JacksonAdapter.createDefaultSerializerAdapter()); + = new NonAzureOpenAIClientImpl(localPipeline, JacksonAdapter.createDefaultSerializerAdapter(), endpoint); return client; } @@ -362,11 +362,11 @@ public OpenAIClient buildClient() { private static final ClientLogger LOGGER = new ClientLogger(OpenAIClientBuilder.class); /** - * OpenAI service can be used by either not setting the endpoint or by setting the endpoint to start with - * "https://api.openai.com/" + * OpenAI service can be used by either not setting the endpoint or by setting the endpoint to a + * URL like "https://api.openai.com/v1" or "https://eu.api.openai.com/v1". */ private boolean useNonAzureOpenAIService() { - return endpoint == null || endpoint.startsWith(OPEN_AI_ENDPOINT); + return endpoint == null || OPEN_AI_ENDPOINT_PATTERN.matcher(endpoint).matches(); } private void validateClient() { diff --git a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/NonAzureOpenAIClientImpl.java b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/NonAzureOpenAIClientImpl.java index 940766d2e8f8..d84bb2b5a626 100644 --- a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/NonAzureOpenAIClientImpl.java +++ b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/NonAzureOpenAIClientImpl.java @@ -28,6 +28,7 @@ import com.azure.core.util.Context; import com.azure.core.util.FluxUtil; import com.azure.core.util.serializer.SerializerAdapter; +import java.util.regex.Pattern; import reactor.core.publisher.Mono; import java.util.Map; @@ -55,7 +56,7 @@ public HttpPipeline getHttpPipeline() { private final SerializerAdapter serializerAdapter; /** - * Gets The serializer to serialize an object into a string. + * Gets the serializer to serialize an object into a string. * * @return the serializerAdapter value. */ @@ -63,20 +64,39 @@ public SerializerAdapter getSerializerAdapter() { return this.serializerAdapter; } + /** The endpoint to send API requests to. */ + private final String endpoint; + /** - * This is the endpoint that non-azure OpenAI supports. Currently, it has only v1 version. + * Gets the endpoint that this client is configured to send requests to. + * + * @return the endpoint value. + */ + public String getEndpoint() { + return this.endpoint; + } + + /** + * This is the generic endpoint that non-azure OpenAI supports. Currently, it has only v1 version. */ public static final String OPEN_AI_ENDPOINT = "https://api.openai.com/v1"; + /** + * Pattern for validating native OpenAI API endpoint URLs. This allows for subdomains to support + * regional endpoints. + */ + public static final Pattern OPEN_AI_ENDPOINT_PATTERN = Pattern.compile("https://(\\w*\\.)?api\\.openai\\.com/v1"); + /** * Initializes an instance of OpenAIClient client. * * @param httpPipeline The HTTP pipeline to send requests through. * @param serializerAdapter The serializer to serialize an object into a string. */ - public NonAzureOpenAIClientImpl(HttpPipeline httpPipeline, SerializerAdapter serializerAdapter) { + public NonAzureOpenAIClientImpl(HttpPipeline httpPipeline, SerializerAdapter serializerAdapter, String endpoint) { this.httpPipeline = httpPipeline; this.serializerAdapter = serializerAdapter; + this.endpoint = endpoint == null ? OPEN_AI_ENDPOINT : endpoint; this.service = RestProxy.create(NonAzureOpenAIClientService.class, this.httpPipeline, this.getSerializerAdapter()); } @@ -587,8 +607,8 @@ public Mono> getEmbeddingsWithResponseAsync(String modelId, final String accept = "application/json"; // modelId is part of the request body in nonAzure OpenAI final BinaryData embeddingsOptionsUpdated = addModelIdJson(embeddingsOptions, modelId); - return FluxUtil.withContext(context -> service.getEmbeddings(OPEN_AI_ENDPOINT, accept, embeddingsOptionsUpdated, - requestOptions, context)); + return FluxUtil.withContext( + context -> service.getEmbeddings(getEndpoint(), accept, embeddingsOptionsUpdated, requestOptions, context)); } /** @@ -644,8 +664,7 @@ public Response getEmbeddingsWithResponse(String modelId, BinaryData final String accept = "application/json"; // modelId is part of the request body in nonAzure OpenAI final BinaryData embeddingsOptionsUpdated = addModelIdJson(embeddingsOptions, modelId); - return service.getEmbeddingsSync(OPEN_AI_ENDPOINT, accept, embeddingsOptionsUpdated, requestOptions, - Context.NONE); + return service.getEmbeddingsSync(getEndpoint(), accept, embeddingsOptionsUpdated, requestOptions, Context.NONE); } /** @@ -736,8 +755,8 @@ public Mono> getCompletionsWithResponseAsync(String modelId final String accept = "application/json"; // modelId is part of the request body in nonAzure OpenAI final BinaryData completionsOptionsUpdated = addModelIdJson(completionsOptions, modelId); - return FluxUtil.withContext(context -> service.getCompletions(OPEN_AI_ENDPOINT, accept, - completionsOptionsUpdated, requestOptions, context)); + return FluxUtil.withContext(context -> service.getCompletions(getEndpoint(), accept, completionsOptionsUpdated, + requestOptions, context)); } /** @@ -826,7 +845,7 @@ public Response getCompletionsWithResponse(String modelId, BinaryDat final String accept = "application/json"; // modelId is part of the request body in nonAzure OpenAI final BinaryData completionsOptionsUpdated = addModelIdJson(completionsOptions, modelId); - return service.getCompletionsSync(OPEN_AI_ENDPOINT, accept, completionsOptionsUpdated, requestOptions, + return service.getCompletionsSync(getEndpoint(), accept, completionsOptionsUpdated, requestOptions, Context.NONE); } @@ -909,7 +928,7 @@ public Mono> getChatCompletionsWithResponseAsync(String mod final String accept = "application/json"; // modelId is part of the request body in nonAzure OpenAI final BinaryData chatCompletionsOptionsUpdated = addModelIdJson(chatCompletionsOptions, modelId); - return FluxUtil.withContext(context -> service.getChatCompletions(OPEN_AI_ENDPOINT, accept, + return FluxUtil.withContext(context -> service.getChatCompletions(getEndpoint(), accept, chatCompletionsOptionsUpdated, requestOptions, context)); } @@ -991,7 +1010,7 @@ public Response getChatCompletionsWithResponse(String modelId, Binar final String accept = "application/json"; // modelId is part of the request body in nonAzure OpenAI final BinaryData chatCompletionsOptionsUpdated = addModelIdJson(chatCompletionsOptions, modelId); - return service.getChatCompletionsSync(OPEN_AI_ENDPOINT, accept, chatCompletionsOptionsUpdated, requestOptions, + return service.getChatCompletionsSync(getEndpoint(), accept, chatCompletionsOptionsUpdated, requestOptions, Context.NONE); } @@ -1044,7 +1063,7 @@ public Mono> getImageGenerationsWithResponseAsync(String mo final String accept = "application/json"; // modelId is part of the request body in nonAzure OpenAI final BinaryData imageGenerationOptionsUpdated = addModelIdJson(imageGenerationOptions, modelId); - return FluxUtil.withContext(context -> service.getImageGenerations(OPEN_AI_ENDPOINT, accept, + return FluxUtil.withContext(context -> service.getImageGenerations(getEndpoint(), accept, imageGenerationOptionsUpdated, requestOptions, context)); } @@ -1096,7 +1115,7 @@ public Response getImageGenerationsWithResponse(String modelId, Bina RequestOptions requestOptions) { final String accept = "application/json"; final BinaryData imageGenerationOptionsUpdated = addModelIdJson(imageGenerationOptions, modelId); - return service.getImageGenerationsSync(OPEN_AI_ENDPOINT, accept, imageGenerationOptionsUpdated, requestOptions, + return service.getImageGenerationsSync(getEndpoint(), accept, imageGenerationOptionsUpdated, requestOptions, Context.NONE); } @@ -1175,7 +1194,7 @@ public static BinaryData addModelIdJson(BinaryData inputJson, String modelId) { public Mono> getAudioTranscriptionAsResponseObjectWithResponseAsync(String modelId, BinaryData audioTranscriptionOptions, RequestOptions requestOptions) { final String accept = "application/json"; - return FluxUtil.withContext(context -> service.getAudioTranscriptionAsResponseObject(OPEN_AI_ENDPOINT, accept, + return FluxUtil.withContext(context -> service.getAudioTranscriptionAsResponseObject(getEndpoint(), accept, audioTranscriptionOptions, requestOptions, context)); } @@ -1236,7 +1255,7 @@ public Mono> getAudioTranscriptionAsResponseObjectWithRespo public Response getAudioTranscriptionAsResponseObjectWithResponse(String modelId, BinaryData audioTranscriptionOptions, RequestOptions requestOptions) { final String accept = "application/json"; - return service.getAudioTranscriptionAsResponseObjectSync(OPEN_AI_ENDPOINT, accept, audioTranscriptionOptions, + return service.getAudioTranscriptionAsResponseObjectSync(getEndpoint(), accept, audioTranscriptionOptions, requestOptions, Context.NONE); } @@ -1277,7 +1296,7 @@ public Response getAudioTranscriptionAsResponseObjectWithResponse(St public Mono> getAudioTranscriptionAsPlainTextWithResponseAsync(String modelId, BinaryData audioTranscriptionOptions, RequestOptions requestOptions) { final String accept = "application/json"; - return FluxUtil.withContext(context -> service.getAudioTranscriptionAsPlainText(OPEN_AI_ENDPOINT, accept, + return FluxUtil.withContext(context -> service.getAudioTranscriptionAsPlainText(getEndpoint(), accept, audioTranscriptionOptions, requestOptions, context)); } @@ -1317,7 +1336,7 @@ public Mono> getAudioTranscriptionAsPlainTextWithResponseAs public Response getAudioTranscriptionAsPlainTextWithResponse(String modelId, BinaryData audioTranscriptionOptions, RequestOptions requestOptions) { final String accept = "application/json"; - return service.getAudioTranscriptionAsPlainTextSync(OPEN_AI_ENDPOINT, accept, audioTranscriptionOptions, + return service.getAudioTranscriptionAsPlainTextSync(getEndpoint(), accept, audioTranscriptionOptions, requestOptions, Context.NONE); } @@ -1377,7 +1396,7 @@ public Response getAudioTranscriptionAsPlainTextWithResponse(String public Mono> getAudioTranslationAsResponseObjectWithResponseAsync(String deploymentOrModelName, BinaryData audioTranslationOptions, RequestOptions requestOptions) { final String accept = "application/json"; - return FluxUtil.withContext(context -> service.getAudioTranslationAsResponseObject(OPEN_AI_ENDPOINT, accept, + return FluxUtil.withContext(context -> service.getAudioTranslationAsResponseObject(getEndpoint(), accept, audioTranslationOptions, requestOptions, context)); } @@ -1437,7 +1456,7 @@ public Mono> getAudioTranslationAsResponseObjectWithRespons public Response getAudioTranslationAsResponseObjectWithResponse(String modelId, BinaryData audioTranslationOptions, RequestOptions requestOptions) { final String accept = "application/json"; - return service.getAudioTranslationAsResponseObjectSync(OPEN_AI_ENDPOINT, accept, audioTranslationOptions, + return service.getAudioTranslationAsResponseObjectSync(getEndpoint(), accept, audioTranslationOptions, requestOptions, Context.NONE); } @@ -1476,7 +1495,7 @@ public Response getAudioTranslationAsResponseObjectWithResponse(Stri public Mono> getAudioTranslationAsPlainTextWithResponseAsync(String modelId, BinaryData audioTranslationOptions, RequestOptions requestOptions) { final String accept = "application/json"; - return FluxUtil.withContext(context -> service.getAudioTranslationAsPlainText(OPEN_AI_ENDPOINT, accept, + return FluxUtil.withContext(context -> service.getAudioTranslationAsPlainText(getEndpoint(), accept, audioTranslationOptions, requestOptions, context)); } @@ -1515,7 +1534,7 @@ public Mono> getAudioTranslationAsPlainTextWithResponseAsyn public Response getAudioTranslationAsPlainTextWithResponse(String modelId, BinaryData audioTranslationOptions, RequestOptions requestOptions) { final String accept = "application/json"; - return service.getAudioTranslationAsPlainTextSync(OPEN_AI_ENDPOINT, accept, audioTranslationOptions, + return service.getAudioTranslationAsPlainTextSync(getEndpoint(), accept, audioTranslationOptions, requestOptions, Context.NONE); } @@ -1553,7 +1572,7 @@ public Response getAudioTranslationAsPlainTextWithResponse(String mo public Mono> generateSpeechFromTextWithResponseAsync(String modelId, BinaryData speechGenerationOptions, RequestOptions requestOptions) { final String accept = "application/octet-stream, application/json"; - return FluxUtil.withContext(context -> service.generateSpeechFromText(OPEN_AI_ENDPOINT, accept, + return FluxUtil.withContext(context -> service.generateSpeechFromText(getEndpoint(), accept, speechGenerationOptions, requestOptions, context)); } @@ -1590,7 +1609,7 @@ public Mono> generateSpeechFromTextWithResponseAsync(String public Response generateSpeechFromTextWithResponse(BinaryData speechGenerationOptions, RequestOptions requestOptions) { final String accept = "application/octet-stream, application/json"; - return service.generateSpeechFromTextSync(OPEN_AI_ENDPOINT, accept, speechGenerationOptions, requestOptions, + return service.generateSpeechFromTextSync(getEndpoint(), accept, speechGenerationOptions, requestOptions, Context.NONE); } @@ -1635,7 +1654,7 @@ public Response generateSpeechFromTextWithResponse(BinaryData speech @ServiceMethod(returns = ReturnType.SINGLE) public Mono> listFilesWithResponseAsync(RequestOptions requestOptions) { final String accept = "application/json"; - return FluxUtil.withContext(context -> service.listFiles(OPEN_AI_ENDPOINT, accept, requestOptions, context)); + return FluxUtil.withContext(context -> service.listFiles(getEndpoint(), accept, requestOptions, context)); } /** @@ -1679,7 +1698,7 @@ public Mono> listFilesWithResponseAsync(RequestOptions requ @ServiceMethod(returns = ReturnType.SINGLE) public Response listFilesWithResponse(RequestOptions requestOptions) { final String accept = "application/json"; - return service.listFilesSync(OPEN_AI_ENDPOINT, accept, requestOptions, Context.NONE); + return service.listFilesSync(getEndpoint(), accept, requestOptions, Context.NONE); } /** @@ -1713,8 +1732,8 @@ public Mono> uploadFileWithResponseAsync(BinaryData uploadF RequestOptions requestOptions) { final String contentType = "multipart/form-data"; final String accept = "application/json"; - return FluxUtil.withContext(context -> service.uploadFile(OPEN_AI_ENDPOINT, contentType, accept, - uploadFileRequest, requestOptions, context)); + return FluxUtil.withContext(context -> service.uploadFile(getEndpoint(), contentType, accept, uploadFileRequest, + requestOptions, context)); } /** @@ -1746,7 +1765,7 @@ public Mono> uploadFileWithResponseAsync(BinaryData uploadF public Response uploadFileWithResponse(BinaryData uploadFileRequest, RequestOptions requestOptions) { final String contentType = "multipart/form-data"; final String accept = "application/json"; - return service.uploadFileSync(OPEN_AI_ENDPOINT, contentType, accept, uploadFileRequest, requestOptions, + return service.uploadFileSync(getEndpoint(), contentType, accept, uploadFileRequest, requestOptions, Context.NONE); } @@ -1775,7 +1794,7 @@ public Response uploadFileWithResponse(BinaryData uploadFileRequest, public Mono> deleteFileWithResponseAsync(String fileId, RequestOptions requestOptions) { final String accept = "application/json"; return FluxUtil - .withContext(context -> service.deleteFile(OPEN_AI_ENDPOINT, fileId, accept, requestOptions, context)); + .withContext(context -> service.deleteFile(getEndpoint(), fileId, accept, requestOptions, context)); } /** @@ -1801,7 +1820,7 @@ public Mono> deleteFileWithResponseAsync(String fileId, Req @ServiceMethod(returns = ReturnType.SINGLE) public Response deleteFileWithResponse(String fileId, RequestOptions requestOptions) { final String accept = "application/json"; - return service.deleteFileSync(OPEN_AI_ENDPOINT, fileId, accept, requestOptions, Context.NONE); + return service.deleteFileSync(getEndpoint(), fileId, accept, requestOptions, Context.NONE); } /** @@ -1833,8 +1852,7 @@ public Response deleteFileWithResponse(String fileId, RequestOptions @ServiceMethod(returns = ReturnType.SINGLE) public Mono> getFileWithResponseAsync(String fileId, RequestOptions requestOptions) { final String accept = "application/json"; - return FluxUtil - .withContext(context -> service.getFile(OPEN_AI_ENDPOINT, fileId, accept, requestOptions, context)); + return FluxUtil.withContext(context -> service.getFile(getEndpoint(), fileId, accept, requestOptions, context)); } /** @@ -1865,7 +1883,7 @@ public Mono> getFileWithResponseAsync(String fileId, Reques @ServiceMethod(returns = ReturnType.SINGLE) public Response getFileWithResponse(String fileId, RequestOptions requestOptions) { final String accept = "application/json"; - return service.getFileSync(OPEN_AI_ENDPOINT, fileId, accept, requestOptions, Context.NONE); + return service.getFileSync(getEndpoint(), fileId, accept, requestOptions, Context.NONE); } /** @@ -1888,7 +1906,7 @@ public Response getFileWithResponse(String fileId, RequestOptions re public Mono> getFileContentWithResponseAsync(String fileId, RequestOptions requestOptions) { final String accept = "application/json"; return FluxUtil - .withContext(context -> service.getFileContent(OPEN_AI_ENDPOINT, fileId, accept, requestOptions, context)); + .withContext(context -> service.getFileContent(getEndpoint(), fileId, accept, requestOptions, context)); } /** @@ -1910,7 +1928,7 @@ public Mono> getFileContentWithResponseAsync(String fileId, @ServiceMethod(returns = ReturnType.SINGLE) public Response getFileContentWithResponse(String fileId, RequestOptions requestOptions) { final String accept = "application/json"; - return service.getFileContentSync(OPEN_AI_ENDPOINT, fileId, accept, requestOptions, Context.NONE); + return service.getFileContentSync(getEndpoint(), fileId, accept, requestOptions, Context.NONE); } /** @@ -1986,7 +2004,7 @@ public Response getFileContentWithResponse(String fileId, RequestOpt @ServiceMethod(returns = ReturnType.SINGLE) public Mono> listBatchesWithResponseAsync(RequestOptions requestOptions) { final String accept = "application/json"; - return FluxUtil.withContext(context -> service.listBatches(OPEN_AI_ENDPOINT, accept, requestOptions, context)); + return FluxUtil.withContext(context -> service.listBatches(getEndpoint(), accept, requestOptions, context)); } /** @@ -2061,7 +2079,7 @@ public Mono> listBatchesWithResponseAsync(RequestOptions re @ServiceMethod(returns = ReturnType.SINGLE) public Response listBatchesWithResponse(RequestOptions requestOptions) { final String accept = "application/json"; - return service.listBatchesSync(OPEN_AI_ENDPOINT, accept, requestOptions, Context.NONE); + return service.listBatchesSync(getEndpoint(), accept, requestOptions, Context.NONE); } /** @@ -2137,7 +2155,7 @@ public Mono> createBatchWithResponseAsync(BinaryData create RequestOptions requestOptions) { final String accept = "application/json"; return FluxUtil.withContext( - context -> service.createBatch(OPEN_AI_ENDPOINT, accept, createBatchRequest, requestOptions, context)); + context -> service.createBatch(getEndpoint(), accept, createBatchRequest, requestOptions, context)); } /** @@ -2211,7 +2229,7 @@ public Mono> createBatchWithResponseAsync(BinaryData create @ServiceMethod(returns = ReturnType.SINGLE) public Response createBatchWithResponse(BinaryData createBatchRequest, RequestOptions requestOptions) { final String accept = "application/json"; - return service.createBatchSync(OPEN_AI_ENDPOINT, accept, createBatchRequest, requestOptions, Context.NONE); + return service.createBatchSync(getEndpoint(), accept, createBatchRequest, requestOptions, Context.NONE); } /** @@ -2272,7 +2290,7 @@ public Response createBatchWithResponse(BinaryData createBatchReques public Mono> getBatchWithResponseAsync(String batchId, RequestOptions requestOptions) { final String accept = "application/json"; return FluxUtil - .withContext(context -> service.getBatch(OPEN_AI_ENDPOINT, batchId, accept, requestOptions, context)); + .withContext(context -> service.getBatch(getEndpoint(), batchId, accept, requestOptions, context)); } /** @@ -2331,7 +2349,7 @@ public Mono> getBatchWithResponseAsync(String batchId, Requ @ServiceMethod(returns = ReturnType.SINGLE) public Response getBatchWithResponse(String batchId, RequestOptions requestOptions) { final String accept = "application/json"; - return service.getBatchSync(OPEN_AI_ENDPOINT, batchId, accept, requestOptions, Context.NONE); + return service.getBatchSync(getEndpoint(), batchId, accept, requestOptions, Context.NONE); } /** @@ -2392,7 +2410,7 @@ public Response getBatchWithResponse(String batchId, RequestOptions public Mono> cancelBatchWithResponseAsync(String batchId, RequestOptions requestOptions) { final String accept = "application/json"; return FluxUtil - .withContext(context -> service.cancelBatch(OPEN_AI_ENDPOINT, batchId, accept, requestOptions, context)); + .withContext(context -> service.cancelBatch(getEndpoint(), batchId, accept, requestOptions, context)); } /** @@ -2451,7 +2469,7 @@ public Mono> cancelBatchWithResponseAsync(String batchId, R @ServiceMethod(returns = ReturnType.SINGLE) public Response cancelBatchWithResponse(String batchId, RequestOptions requestOptions) { final String accept = "application/json"; - return service.cancelBatchSync(OPEN_AI_ENDPOINT, batchId, accept, requestOptions, Context.NONE); + return service.cancelBatchSync(getEndpoint(), batchId, accept, requestOptions, Context.NONE); } /** @@ -2516,8 +2534,8 @@ public Mono> createUploadWithResponseAsync(BinaryData reque RequestOptions requestOptions) { final String contentType = "application/json"; final String accept = "application/json"; - return FluxUtil.withContext(context -> service.createUpload(OPEN_AI_ENDPOINT, contentType, accept, requestBody, - requestOptions, context)); + return FluxUtil.withContext( + context -> service.createUpload(getEndpoint(), contentType, accept, requestBody, requestOptions, context)); } /** @@ -2580,8 +2598,7 @@ public Mono> createUploadWithResponseAsync(BinaryData reque public Response createUploadWithResponse(BinaryData requestBody, RequestOptions requestOptions) { final String contentType = "application/json"; final String accept = "application/json"; - return service.createUploadSync(OPEN_AI_ENDPOINT, contentType, accept, requestBody, requestOptions, - Context.NONE); + return service.createUploadSync(getEndpoint(), contentType, accept, requestBody, requestOptions, Context.NONE); } /** @@ -2617,7 +2634,7 @@ public Mono> addUploadPartWithResponseAsync(String uploadId RequestOptions requestOptions) { final String contentType = "multipart/form-data"; final String accept = "application/json"; - return FluxUtil.withContext(context -> service.addUploadPart(OPEN_AI_ENDPOINT, contentType, uploadId, accept, + return FluxUtil.withContext(context -> service.addUploadPart(getEndpoint(), contentType, uploadId, accept, requestBody, requestOptions, context)); } @@ -2653,7 +2670,7 @@ public Response addUploadPartWithResponse(String uploadId, BinaryDat RequestOptions requestOptions) { final String contentType = "multipart/form-data"; final String accept = "application/json"; - return service.addUploadPartSync(OPEN_AI_ENDPOINT, contentType, uploadId, accept, requestBody, requestOptions, + return service.addUploadPartSync(getEndpoint(), contentType, uploadId, accept, requestBody, requestOptions, Context.NONE); } @@ -2718,7 +2735,7 @@ public Mono> completeUploadWithResponseAsync(String uploadI RequestOptions requestOptions) { final String contentType = "application/json"; final String accept = "application/json"; - return FluxUtil.withContext(context -> service.completeUpload(OPEN_AI_ENDPOINT, uploadId, contentType, accept, + return FluxUtil.withContext(context -> service.completeUpload(getEndpoint(), uploadId, contentType, accept, requestBody, requestOptions, context)); } @@ -2782,7 +2799,7 @@ public Response completeUploadWithResponse(String uploadId, BinaryDa RequestOptions requestOptions) { final String contentType = "application/json"; final String accept = "application/json"; - return service.completeUploadSync(OPEN_AI_ENDPOINT, uploadId, contentType, accept, requestBody, requestOptions, + return service.completeUploadSync(getEndpoint(), uploadId, contentType, accept, requestBody, requestOptions, Context.NONE); } @@ -2826,7 +2843,7 @@ public Response completeUploadWithResponse(String uploadId, BinaryDa public Mono> cancelUploadWithResponseAsync(String uploadId, RequestOptions requestOptions) { final String accept = "application/json"; return FluxUtil - .withContext(context -> service.cancelUpload(OPEN_AI_ENDPOINT, uploadId, accept, requestOptions, context)); + .withContext(context -> service.cancelUpload(getEndpoint(), uploadId, accept, requestOptions, context)); } /** @@ -2867,6 +2884,6 @@ public Mono> cancelUploadWithResponseAsync(String uploadId, @ServiceMethod(returns = ReturnType.SINGLE) public Response cancelUploadWithResponse(String uploadId, RequestOptions requestOptions) { final String accept = "application/json"; - return service.cancelUploadSync(OPEN_AI_ENDPOINT, uploadId, accept, requestOptions, Context.NONE); + return service.cancelUploadSync(getEndpoint(), uploadId, accept, requestOptions, Context.NONE); } } diff --git a/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/OpenAIClientBuilderTest.java b/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/OpenAIClientBuilderTest.java new file mode 100644 index 000000000000..b5d0ee597c8b --- /dev/null +++ b/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/OpenAIClientBuilderTest.java @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.ai.openai; + +import static com.azure.ai.openai.implementation.NonAzureOpenAIClientImpl.OPEN_AI_ENDPOINT_PATTERN; +import java.lang.reflect.Field; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.Test; +import static com.azure.ai.openai.implementation.NonAzureOpenAIClientImpl.OPEN_AI_ENDPOINT; + +public class OpenAIClientBuilderTest { + @Test + public void testEndpointPattern() { + assertTrue(OPEN_AI_ENDPOINT_PATTERN.matcher(OPEN_AI_ENDPOINT).matches()); + assertTrue(OPEN_AI_ENDPOINT_PATTERN.matcher("https://eu.api.openai.com/v1").matches()); + assertTrue(OPEN_AI_ENDPOINT_PATTERN.matcher("https://asdf.api.openai.com/v1").matches()); + assertFalse(OPEN_AI_ENDPOINT_PATTERN.matcher("http://api.openai.com/v1").matches()); + assertFalse(OPEN_AI_ENDPOINT_PATTERN.matcher("https://api.openai.com/").matches()); + assertFalse(OPEN_AI_ENDPOINT_PATTERN.matcher("https://dead.beef.api.openai.com/v1").matches()); + assertFalse(OPEN_AI_ENDPOINT_PATTERN.matcher("https://api.openai.com.org/v1").matches()); + assertFalse(OPEN_AI_ENDPOINT_PATTERN.matcher("https://api.openai.com/v2").matches()); + } + + @Test + public void testInnerImplBasedOnEndpoint() throws NoSuchFieldException, IllegalAccessException { + Field azureClient = OpenAIClient.class.getDeclaredField("serviceClient"); + azureClient.setAccessible(true); + Field nonAzureClient = OpenAIClient.class.getDeclaredField("openAIServiceClient"); + nonAzureClient.setAccessible(true); + + OpenAIClient nullEndpointClient = new OpenAIClientBuilder().buildClient(); + assertNull(azureClient.get(nullEndpointClient)); + assertNotNull(nonAzureClient.get(nullEndpointClient)); + + OpenAIClient customEndpointClient + = new OpenAIClientBuilder().endpoint("https://my.custom.domain/").buildClient(); + assertNotNull(azureClient.get(customEndpointClient)); + assertNull(nonAzureClient.get(customEndpointClient)); + + OpenAIClient defaultEndpointClient = new OpenAIClientBuilder().endpoint(OPEN_AI_ENDPOINT).buildClient(); + assertNull(azureClient.get(defaultEndpointClient)); + assertNotNull(nonAzureClient.get(defaultEndpointClient)); + + OpenAIClient euEndpointClient + = new OpenAIClientBuilder().endpoint("https://eu.api.openai.com/v1").buildClient(); + assertNull(azureClient.get(euEndpointClient)); + assertNotNull(nonAzureClient.get(euEndpointClient)); + } +}