Skip to content

Extend huggingface with rerank #127297

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
public class SettingsConfigurationTestUtils {

public static SettingsConfiguration getRandomSettingsConfigurationField() {
return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)).setDefaultValue(
randomAlphaOfLength(10)
)
return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK))
.setDefaultValue(randomAlphaOfLength(10))
.setDescription(randomAlphaOfLength(10))
.setLabel(randomAlphaOfLength(10))
.setRequired(randomBoolean())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,22 @@ static String mockDenseServiceModelConfig() {
""";
}

static String mockRerankServiceModelConfig() {
return """
{
"task_type": "rerank",
"service": "rerank_test_service",
"service_settings": {
"model": "rerank_model",
"api_key": "abc64"
},
"task_settings": {
"return_documents": true
}
}
""";
}

static void deleteModel(String modelId) throws IOException {
var request = new Request("DELETE", "_inference/" + modelId);
var response = client().performRequest(request);
Expand Down Expand Up @@ -484,6 +500,10 @@ private String jsonBody(List<String> input, @Nullable String query) {
@SuppressWarnings("unchecked")
protected void assertNonEmptyInferenceResults(Map<String, Object> resultMap, int expectedNumberOfResults, TaskType taskType) {
switch (taskType) {
case RERANK -> {
var results = (List<Map<String, Object>>) resultMap.get(TaskType.RERANK.toString());
assertThat(results, hasSize(expectedNumberOfResults));
}
case SPARSE_EMBEDDING -> {
var results = (List<Map<String, Object>>) resultMap.get(TaskType.SPARSE_EMBEDDING.toString());
assertThat(results, hasSize(expectedNumberOfResults));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
@SuppressWarnings("unchecked")
public void testGetServicesWithRerankTaskType() throws IOException {
List<Object> services = getServices(TaskType.RERANK);
assertThat(services.size(), equalTo(7));
assertThat(services.size(), equalTo(8));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Expand All @@ -106,8 +106,16 @@ public void testGetServicesWithRerankTaskType() throws IOException {
}

assertArrayEquals(
List.of("alibabacloud-ai-search", "cohere", "elasticsearch", "googlevertexai", "jinaai", "test_reranking_service", "voyageai")
.toArray(),
List.of(
"alibabacloud-ai-search",
"cohere",
"elasticsearch",
"googlevertexai",
"jinaai",
"test_reranking_service",
"voyageai",
"hugging_face"
).toArray(),
providers
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;
Expand Down Expand Up @@ -353,6 +354,13 @@ private static void addHuggingFaceNamedWriteables(List<NamedWriteableRegistry.En
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, HuggingFaceServiceSettings.NAME, HuggingFaceServiceSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
HuggingFaceRerankServiceSettings.NAME,
HuggingFaceRerankServiceSettings::new
)
);
}

private static void addGoogleAiStudioNamedWritables(List<NamedWriteableRegistry.Entry> namedWriteables) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@

import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionVisitor;
import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;

import java.util.Objects;

public abstract class HuggingFaceModel extends Model {
public abstract class HuggingFaceModel extends RateLimitGroupingModel {
private final HuggingFaceRateLimitServiceSettings rateLimitServiceSettings;
private final SecureString apiKey;

Expand All @@ -38,6 +39,16 @@ public HuggingFaceRateLimitServiceSettings rateLimitServiceSettings() {
return rateLimitServiceSettings;
}

@Override
public int rateLimitGroupingHash() {
return Objects.hash(rateLimitServiceSettings.uri(), apiKey);
}

@Override
public RateLimitSettings rateLimitSettings() {
return rateLimitServiceSettings.rateLimitSettings();
}

public SecureString apiKey() {
return apiKey;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.huggingface;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;

public class HuggingFaceRerankTaskSettings implements TaskSettings {

public static final String NAME = "hugging_face_rerank_task_settings";
public static final String RETURN_DOCUMENTS = "return_documents";
public static final String TOP_N_DOCS_ONLY = "top_n";
public static final String MAX_CHUNKS_PER_DOC = "max_chunks_per_doc";

static final HuggingFaceRerankTaskSettings EMPTY_SETTINGS = new HuggingFaceRerankTaskSettings(null, null, null);

public static HuggingFaceRerankTaskSettings fromMap(Map<String, Object> map) {
ValidationException validationException = new ValidationException();

if (map == null || map.isEmpty()) {
return EMPTY_SETTINGS;
}

Boolean returnDocuments = extractOptionalBoolean(map, RETURN_DOCUMENTS, validationException);
Integer topNDocumentsOnly = extractOptionalPositiveInteger(
map,
TOP_N_DOCS_ONLY,
ModelConfigurations.TASK_SETTINGS,
validationException
);
Integer maxChunksPerDoc = extractOptionalPositiveInteger(
map,
MAX_CHUNKS_PER_DOC,
ModelConfigurations.TASK_SETTINGS,
validationException
);

if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}

return of(topNDocumentsOnly, returnDocuments, maxChunksPerDoc);
}

/**
* Creates a new {@link HuggingFaceRerankTaskSettings}
* by preferring non-null fields from the request settings over the original settings.
*
* @param originalSettings the settings stored as part of the inference entity configuration
* @param requestTaskSettings the settings passed in within the task_settings field of the request
* @return a constructed {@link HuggingFaceRerankTaskSettings}
*/
public static HuggingFaceRerankTaskSettings of(
HuggingFaceRerankTaskSettings originalSettings,
HuggingFaceRerankTaskSettings requestTaskSettings
) {
return new HuggingFaceRerankTaskSettings(
requestTaskSettings.getTopNDocumentsOnly() != null
? requestTaskSettings.getTopNDocumentsOnly()
: originalSettings.getTopNDocumentsOnly(),
requestTaskSettings.getReturnDocuments() != null
? requestTaskSettings.getReturnDocuments()
: originalSettings.getReturnDocuments(),
requestTaskSettings.getMaxChunksPerDoc() != null
? requestTaskSettings.getMaxChunksPerDoc()
: originalSettings.getMaxChunksPerDoc()
);
}

public static HuggingFaceRerankTaskSettings of(Integer topNDocumentsOnly, Boolean returnDocuments, Integer maxChunksPerDoc) {
return new HuggingFaceRerankTaskSettings(topNDocumentsOnly, returnDocuments, maxChunksPerDoc);
}

private final Integer topNDocumentsOnly;
private final Boolean returnDocuments;
private final Integer maxChunksPerDoc;

public HuggingFaceRerankTaskSettings(StreamInput in) throws IOException {
this(in.readOptionalInt(), in.readOptionalBoolean(), in.readOptionalInt());
}

public HuggingFaceRerankTaskSettings(
@Nullable Integer topNDocumentsOnly,
@Nullable Boolean doReturnDocuments,
@Nullable Integer maxChunksPerDoc
) {
this.topNDocumentsOnly = topNDocumentsOnly;
this.returnDocuments = doReturnDocuments;
this.maxChunksPerDoc = maxChunksPerDoc;
}

@Override
public boolean isEmpty() {
return topNDocumentsOnly == null && returnDocuments == null && maxChunksPerDoc == null;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (topNDocumentsOnly != null) {
builder.field(TOP_N_DOCS_ONLY, topNDocumentsOnly);
}
if (returnDocuments != null) {
builder.field(RETURN_DOCUMENTS, returnDocuments);
}
if (maxChunksPerDoc != null) {
builder.field(MAX_CHUNKS_PER_DOC, maxChunksPerDoc);
}
builder.endObject();
return builder;
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.V_8_14_0;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalInt(topNDocumentsOnly);
out.writeOptionalBoolean(returnDocuments);
out.writeOptionalInt(maxChunksPerDoc);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
HuggingFaceRerankTaskSettings that = (HuggingFaceRerankTaskSettings) o;
return Objects.equals(returnDocuments, that.returnDocuments)
&& Objects.equals(topNDocumentsOnly, that.topNDocumentsOnly)
&& Objects.equals(maxChunksPerDoc, that.maxChunksPerDoc);
}

@Override
public int hashCode() {
return Objects.hash(returnDocuments, topNDocumentsOnly, maxChunksPerDoc);
}

public Boolean getDoesReturnDocuments() {
return returnDocuments;
}

public Integer getTopNDocumentsOnly() {
return topNDocumentsOnly;
}

public Boolean getReturnDocuments() {
return returnDocuments;
}

public Integer getMaxChunksPerDoc() {
return maxChunksPerDoc;
}

@Override
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
HuggingFaceRerankTaskSettings updatedSettings = HuggingFaceRerankTaskSettings.fromMap(new HashMap<>(newSettings));
return HuggingFaceRerankTaskSettings.of(this, updatedSettings);
}
}
Loading