-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Add Hugging Face Rerank support #127966
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add Hugging Face Rerank support #127966
Conversation
@@ -171,6 +171,22 @@ static String mockDenseServiceModelConfig() { | |||
"""; | |||
} | |||
|
|||
static String mockRerankServiceModelConfig() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if methods you've added to this class are actually used somewhere. Methods you've taken for reference are being called. The ones you've added - are not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for noticing. It's used now
@@ -484,6 +500,10 @@ private String jsonBody(List<String> input, @Nullable String query) { | |||
@SuppressWarnings("unchecked") | |||
protected void assertNonEmptyInferenceResults(Map<String, Object> resultMap, int expectedNumberOfResults, TaskType taskType) { | |||
switch (taskType) { | |||
case RERANK -> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like this method is not called with TaskType.RERANK param anywhere. meaning assertion isn't triggered.
@@ -92,14 +98,15 @@ public HuggingFaceModel parsePersistedConfigWithSecrets( | |||
Map<String, Object> secrets | |||
) { | |||
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); | |||
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct me if I'm wrong. but won't that throw an exception if there are no task settings in config? If so, doesn't that affect other integrations that don't require TASK_SETTINGS to be present?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added Rerank type check to ensure the method isn't used for other tasks
} | ||
|
||
@Override | ||
public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) { | ||
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); | ||
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same question as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added type check before using the methos.Thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments.
|
||
@Override | ||
public boolean[] getTruncationInfo() { | ||
return null; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have a comment here, explaining why null is returned?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah truncation is only used in some services that support text embedding. Just say something like "Not applicable for rerank, only used in text embedding requests".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added as suggested: Not applicable for rerank, only used in text embedding requests
Thanks all
|
||
@Override | ||
public TransportVersion getMinimalSupportedVersion() { | ||
return TransportVersions.V_8_12_0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please check comments related to TransportVersions left by @jonathan-buttner to this PR: #127254
They would apply here as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did it, read it through, updated. I'm going to update the versions once more before the merge
|
||
@Override | ||
public TransportVersion getMinimalSupportedVersion() { | ||
return TransportVersions.V_8_14_0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same thing here related to comments for TranportVersion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, applied the change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good! I left a few suggestions.
this.returnDocuments = returnDocuments; | ||
this.topN = topN; | ||
taskSettings = model.getTaskSettings(); | ||
this.model = model; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we're saving a reference to the model
how about we remove the taskSettings
and inferenceEntityId
references and just use the model
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. Used them from the model
|
||
import java.util.Map; | ||
|
||
public class HuggingFaceModelInput { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about we make this a record and maybe rename it to HuggingFaceModelParameters
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep. The record fits better. Thanks. Done
private final String failureMessage; | ||
private final ConfigurationParseContext context; | ||
|
||
public HuggingFaceModelInput(Builder builder) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we make this private? We probably want the instantiation done through the builder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The builder was replaced with the record as suggested so not needed anymore.
Though thank you for pointing that out
@@ -128,17 +140,13 @@ public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType | |||
parsePersistedConfigErrorMsg(inferenceEntityId, name()), | |||
ConfigurationParseContext.PERSISTENT | |||
); | |||
|
|||
return createModel( | |||
TaskType.RERANK.equals(taskType) ? modelBuilder.withTaskSettings(taskSettingsMap).build() : modelBuilder.build() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like the builder accepts null task settings so how about we just pass in the task settings map, regardless of it being null or not. That way we don't need to check for rerank here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The models accept the task settings map as is now. Thanks
|
||
@Override | ||
public boolean[] getTruncationInfo() { | ||
return null; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah truncation is only used in some services that support text embedding. Just say something like "Not applicable for rerank, only used in text embedding requests".
return RERANK_TOKEN_LIMIT; | ||
} | ||
|
||
// model is not defined in the service settings. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we encountered situations where the model id was required for chat completion, have we done any testing to see if the serverless style endpoint requires the model id?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The thing is that HF currently does not provide serverless for Rerank models. We cannot test it now
@Override | ||
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { | ||
builder.field(URL, uri.toString()); | ||
builder.field(MAX_INPUT_TOKENS, RERANK_TOKEN_LIMIT); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove this, since we don't use it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed. Thank you
@@ -0,0 +1,123 @@ | |||
/* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're trying to move away from this style of parsing and instead use an ObjectParser
or ConstructingObjectParser
. How about we switch this implementation to use ConstructingObjectParser
? Here's an example: https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiEmbeddingsResponseEntity.java
import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownToken; | ||
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; | ||
|
||
public class HuggingFaceRerankResponseEntity extends ErrorResponse { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, typically we separate the valid response from the error response. Does the HuggingFaceErrorResponseEntity
suffice?
"Failed to send Hugging Face %s request from inference entity id [%s]"; | ||
static final ResponseHandler RERANK_HANDLER = new HuggingFaceResponseHandler( | ||
"hugging face rerank", | ||
(request, response) -> HuggingFaceRerankResponseEntity.fromResponse((HuggingFaceRerankRequest) request, response) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It'd be unlikely but can we do an instanceof
check for request
being a HuggingFaceRerankRequest
? And throw an IllegalArgumentException
if it's invalid.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Thank you Jonathan. Explicit check was added
# Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java
Pinging @elastic/ml-core (Team:ML) |
private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE = | ||
"Failed to send Hugging Face %s request from inference entity id [%s]"; | ||
private static final String INVALID_REQUEST_TYPE_MESSAGE = "Invalid request type: expected HuggingFace %s request but got %s"; | ||
static final ResponseHandler RERANK_HANDLER = new HuggingFaceResponseHandler("hugging face rerank", (request, response) -> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is it package private and not just private?
gradle check
?CA have been signed.
Used the following with success:
gradlew :x-pack:plugin:inference:check
gradlew.bat :x-pack:plugin:inference:spotlessApply
Tested via api:
PUT {{base-url}}/_inference/rerank/bge-reranker-base-mkn
{
"service": "hugging_face",
"service_settings": {
"api_key": "{{hf-api-key}}",
"url": "{{hf-bge-reranker-url}}"
},
"task_settings": {
"return_text": false,
"top_n": 4
}
}
POST {{base-url}}/_inference/rerank/bge-reranker-base-mkn
{
"input": ["luke", "like", "leia", "chewy","r2d2", "star", "wars"],
"query": "star wars main character",
"top_n": 2,
"return_documents": false
}
{
"rerank": [
{
"index": 6,
"relevance_score": 0.50955844
},
{
"index": 5,
"relevance_score": 0.084341794
}
]
}
POST {{base-url}}/_inference/rerank/bge-reranker-base-mkn
{
"input": ["luke", "like", "leia", "chewy","r2d2", "star", "wars"],
"query": "star wars main character",
"top_n": 3,
"return_documents": true
}
{
"rerank": [
{
"index": 6,
"relevance_score": 0.5089636,
"text": "wars"
},
{
"index": 5,
"relevance_score": 0.08449275,
"text": "star"
},
{
"index": 3,
"relevance_score": 0.0045032725,
"text": "chewy"
}
]
}
Also there were the following HF task settings integrated additionally:
raw_scores, truncate, truncation_direction
For now removed from the PR saving into a distinct branch,
Just for a case if we decide to make those a part of the inference api
@jonathan-buttner @Jan-Kazlouski-elastic
Apologies for the delay. I meant to create this much sooner.
Thanks for your patience