Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions solr/modules/language-models/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies {
implementation project(':solr:solrj')

implementation libs.apache.lucene.core
implementation libs.apache.lucene.highlighter

implementation libs.langchain4j.core
runtimeOnly libs.langchain4j.cohere
Expand Down
2 changes: 1 addition & 1 deletion solr/modules/language-models/gradle.lockfile
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ org.apache.lucene:lucene-core:10.3.2=compileClasspath,jarValidation,runtimeClass
org.apache.lucene:lucene-expressions:10.3.2=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath
org.apache.lucene:lucene-facet:10.3.2=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath
org.apache.lucene:lucene-grouping:10.3.2=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath
org.apache.lucene:lucene-highlighter:10.3.2=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath
org.apache.lucene:lucene-highlighter:10.3.2=compileClasspath,jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testCompileClasspath,testRuntimeClasspath
org.apache.lucene:lucene-join:10.3.2=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath
org.apache.lucene:lucene-memory:10.3.2=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath
org.apache.lucene:lucene-misc:10.3.2=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.languagemodels.textvectorisation.model;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.apache.solr.common.util.CollectionUtil;

public class CustomModel implements EmbeddingModel {

private final String customUrl;
private final Long customVersion;
private final float[] defaultEmbedding;
private final Map<String, float[]> customEmbeddings;

private CustomModel(
String customUrl,
Long customVersion,
float[] defaultEmbedding,
Map<String, float[]> customEmbeddings) {
this.customUrl = customUrl;
this.customVersion = customVersion;
this.defaultEmbedding = defaultEmbedding;
this.customEmbeddings = customEmbeddings;
}

private Embedding embedding(String text) {
float[] embedding = customEmbeddings.getOrDefault(text, defaultEmbedding);
return new Embedding(embedding);
}

@Override
public Response<Embedding> embed(String text) {
return new Response<>(embedding(text));
}

@Override
public Response<Embedding> embed(TextSegment textSegment) {
return new Response<>(embedding(textSegment.text()));
}

@Override
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
List<Embedding> embeddings = new ArrayList<>(textSegments.size());
for (TextSegment textSegment : textSegments) {
embeddings.add(embedding(textSegment.text()));
}
return new Response<>(embeddings);
}

@Override
public int dimension() {
return this.defaultEmbedding.length;
}

public static CustomModelBuilder builder() {
return new CustomModelBuilder();
}

public static class CustomModelBuilder {
private String customUrl;
private Long customVersion;
private float[] defaultEmbedding;
private Map<String, float[]> customEmbeddings;

public CustomModelBuilder() {}

public CustomModelBuilder customUrl(String customUrl) {
this.customUrl = customUrl;
return this;
}

public CustomModelBuilder customVersion(Number customVersion) {
this.customVersion = customVersion.longValue();
return this;
}

public CustomModelBuilder defaultEmbedding(ArrayList<Double> defaultEmbedding) {
float[] converted = new float[defaultEmbedding.size()];
for (int i = 0; i < defaultEmbedding.size(); i++) {
converted[i] = defaultEmbedding.get(i).floatValue();
}
this.defaultEmbedding = converted;
return this;
}

public CustomModelBuilder customEmbeddings(Map<String, ArrayList<Double>> customEmbeddings) {
Map<String, float[]> converted = CollectionUtil.newHashMap(customEmbeddings.size());
for (Map.Entry<String, ArrayList<Double>> entry : customEmbeddings.entrySet()) {
ArrayList<Double> values = entry.getValue();
float[] embedding = new float[values.size()];
for (int i = 0; i < values.size(); i++) {
embedding[i] = values.get(i).floatValue();
}
converted.put(entry.getKey(), embedding);
}
this.customEmbeddings = converted;
return this;
}

public CustomModel build() {
if (defaultEmbedding == null) {
throw new IllegalArgumentException("defaultEmbedding must not be null");
}
if (customEmbeddings == null) {
throw new IllegalArgumentException("customEmbeddings must not be null");
}
final int expectedDimension = defaultEmbedding.length;
for (Map.Entry<String, float[]> entry : customEmbeddings.entrySet()) {
final int actualDimension = entry.getValue().length;
if (actualDimension != expectedDimension) {
throw new IllegalArgumentException(
"Custom embedding for key '"
+ entry.getKey()
+ "' has dimension "
+ actualDimension
+ ", expected "
+ expectedDimension);
}
}
return new CustomModel(customUrl, customVersion, defaultEmbedding, customEmbeddings);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"class": "org.apache.solr.languagemodels.textvectorisation.model.CustomModel",
"name": "custom-1",
"params": {
"customUrl": "https://custom.api/text-custom",
"customVersion": 42,
"defaultEmbedding": [
0.0, 0.0, 0.0, 0.0, 0.0
],
"customEmbeddings": {
"the queen bee": [
0.1, 0.2, 0.3, 0.4, 0.5
],
"the hardest working bee": [
0.5, 0.6, 0.7, 0.8, 0.9
],
"the bee that": [
0.3, 0.4, 0.0, 0.6, 0.7
]
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
<field name="2048_byte_vector" type="high_dimensional_byte_knn_vector" indexed="true" stored="true" />
<field name="2048_float_vector" type="high_dimensional_float_knn_vector" indexed="true" stored="true" />
<field name="string_field" type="string" indexed="true" stored="true" multiValued="false" required="false"/>
<field name="text_field" type="text_general" indexed="true" stored="true" multiValued="false" required="false"/>

<field name="_version_" type="plong" indexed="true" stored="true" multiValued="false" />
<field name="_text_" type="text_general" indexed="true" stored="false" multiValued="true"/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,6 @@
<processor class="solr.RunUpdateProcessorFactory"/>
</updateRequestProcessorChain>

<searchComponent class="org.apache.solr.languagemodels.textvectorisation.handler.component.SemanticHighlightComponent" name="highlight"/>

</config>
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.languagemodels.textvectorisation.handler.component;

import java.util.Comparator;
import org.apache.lucene.search.uhighlight.Passage;
import org.apache.lucene.search.uhighlight.UnifiedHighlighter;
import org.apache.lucene.util.BytesRef;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.params.HighlightParams;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.core.SolrCore;
import org.apache.solr.handler.component.HighlightComponent;
import org.apache.solr.highlight.SolrHighlighter;
import org.apache.solr.highlight.UnifiedSolrHighlighter;
import org.apache.solr.languagemodels.textvectorisation.model.SolrTextToVectorModel;
import org.apache.solr.languagemodels.textvectorisation.store.rest.ManagedTextToVectorModelStore;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.rest.ManagedResource;
import org.apache.solr.rest.ManagedResourceObserver;
import org.apache.solr.util.plugin.SolrCoreAware;

public class SemanticHighlightComponent extends HighlightComponent
implements SolrCoreAware, ManagedResourceObserver {

private ManagedTextToVectorModelStore modelStore = null;

@Override
public void inform(SolrCore core) {
super.inform(core);
ManagedTextToVectorModelStore.registerManagedTextToVectorModelStore(
core.getResourceLoader(), this);
}

@Override
public void onManagedResourceInitialized(NamedList<?> args, ManagedResource res)
throws SolrException {
if (res instanceof ManagedTextToVectorModelStore) {
modelStore = (ManagedTextToVectorModelStore) res;
}
if (modelStore != null) {
modelStore.loadStoredModels();
}
}

@Override
public SolrHighlighter getHighlighter(SolrParams params) {
if ("unified_with_semantic".equals(params.get(HighlightParams.METHOD))) {

return new UnifiedSolrHighlighter() {
@Override
protected UnifiedHighlighter getHighlighter(SolrQueryRequest req) {

final ManagedTextToVectorModelStore modelStore =
ManagedTextToVectorModelStore.getManagedModelStore(req.getCore());

final String modelName = req.getParams().get("hl.unified_with_semantic.model");
final SolrTextToVectorModel textToVector = modelStore.getModel(modelName);
if (textToVector == null) {
throw new SolrException(
SolrException.ErrorCode.BAD_REQUEST,
"The model requested '" + modelName + "' can't be found.");
}

final float[] queryVector;
final String queryVectorString = req.getParams().get("hl.unified_with_semantic.vector");
if (queryVectorString != null) {
final String[] queryVectorArray = queryVectorString.split(",");
queryVector = new float[queryVectorArray.length];
for (int i = 0; i < queryVectorArray.length; i++) {
queryVector[i] = Float.parseFloat(queryVectorArray[i]);
}
} else {
queryVector = null;
}

return new SolrExtendedUnifiedHighlighter(req) {

private static String textFromPassage(Passage passage) {
final StringBuilder sb = new StringBuilder();
String delimiter = "";
for (BytesRef term : passage.getMatchTerms()) {
if (term != null) {
sb.append(delimiter).append(term.utf8ToString());
delimiter = " ";
}
}
return sb.toString();
}

private static double euclideanDistance(float[] a, float[] b) {
double sumSquared = 0.0;
for (int i = 0; i < a.length; i++) {
double diff = a[i] - b[i];
sumSquared += diff * diff;
}
return Math.sqrt(sumSquared);
}

@Override
protected Comparator<Passage> getPassageSortComparator(String field) {
return new Comparator<Passage>() {
@Override
public int compare(Passage a, Passage b) {
String aText = textFromPassage(a);
String bText = textFromPassage(b);
if (queryVector == null) {
return aText.compareTo(bText);
} else {
float[] aVector = textToVector.vectorise(aText);
float[] bVector = textToVector.vectorise(bText);
double aDistance = euclideanDistance(aVector, queryVector);
double bDistance = euclideanDistance(bVector, queryVector);
return Double.compare(aDistance, bDistance);
}
}
};
}
};
}
};
} else {
return super.getHighlighter(params);
}
}
}
Loading
Loading