Skip to content

Custom inference service jon #127939

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/127939.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 127939
summary: Add Custom inference service
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ static TransportVersion def(int id) {
public static final TransportVersion RESCORE_VECTOR_ALLOW_ZERO_BACKPORT_8_19 = def(8_841_0_27);
public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19 = def(8_841_0_28);
public static final TransportVersion ESQL_REPORT_SHARD_PARTITIONING_8_19 = def(8_841_0_29);
public static final TransportVersion ADD_INFERENCE_CUSTOM_MODEL_8_19 = def(8_841_0_30);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
Expand Down Expand Up @@ -252,7 +253,7 @@ static TransportVersion def(int id) {
public static final TransportVersion FIELD_CAPS_ADD_CLUSTER_ALIAS = def(9_073_0_00);
public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT = def(9_074_0_00);
public static final TransportVersion ESQL_FIELD_ATTRIBUTE_DROP_TYPE = def(9_075_0_00);

public static final TransportVersion ADD_INFERENCE_CUSTOM_MODEL = def(9_076_0_00);
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {

public void testGetServicesWithoutTaskType() throws IOException {
List<Object> services = getAllServices();
assertThat(services.size(), equalTo(22));
assertThat(services.size(), equalTo(23));

var providers = providers(services);

Expand All @@ -39,6 +39,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"custom",
"deepseek",
"elastic",
"elasticsearch",
Expand Down Expand Up @@ -70,7 +71,7 @@ private Iterable<String> providers(List<Object> services) {

public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
assertThat(services.size(), equalTo(16));
assertThat(services.size(), equalTo(17));

var providers = providers(services);

Expand All @@ -83,6 +84,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"custom",
"elasticsearch",
"googleaistudio",
"googlevertexai",
Expand All @@ -101,7 +103,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {

public void testGetServicesWithRerankTaskType() throws IOException {
List<Object> services = getServices(TaskType.RERANK);
assertThat(services.size(), equalTo(7));
assertThat(services.size(), equalTo(8));

var providers = providers(services);

Expand All @@ -111,6 +113,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
List.of(
"alibabacloud-ai-search",
"cohere",
"custom",
"elasticsearch",
"googlevertexai",
"jinaai",
Expand All @@ -123,7 +126,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {

public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
assertThat(services.size(), equalTo(10));
assertThat(services.size(), equalTo(11));

var providers = providers(services);

Expand All @@ -137,6 +140,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"custom",
"deepseek",
"googleaistudio",
"openai",
Expand All @@ -157,7 +161,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {

public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.SPARSE_EMBEDDING);
assertThat(services.size(), equalTo(6));
assertThat(services.size(), equalTo(7));

var providers = providers(services);

Expand All @@ -166,6 +170,7 @@ public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
containsInAnyOrder(
List.of(
"alibabacloud-ai-search",
"custom",
"elastic",
"elasticsearch",
"hugging_face",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.custom.CustomSecretSettings;
import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings;
import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings;
import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
Expand Down Expand Up @@ -154,6 +163,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
addAlibabaCloudSearchNamedWriteables(namedWriteables);
addJinaAINamedWriteables(namedWriteables);
addVoyageAINamedWriteables(namedWriteables);
addCustomNamedWriteables(namedWriteables);

addUnifiedNamedWriteables(namedWriteables);

Expand All @@ -165,6 +175,38 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
return namedWriteables;
}

private static void addCustomNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, CustomServiceSettings.NAME, CustomServiceSettings::new)
);

namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, CustomTaskSettings.NAME, CustomTaskSettings::new));

namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, CustomSecretSettings.NAME, CustomSecretSettings::new));

namedWriteables.add(
new NamedWriteableRegistry.Entry(CustomResponseParser.class, TextEmbeddingResponseParser.NAME, TextEmbeddingResponseParser::new)
);

namedWriteables.add(
new NamedWriteableRegistry.Entry(
CustomResponseParser.class,
SparseEmbeddingResponseParser.NAME,
SparseEmbeddingResponseParser::new
)
);

namedWriteables.add(
new NamedWriteableRegistry.Entry(CustomResponseParser.class, RerankResponseParser.NAME, RerankResponseParser::new)
);

namedWriteables.add(new NamedWriteableRegistry.Entry(CustomResponseParser.class, NoopResponseParser.NAME, NoopResponseParser::new));

namedWriteables.add(
new NamedWriteableRegistry.Entry(CustomResponseParser.class, CompletionResponseParser.NAME, CompletionResponseParser::new)
);
}

private static void addUnifiedNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
var writeables = UnifiedCompletionRequest.getNamedWriteables();
namedWriteables.addAll(writeables);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
import org.elasticsearch.xpack.inference.services.custom.CustomService;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
Expand Down Expand Up @@ -396,6 +397,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()),
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()),
context -> new CustomService(httpFactory.get(), serviceComponents.get()),
ElasticsearchInternalService::new
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public abstract class BaseResponseHandler implements ResponseHandler {
public static final String METHOD_NOT_ALLOWED = "Received a method not allowed status code";

protected final String requestType;
private final ResponseParser parseFunction;
protected final ResponseParser parseFunction;
private final Function<HttpResult, ErrorResponse> errorParseFunction;
private final boolean canHandleStreamingResponses;

Expand Down
Loading
Loading