Skip to content

Add Hugging Face Rerank support #127966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

Evgenii-Kazannik
Copy link

@Evgenii-Kazannik Evgenii-Kazannik commented May 9, 2025

  • Have you signed the contributor license agreement?
  • Have you followed the contributor guidelines?
  • If submitting code, have you built your formula locally prior to submission with gradle check?
  • If submitting code, is your pull request against main? Unless there is a good reason otherwise, we prefer pull requests against main and will backport as needed.
  • If submitting code, have you checked that your submission is for an OS and architecture that we support?
  • If you are submitting this code for a class then read our policy for that.

CA have been signed.
Used the following with success:

gradlew :x-pack:plugin:inference:check
gradlew.bat :x-pack:plugin:inference:spotlessApply

Tested via api:

PUT {{base-url}}/_inference/rerank/bge-reranker-base-mkn
{
"service": "hugging_face",
"service_settings": {
"api_key": "{{hf-api-key}}",
"url": "{{hf-bge-reranker-url}}"
},
"task_settings": {
"return_text": false,
"top_n": 4
}
}

POST {{base-url}}/_inference/rerank/bge-reranker-base-mkn
{
"input": ["luke", "like", "leia", "chewy","r2d2", "star", "wars"],
"query": "star wars main character",
"top_n": 2,
"return_documents": false
}


{
"rerank": [
{
"index": 6,
"relevance_score": 0.50955844
},
{
"index": 5,
"relevance_score": 0.084341794
}
]
}

POST {{base-url}}/_inference/rerank/bge-reranker-base-mkn
{
"input": ["luke", "like", "leia", "chewy","r2d2", "star", "wars"],
"query": "star wars main character",
"top_n": 3,
"return_documents": true
}


{
"rerank": [
{
"index": 6,
"relevance_score": 0.5089636,
"text": "wars"
},
{
"index": 5,
"relevance_score": 0.08449275,
"text": "star"
},
{
"index": 3,
"relevance_score": 0.0045032725,
"text": "chewy"
}
]
}

Also there were the following HF task settings integrated additionally:
raw_scores, truncate, truncation_direction
For now removed from the PR saving into a distinct branch,
Just for a case if we decide to make those a part of the inference api

@jonathan-buttner @Jan-Kazlouski-elastic
Apologies for the delay. I meant to create this much sooner.
Thanks for your patience

@elasticsearchmachine elasticsearchmachine added needs:triage Requires assignment of a team area label v9.1.0 external-contributor Pull request authored by a developer outside the Elasticsearch team labels May 9, 2025
@@ -171,6 +171,22 @@ static String mockDenseServiceModelConfig() {
""";
}

static String mockRerankServiceModelConfig() {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if methods you've added to this class are actually used somewhere. Methods you've taken for reference are being called. The ones you've added - are not.

Copy link
Author

@Evgenii-Kazannik Evgenii-Kazannik May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for noticing. It's used now

@@ -484,6 +500,10 @@ private String jsonBody(List<String> input, @Nullable String query) {
@SuppressWarnings("unchecked")
protected void assertNonEmptyInferenceResults(Map<String, Object> resultMap, int expectedNumberOfResults, TaskType taskType) {
switch (taskType) {
case RERANK -> {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this method is not called with TaskType.RERANK param anywhere. meaning assertion isn't triggered.

@@ -92,14 +98,15 @@ public HuggingFaceModel parsePersistedConfigWithSecrets(
Map<String, Object> secrets
) {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct me if I'm wrong. but won't that throw an exception if there are no task settings in config? If so, doesn't that affect other integrations that don't require TASK_SETTINGS to be present?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added Rerank type check to ensure the method isn't used for other tasks

}

@Override
public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question as above

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added type check before using the methos.Thanks

Copy link

@Jan-Kazlouski-elastic Jan-Kazlouski-elastic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments.


@Override
public boolean[] getTruncationInfo() {
return null;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a comment here, explaining why null is returned?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah truncation is only used in some services that support text embedding. Just say something like "Not applicable for rerank, only used in text embedding requests".

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added as suggested: Not applicable for rerank, only used in text embedding requests
Thanks all


@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.V_8_12_0;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check comments related to TransportVersions left by @jonathan-buttner to this PR: #127254
They would apply here as well.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did it, read it through, updated. I'm going to update the versions once more before the merge


@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.V_8_14_0;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing here related to comments for TranportVersion

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, applied the change

Copy link
Contributor

@jonathan-buttner jonathan-buttner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good! I left a few suggestions.

this.returnDocuments = returnDocuments;
this.topN = topN;
taskSettings = model.getTaskSettings();
this.model = model;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're saving a reference to the model how about we remove the taskSettings and inferenceEntityId references and just use the model.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. Used them from the model


import java.util.Map;

public class HuggingFaceModelInput {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we make this a record and maybe rename it to HuggingFaceModelParameters

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. The record fits better. Thanks. Done

private final String failureMessage;
private final ConfigurationParseContext context;

public HuggingFaceModelInput(Builder builder) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make this private? We probably want the instantiation done through the builder.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The builder was replaced with the record as suggested so not needed anymore.
Though thank you for pointing that out

@@ -128,17 +140,13 @@ public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
ConfigurationParseContext.PERSISTENT
);

return createModel(
TaskType.RERANK.equals(taskType) ? modelBuilder.withTaskSettings(taskSettingsMap).build() : modelBuilder.build()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the builder accepts null task settings so how about we just pass in the task settings map, regardless of it being null or not. That way we don't need to check for rerank here.

Copy link
Author

@Evgenii-Kazannik Evgenii-Kazannik May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The models accept the task settings map as is now. Thanks


@Override
public boolean[] getTruncationInfo() {
return null;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah truncation is only used in some services that support text embedding. Just say something like "Not applicable for rerank, only used in text embedding requests".

return RERANK_TOKEN_LIMIT;
}

// model is not defined in the service settings.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we encountered situations where the model id was required for chat completion, have we done any testing to see if the serverless style endpoint requires the model id?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing is that HF currently does not provide serverless for Rerank models. We cannot test it now

@Override
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
builder.field(URL, uri.toString());
builder.field(MAX_INPUT_TOKENS, RERANK_TOKEN_LIMIT);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove this, since we don't use it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed. Thank you

@@ -0,0 +1,123 @@
/*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're trying to move away from this style of parsing and instead use an ObjectParser or ConstructingObjectParser. How about we switch this implementation to use ConstructingObjectParser? Here's an example: https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntity.java

import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownToken;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;

public class HuggingFaceRerankResponseEntity extends ErrorResponse {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, typically we separate the valid response from the error response. Does the HuggingFaceErrorResponseEntity suffice?

"Failed to send Hugging Face %s request from inference entity id [%s]";
static final ResponseHandler RERANK_HANDLER = new HuggingFaceResponseHandler(
"hugging face rerank",
(request, response) -> HuggingFaceRerankResponseEntity.fromResponse((HuggingFaceRerankRequest) request, response)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be unlikely but can we do an instanceof check for request being a HuggingFaceRerankRequest? And throw an IllegalArgumentException if it's invalid.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Thank you Jonathan. Explicit check was added

# Conflicts:
#	server/src/main/java/org/elasticsearch/TransportVersions.java
@PeteGillinElastic PeteGillinElastic added :ml Machine learning and removed needs:triage Requires assignment of a team area label labels May 13, 2025
@elasticsearchmachine elasticsearchmachine added the Team:ML Meta label for the ML team label May 13, 2025
@elasticsearchmachine
Copy link
Collaborator

Pinging @elastic/ml-core (Team:ML)

private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE =
"Failed to send Hugging Face %s request from inference entity id [%s]";
private static final String INVALID_REQUEST_TYPE_MESSAGE = "Invalid request type: expected HuggingFace %s request but got %s";
static final ResponseHandler RERANK_HANDLER = new HuggingFaceResponseHandler("hugging face rerank", (request, response) -> {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it package private and not just private?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
external-contributor Pull request authored by a developer outside the Elasticsearch team :ml Machine learning Team:ML Meta label for the ML team v9.1.0
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants