Skip to content

Add Hugging Face Rerank support #127966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ static TransportVersion def(int id) {
public static final TransportVersion RESCORE_VECTOR_ALLOW_ZERO_BACKPORT_8_19 = def(8_841_0_27);
public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19 = def(8_841_0_28);
public static final TransportVersion ESQL_REPORT_SHARD_PARTITIONING_8_19 = def(8_841_0_29);
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED_8_19 = def(8_843_0_29);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
Expand Down Expand Up @@ -254,7 +255,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_FIELD_ATTRIBUTE_DROP_TYPE = def(9_075_0_00);
public static final TransportVersion ESQL_TIME_SERIES_SOURCE_STATUS = def(9_076_0_00);
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME = def(9_077_0_00);

public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED = def(9_078_0_00);
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
public class SettingsConfigurationTestUtils {

public static SettingsConfiguration getRandomSettingsConfigurationField() {
return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)).setDefaultValue(
randomAlphaOfLength(10)
)
return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK))
.setDefaultValue(randomAlphaOfLength(10))
.setDescription(randomAlphaOfLength(10))
.setLabel(randomAlphaOfLength(10))
.setRequired(randomBoolean())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,20 @@ static String mockDenseServiceModelConfig() {
""";
}

static String mockRerankServiceModelConfig() {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if methods you've added to this class are actually used somewhere. Methods you've taken for reference are being called. The ones you've added - are not.

Copy link
Author

@Evgenii-Kazannik Evgenii-Kazannik May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for noticing. It's used now

return """
{
"service": "test_reranking_service",
"service_settings": {
"model_id": "my_model",
"api_key": "abc64"
},
"task_settings": {
}
}
""";
}

static void deleteModel(String modelId) throws IOException {
var request = new Request("DELETE", "_inference/" + modelId);
var response = client().performRequest(request);
Expand Down Expand Up @@ -484,6 +498,10 @@ private String jsonBody(List<String> input, @Nullable String query) {
@SuppressWarnings("unchecked")
protected void assertNonEmptyInferenceResults(Map<String, Object> resultMap, int expectedNumberOfResults, TaskType taskType) {
switch (taskType) {
case RERANK -> {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this method is not called with TaskType.RERANK param anywhere. meaning assertion isn't triggered.

var results = (List<Map<String, Object>>) resultMap.get(TaskType.RERANK.toString());
assertThat(results, hasSize(expectedNumberOfResults));
}
case SPARSE_EMBEDDING -> {
var results = (List<Map<String, Object>>) resultMap.get(TaskType.SPARSE_EMBEDDING.toString());
assertThat(results, hasSize(expectedNumberOfResults));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,12 @@ public void testCRUD() throws IOException {
for (int i = 0; i < 4; i++) {
putModel("te_model_" + i, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING);
}
for (int i = 0; i < 3; i++) {
putModel("re-model-" + i, mockRerankServiceModelConfig(), TaskType.RERANK);
}

var getAllModels = getAllModels();
int numModels = 12;
int numModels = 15;
assertThat(getAllModels, hasSize(numModels));

var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING);
Expand All @@ -71,6 +74,13 @@ public void testCRUD() throws IOException {
for (var denseModel : getDenseModels) {
assertEquals("text_embedding", denseModel.get("task_type"));
}

var getRerankModels = getModels("_all", TaskType.RERANK);
int numRerankModels = 4;
assertThat(getRerankModels, hasSize(numRerankModels));
for (var denseModel : getRerankModels) {
assertEquals("rerank", denseModel.get("task_type"));
}
String oldApiKey;
{
var singleModel = getModels("se_model_1", TaskType.SPARSE_EMBEDDING);
Expand Down Expand Up @@ -100,6 +110,9 @@ public void testCRUD() throws IOException {
for (int i = 0; i < 4; i++) {
deleteModel("te_model_" + i, TaskType.TEXT_EMBEDDING);
}
for (int i = 0; i < 3; i++) {
deleteModel("re-model-" + i, TaskType.RERANK);
}
}

public void testGetModelWithWrongTaskType() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {

public void testGetServicesWithRerankTaskType() throws IOException {
List<Object> services = getServices(TaskType.RERANK);
assertThat(services.size(), equalTo(7));
assertThat(services.size(), equalTo(8));

var providers = providers(services);

Expand All @@ -115,7 +115,8 @@ public void testGetServicesWithRerankTaskType() throws IOException {
"googlevertexai",
"jinaai",
"test_reranking_service",
"voyageai"
"voyageai",
"hugging_face"
).toArray()
)
);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference;

import org.elasticsearch.inference.TaskType;

import java.io.IOException;
import java.util.List;
import java.util.Map;

public class MockRerankInferenceServiceIT extends InferenceBaseRestTest {

@SuppressWarnings("unchecked")
public void testMockService() throws IOException {
String inferenceEntityId = "test-mock";
var putModel = putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK);
var model = getModels(inferenceEntityId, TaskType.RERANK).get(0);

for (var modelMap : List.of(putModel, model)) {
assertEquals(inferenceEntityId, modelMap.get("inference_id"));
assertEquals(TaskType.RERANK, TaskType.fromString((String) modelMap.get("task_type")));
assertEquals("test_reranking_service", modelMap.get("service"));
}

List<String> input = List.of(randomAlphaOfLength(10));
var inference = infer(inferenceEntityId, input);
assertNonEmptyInferenceResults(inference, 1, TaskType.RERANK);
assertEquals(inference, infer(inferenceEntityId, input));
assertNotEquals(inference, infer(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10)))));
}

public void testMockServiceWithMultipleInputs() throws IOException {
String inferenceEntityId = "test-mock-with-multi-inputs";
putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK);
var queryParams = Map.of("timeout", "120s");

var inference = infer(
inferenceEntityId,
TaskType.RERANK,
List.of(randomAlphaOfLength(5), randomAlphaOfLength(10)),
"What if?",
queryParams
);

assertNonEmptyInferenceResults(inference, 2, TaskType.RERANK);
}

@SuppressWarnings("unchecked")
public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOException {
String inferenceEntityId = "test-mock";
var putModel = putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK);
var model = getModels(inferenceEntityId, TaskType.RERANK).get(0);

var serviceSettings = (Map<String, Object>) model.get("service_settings");
assertNull(serviceSettings.get("api_key"));
assertNotNull(serviceSettings.get("model_id"));

var putServiceSettings = (Map<String, Object>) putModel.get("service_settings");
assertNull(putServiceSettings.get("api_key"));
assertNotNull(putServiceSettings.get("model_id"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;
Expand Down Expand Up @@ -357,6 +359,16 @@ private static void addHuggingFaceNamedWriteables(List<NamedWriteableRegistry.En
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, HuggingFaceServiceSettings.NAME, HuggingFaceServiceSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, HuggingFaceRerankTaskSettings.NAME, HuggingFaceRerankTaskSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
HuggingFaceRerankServiceSettings.NAME,
HuggingFaceRerankServiceSettings::new
)
);
}

private static void addGoogleAiStudioNamedWritables(List<NamedWriteableRegistry.Entry> namedWriteables) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;

import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -91,7 +92,10 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[
} else if (r.getEndpoints().isEmpty() == false
&& r.getEndpoints().get(0).getTaskSettings() instanceof GoogleVertexAiRerankTaskSettings googleVertexAiTaskSettings) {
configuredTopN = googleVertexAiTaskSettings.topN();
}
} else if (r.getEndpoints().isEmpty() == false
&& r.getEndpoints().get(0).getTaskSettings() instanceof HuggingFaceRerankTaskSettings huggingFaceRerankTaskSettings) {
configuredTopN = huggingFaceRerankTaskSettings.getTopNDocumentsOnly();
}
if (configuredTopN != null && configuredTopN < rankWindowSize) {
l.onFailure(
new IllegalArgumentException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator;

import java.util.Collections;
import java.util.Map;

import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
Expand Down Expand Up @@ -57,6 +58,11 @@ public void parseRequestConfig(
) {
try {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = Collections.emptyMap();

if (TaskType.RERANK.equals(taskType)) {
taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
}

ChunkingSettings chunkingSettings = null;
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
Expand All @@ -66,17 +72,21 @@ public void parseRequestConfig(
}

var model = createModel(
inferenceEntityId,
taskType,
serviceSettingsMap,
chunkingSettings,
serviceSettingsMap,
TaskType.unsupportedTaskTypeErrorMsg(taskType, name()),
ConfigurationParseContext.REQUEST
new HuggingFaceModelParameters(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
TaskType.unsupportedTaskTypeErrorMsg(taskType, name()),
ConfigurationParseContext.REQUEST
)
);

throwIfNotEmptyMap(config, name());
throwIfNotEmptyMap(serviceSettingsMap, name());
throwIfNotEmptyMap(taskSettingsMap, name());

parsedModelListener.onResponse(model);
} catch (Exception e) {
Expand All @@ -93,52 +103,60 @@ public HuggingFaceModel parsePersistedConfigWithSecrets(
) {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);
Map<String, Object> taskSettingsMap = Collections.emptyMap();

if (TaskType.RERANK.equals(taskType)) {
taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
}

ChunkingSettings chunkingSettings = null;
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}

return createModel(
inferenceEntityId,
taskType,
serviceSettingsMap,
chunkingSettings,
secretSettingsMap,
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
ConfigurationParseContext.PERSISTENT
new HuggingFaceModelParameters(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
secretSettingsMap,
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
ConfigurationParseContext.PERSISTENT
)
);
}

@Override
public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = Collections.emptyMap();

if (TaskType.RERANK.equals(taskType)) {
taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
}

ChunkingSettings chunkingSettings = null;
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}

return createModel(
inferenceEntityId,
taskType,
serviceSettingsMap,
chunkingSettings,
null,
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
ConfigurationParseContext.PERSISTENT
new HuggingFaceModelParameters(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
null,
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
ConfigurationParseContext.PERSISTENT
)
);
}

protected abstract HuggingFaceModel createModel(
String inferenceEntityId,
TaskType taskType,
Map<String, Object> serviceSettings,
ChunkingSettings chunkingSettings,
Map<String, Object> secretSettings,
String failureMessage,
ConfigurationParseContext context
);
protected abstract HuggingFaceModel createModel(HuggingFaceModelParameters input);

@Override
public void doInfer(
Expand Down
Loading