diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java index 6b5d708c0a503..f996c559a8779 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java @@ -59,7 +59,9 @@ record CmdLineArgs( double writerBufferSizeInMb, int writerMaxBufferedDocs, int forceMergeMaxNumSegments, - boolean onDiskRescore + boolean onDiskRescore, + boolean doPrecondition, + int preconditioningDims ) implements ToXContentObject { static final ParseField DOC_VECTORS_FIELD = new ParseField("doc_vectors"); @@ -92,6 +94,8 @@ record CmdLineArgs( static final ParseField WRITER_BUFFER_MB_FIELD = new ParseField("writer_buffer_mb"); static final ParseField WRITER_BUFFER_DOCS_FIELD = new ParseField("writer_buffer_docs"); static final ParseField ON_DISK_RESCORE_FIELD = new ParseField("on_disk_rescore"); + static final ParseField DO_PRECONDITION = new ParseField("do_precondition"); + static final ParseField PRECONDITIONING_DIMS = new ParseField("preconditioning_dims"); /** By default, in ES the default writer buffer size is 10% of the heap space * (see {@code IndexingMemoryController.INDEX_BUFFER_SIZE_SETTING}). @@ -138,6 +142,8 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException { PARSER.declareInt(Builder::setWriterMaxBufferedDocs, WRITER_BUFFER_DOCS_FIELD); PARSER.declareInt(Builder::setForceMergeMaxNumSegments, FORCE_MERGE_MAX_NUM_SEGMENTS_FIELD); PARSER.declareBoolean(Builder::setOnDiskRescore, ON_DISK_RESCORE_FIELD); + PARSER.declareBoolean(Builder::setDoPrecondition, DO_PRECONDITION); + PARSER.declareInt(Builder::setPreconditioningDims, PRECONDITIONING_DIMS); } @Override @@ -213,6 +219,8 @@ static class Builder { private KnnIndexTester.MergePolicyType mergePolicy = null; private double writerBufferSizeInMb = DEFAULT_WRITER_BUFFER_MB; private boolean onDiskRescore = false; + private boolean doPrecondition = false; + private int preconditioningDims = 64; /** * Elasticsearch does not set this explicitly, and in Lucene this setting is @@ -369,6 +377,16 @@ public Builder setOnDiskRescore(boolean onDiskRescore) { return this; } + public Builder setDoPrecondition(boolean doPrecondition) { + this.doPrecondition = doPrecondition; + return this; + } + + public Builder setPreconditioningDims(int preconditioningDims) { + this.preconditioningDims = preconditioningDims; + return this; + } + public CmdLineArgs build() { if (docVectors == null) { throw new IllegalArgumentException("Document vectors path must be provided"); @@ -407,7 +425,9 @@ public CmdLineArgs build() { writerBufferSizeInMb, writerMaxBufferedDocs, forceMergeMaxNumSegments, - onDiskRescore + onDiskRescore, + doPrecondition, + preconditioningDims ); } } diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java index cc007809f258e..68bdc439cf89c 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java @@ -133,7 +133,9 @@ static Codec createCodec(CmdLineArgs args) { args.ivfClusterSize(), ES920DiskBBQVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER, DenseVectorFieldMapper.ElementType.FLOAT, - args.onDiskRescore() + args.onDiskRescore(), + args.doPrecondition(), + args.preconditioningDims() ); } else if (args.indexType() == IndexType.GPU_HNSW) { if (quantizeBits == 32) { diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsReader.java index ef2bd98419eb0..09b6323871513 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsReader.java @@ -174,6 +174,12 @@ protected FieldEntry doReadField( ); } + @Override + protected float[] preconditionVector(FieldInfo fieldInfo, float[] vector) { + // no-op + return vector; + } + private static CentroidIterator getCentroidIteratorNoParent( FieldInfo fieldInfo, IndexInput centroids, diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsWriter.java index f017700d86c78..869f532aaef67 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsWriter.java @@ -378,6 +378,12 @@ public CentroidSupplier createCentroidSupplier( return new OffHeapCentroidSupplier(centroidsInput, numCentroids, fieldInfo); } + @Override + public FloatVectorValues preconditionVectors(FloatVectorValues floatVectorValues) throws IOException { + // no-op + return floatVectorValues; + } + @Override public void writeCentroids( FieldInfo fieldInfo, diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsReader.java index a122440fc58c0..d150a587a1f9d 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsReader.java @@ -275,6 +275,8 @@ public final ByteVectorValues getByteVectorValues(String field) throws IOExcepti return getReaderForField(field).getByteVectorValues(field); } + protected abstract float[] preconditionVector(FieldInfo fieldInfo, float[] vector); + @Override public final void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException { final FieldInfo fieldInfo = state.fieldInfos.fieldInfo(field); @@ -317,6 +319,9 @@ public final void search(String field, float[] target, KnnCollector knnCollector // clip so we visit at least one vector visitRatio = estimated / numVectors; } + // precondition the query vector if necessary + target = preconditionVector(fieldInfo, target); + // we account for soar vectors here. We can potentially visit a vector twice so we multiply by 2 here. long maxVectorVisited = (long) (2.0 * visitRatio * numVectors); IndexInput postListSlice = entry.postingListSlice(ivfClusters); @@ -332,6 +337,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector visitRatio ); Bits acceptDocsBits = acceptDocs.bits(); + PostingVisitor scorer = getPostingVisitor(fieldInfo, postListSlice, target, acceptDocsBits); long expectedDocs = 0; long actualDocs = 0; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsWriter.java index 42b288162ac3e..c88a22ac2b52a 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsWriter.java @@ -185,6 +185,8 @@ public abstract CentroidSupplier createCentroidSupplier( float[] globalCentroid ) throws IOException; + public abstract FloatVectorValues preconditionVectors(FloatVectorValues floatVectorValues) throws IOException; + @Override public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { rawVectorDelegate.flush(maxDoc, sortMap); @@ -195,7 +197,9 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { continue; } // build a float vector values with random access - final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldWriter.fieldInfo, fieldWriter.delegate, maxDoc); + FloatVectorValues floatVectorValues = getFloatVectorValues(fieldWriter.fieldInfo, fieldWriter.delegate, maxDoc); + // precondition the vectors if necessary + floatVectorValues = preconditionVectors(floatVectorValues); // build centroids final CentroidAssignments centroidAssignments = calculateCentroids(fieldWriter.fieldInfo, floatVectorValues); // wrap centroids with a supplier @@ -378,7 +382,9 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws ? null : mergeState.segmentInfo.dir.openInput(docsFileName, IOContext.DEFAULT.withHints(DataAccessHint.SEQUENTIAL)) ) { - final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, docs, vectors, numVectors); + FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, docs, vectors, numVectors); + // precondition vectors if necessary + floatVectorValues = preconditionVectors(floatVectorValues); final long centroidOffset; final long centroidLength; @@ -394,11 +400,7 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws try { centroidTemp = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "civf_", IOContext.DEFAULT); centroidTempName = centroidTemp.getName(); - CentroidAssignments centroidAssignments = calculateCentroids( - fieldInfo, - getFloatVectorValues(fieldInfo, docs, vectors, numVectors), - mergeState - ); + CentroidAssignments centroidAssignments = calculateCentroids(fieldInfo, floatVectorValues, mergeState); // write the centroids to a temporary file so we are not holding them on heap final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); for (float[] centroid : centroidAssignments.centroids()) { diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/PreconditioningProvider.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/PreconditioningProvider.java new file mode 100644 index 0000000000000..347998f7a2d62 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/PreconditioningProvider.java @@ -0,0 +1,301 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.diskbbq; + +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.NumericUtils; +import org.apache.lucene.util.VectorUtil; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Random; + +// TODO: apply to other formats +// TODO: instead of manually having to indicate preconditioning add the ability to decide when to use it given the data on the segment +// TODO: consider global version of preconditioning? + +public class PreconditioningProvider { + + final int blockDim; + final int[][] permutationMatrix; + final float[][][] blocks; + + public PreconditioningProvider(int blockDim, FloatVectorValues vectors) throws IOException { + this.blockDim = blockDim; + int dim = vectors.dimension(); + Random random = new Random(42L); + blocks = PreconditioningProvider.generateRandomOrthogonalMatrix(dim, blockDim, random); + int[] dimBlocks = new int[blocks.length]; + for (int i = 0; i < blocks.length; i++) { + dimBlocks[i] = blocks[i].length; + } + // TODO: test random permutation matrix vs variance based + permutationMatrix = PreconditioningProvider.createPermutationMatrixWEqualVariance(dimBlocks, vectors); + // permutationMatrix = PreconditioningProvider.createPermutationMatrixRandomly(dim, dimBlocks, random); + } + + private PreconditioningProvider(int blockDim, float[][][] blocks, int[][] permutationMatrix) { + this.blockDim = blockDim; + this.permutationMatrix = permutationMatrix; + this.blocks = blocks; + } + + public float[] applyPreconditioningTransform(float[] vector) { + assert vector != null; + + float[] out = new float[vector.length]; + + if (blocks.length == 1) { + matrixVectorMultiply(blocks[0], vector, out); + return out; + } + + int blockIdx = 0; + float[] x = new float[blockDim]; + float[] blockOut = new float[blockDim]; + for (int j = 0; j < blocks.length; j++) { + float[][] block = blocks[j]; + int blockDim = blocks[j].length; + // blockDim is only ever smaller for the tail + if (blockDim != this.blockDim) { + x = new float[blockDim]; + blockOut = new float[blockDim]; + } + for (int k = 0; k < permutationMatrix[j].length; k++) { + int idx = permutationMatrix[j][k]; + x[k] = vector[idx]; + } + matrixVectorMultiply(block, x, blockOut); + System.arraycopy(blockOut, 0, out, blockIdx, blockDim); + blockIdx += blockDim; + } + + return out; + } + + public void write(IndexOutput out) throws IOException { + int rem = blockDim; + if (blocks[blocks.length - 1].length != blockDim) { + rem = blocks[blocks.length - 1].length; + } + + out.writeInt(blocks.length); + out.writeInt(blockDim); + out.writeInt(rem); + out.writeInt(permutationMatrix.length); + + final ByteBuffer blockBuffer = ByteBuffer.allocate( + (blocks.length - 1) * blockDim * blockDim * Float.BYTES + rem * rem * Float.BYTES + ).order(ByteOrder.LITTLE_ENDIAN); + FloatBuffer floatBuffer = blockBuffer.asFloatBuffer(); + for (int i = 0; i < blocks.length; i++) { + for (int j = 0; j < blocks[i].length; j++) { + floatBuffer.put(blocks[i][j]); + } + } + out.writeBytes(blockBuffer.array(), blockBuffer.array().length); + + for (int i = 0; i < permutationMatrix.length; i++) { + out.writeInt(permutationMatrix[i].length); + final ByteBuffer permBuffer = ByteBuffer.allocate(permutationMatrix[i].length * Integer.BYTES).order(ByteOrder.LITTLE_ENDIAN); + permBuffer.asIntBuffer().put(permutationMatrix[i]); + out.writeBytes(permBuffer.array(), permBuffer.array().length); + } + } + + public static PreconditioningProvider read(IndexInput input) throws IOException { + int blocksLen = input.readInt(); + int blockDim = input.readInt(); + int rem = input.readInt(); + int permutationMatrixLen = input.readInt(); + + float[][][] blocks = new float[blocksLen][][]; + int[][] permutationMatrix = new int[permutationMatrixLen][]; + + for (int i = 0; i < blocksLen; i++) { + int blockLen = blocksLen - 1 == i ? rem : blockDim; + blocks[i] = new float[blockLen][blockLen]; + for (int j = 0; j < blockLen; j++) { + input.readFloats(blocks[i][j], 0, blockLen); + } + } + + for (int i = 0; i < permutationMatrix.length; i++) { + int permutationMatrixSubLen = input.readInt(); + permutationMatrix[i] = new int[permutationMatrixSubLen]; + input.readInts(permutationMatrix[i], 0, permutationMatrixSubLen); + } + + return new PreconditioningProvider(blockDim, blocks, permutationMatrix); + } + + // TODO: write Panama version of this + static void modifiedGramSchmidt(float[][] m) { + assert m.length == m[0].length; + int dim = m.length; + for (int i = 0; i < dim; i++) { + double norm = 0.0; + for (float v : m[i]) { + norm += v * v; + } + norm = Math.sqrt(norm); + if (norm == 0.0f) { + continue; + } + for (int j = 0; j < dim; j++) { + m[i][j] /= (float) norm; + } + for (int k = i + 1; k < dim; k++) { + double dotik = 0.0; + for (int j = 0; j < dim; j++) { + dotik += m[i][j] * m[k][j]; + } + for (int j = 0; j < dim; j++) { + m[k][j] -= (float) (dotik * m[i][j]); + } + } + } + } + + private static void randomFill(Random random, float[][] m) { + for (int i = 0; i < m.length; ++i) { + for (int j = 0; j < m[i].length; ++j) { + m[i][j] = (float) random.nextGaussian(); + } + } + } + + private static float[][][] generateRandomOrthogonalMatrix(int dim, int blockDim, Random random) { + blockDim = Math.min(dim, blockDim); + int nBlocks = dim / blockDim; + int rem = dim % blockDim; + + float[][][] blocks = new float[nBlocks + (rem > 0 ? 1 : 0)][][]; + + for (int i = 0; i < nBlocks; i++) { + float[][] m = new float[blockDim][blockDim]; + randomFill(random, m); + modifiedGramSchmidt(m); + blocks[i] = m; + } + + if (rem != 0) { + float[][] m = new float[rem][rem]; + randomFill(random, m); + modifiedGramSchmidt(m); + blocks[nBlocks] = m; + } + + return blocks; + } + + private static void matrixVectorMultiply(float[][] m, float[] x, float[] out) { + assert m.length == x.length; + assert m.length == out.length; + int dim = out.length; + for (int i = 0; i < dim; i++) { + out[i] = VectorUtil.dotProduct(m[i], x); + } + } + + private static int minElementIndex(float[] array) { + int minIndex = 0; + float minValue = array[0]; + for (int i = 1; i < array.length; i++) { + if (array[i] < minValue) { + minValue = array[i]; + minIndex = i; + } + } + return minIndex; + } + + private static int[][] createPermutationMatrixRandomly(int dim, int[] dimBlocks, Random random) { + // Randomly assign dimensions to blocks. + List indices = new ArrayList<>(dim); + for (int i = 0; i < dim; i++) { + indices.add(i); + } + Collections.shuffle(indices, random); + + int[][] permutationMatrix = new int[dimBlocks.length][]; + int pos = 0; + for (int i = 0; i < dimBlocks.length; i++) { + permutationMatrix[i] = new int[dimBlocks[i]]; + for (int j = 0; j < dimBlocks[i]; j++) { + permutationMatrix[i][j] = indices.get(pos++); + } + Arrays.sort(permutationMatrix[i]); + } + + return permutationMatrix; + } + + private static int[][] createPermutationMatrixWEqualVariance(int[] dimBlocks, FloatVectorValues vectors) throws IOException { + int dim = vectors.dimension(); + + if (dimBlocks.length == 1) { + int[] indices = new int[dim]; + for (int i = 0; i < indices.length; i++) { + indices[i] = i; + } + return new int[][] { indices }; + } + + // Use a greedy approach to pick assignments to blocks that equalizes their variance. + float[] means = new float[dim]; + float[] variances = new float[dim]; + int[] n = new int[dim]; + + // TODO: write Panama version of this + for (int i = 0; i < vectors.size(); ++i) { + float[] vector = vectors.vectorValue(i); + for (int j = 0; j < dim; j++) { + float value = vector[j]; + n[j]++; + double delta = value - means[j]; + means[j] += (float) (delta / n[j]); + variances[j] += (float) (delta * (value - means[j])); + } + } + + int[] indices = new int[dim]; + for (int i = 0; i < indices.length; i++) { + indices[i] = i; + } + new IntSorter(indices, i -> NumericUtils.floatToSortableInt(variances[i])).sort(0, indices.length); + + int[][] permutationMatrix = new int[dimBlocks.length][]; + for (int i = 0; i < permutationMatrix.length; i++) { + permutationMatrix[i] = new int[dimBlocks[i]]; + } + float[] cumulativeVariances = new float[dimBlocks.length]; + int[] jthIdx = new int[permutationMatrix.length]; + for (int i : indices) { + int j = minElementIndex(cumulativeVariances); + permutationMatrix[j][jthIdx[j]++] = i; + cumulativeVariances[j] = (jthIdx[j] == dimBlocks[j] ? Float.MAX_VALUE : cumulativeVariances[j] + variances[i]); + } + for (int[] matrix : permutationMatrix) { + Arrays.sort(matrix); + } + + return permutationMatrix; + } + +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsFormat.java index 0ecc63eb1a5d7..8e10de8a223f9 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsFormat.java @@ -72,6 +72,9 @@ public class ESNextDiskBBQVectorsFormat extends KnnVectorsFormat { public static final int DEFAULT_CENTROIDS_PER_PARENT_CLUSTER = 16; public static final int MIN_CENTROIDS_PER_PARENT_CLUSTER = 2; public static final int MAX_CENTROIDS_PER_PARENT_CLUSTER = 1 << 8; // 256 + public static final int DEFAULT_PRECONDITIONING_BLOCK_DIMENSION = 32; + public static final int MIN_PRECONDITIONING_BLOCK_DIMS = 8; + public static final int MAX_PRECONDITIONING_BLOCK_DIMS = 384; public enum QuantEncoding { ONE_BIT_4BIT_QUERY(0, (byte) 1, (byte) 4) { @@ -213,13 +216,23 @@ public static QuantEncoding fromId(int id) { private final int centroidsPerParentCluster; private final boolean useDirectIO; private final DirectIOCapableFlatVectorsFormat rawVectorFormat; + private final boolean doPrecondition; + private final int preconditioningBlockDimension; public ESNextDiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster) { this(QuantEncoding.ONE_BIT_4BIT_QUERY, vectorPerCluster, centroidsPerParentCluster); } public ESNextDiskBBQVectorsFormat(QuantEncoding quantEncoding, int vectorPerCluster, int centroidsPerParentCluster) { - this(quantEncoding, vectorPerCluster, centroidsPerParentCluster, DenseVectorFieldMapper.ElementType.FLOAT, false); + this( + quantEncoding, + vectorPerCluster, + centroidsPerParentCluster, + DenseVectorFieldMapper.ElementType.FLOAT, + false, + false, + DEFAULT_PRECONDITIONING_BLOCK_DIMENSION + ); } public ESNextDiskBBQVectorsFormat( @@ -227,7 +240,9 @@ public ESNextDiskBBQVectorsFormat( int vectorPerCluster, int centroidsPerParentCluster, DenseVectorFieldMapper.ElementType elementType, - boolean useDirectIO + boolean useDirectIO, + boolean doPrecondition, + int preconditioningBlockDimension ) { super(NAME); if (vectorPerCluster < MIN_VECTORS_PER_CLUSTER || vectorPerCluster > MAX_VECTORS_PER_CLUSTER) { @@ -259,6 +274,21 @@ public ESNextDiskBBQVectorsFormat( default -> throw new IllegalArgumentException("Unsupported element type " + elementType); }; this.useDirectIO = useDirectIO; + + if (preconditioningBlockDimension < MIN_PRECONDITIONING_BLOCK_DIMS + || preconditioningBlockDimension > MAX_PRECONDITIONING_BLOCK_DIMS) { + throw new IllegalArgumentException( + "preconditioningBlockDimension must be between " + + MIN_PRECONDITIONING_BLOCK_DIMS + + " and " + + MAX_PRECONDITIONING_BLOCK_DIMS + + ", got: " + + preconditioningBlockDimension + ); + } + // TODO: make these settable via DenseVectorFieldMapper + this.preconditioningBlockDimension = preconditioningBlockDimension; + this.doPrecondition = doPrecondition; } /** Constructs a format using the given graph construction parameters and scalar quantization. */ @@ -275,7 +305,9 @@ public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException rawVectorFormat.fieldsWriter(state), quantEncoding, vectorPerCluster, - centroidsPerParentCluster + centroidsPerParentCluster, + doPrecondition, + preconditioningBlockDimension ); } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsReader.java index fe587a35a6874..46283e5dd72c5 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsReader.java @@ -31,6 +31,7 @@ import org.elasticsearch.index.codec.vectors.cluster.NeighborQueue; import org.elasticsearch.index.codec.vectors.diskbbq.DocIdsWriter; import org.elasticsearch.index.codec.vectors.diskbbq.IVFVectorsReader; +import org.elasticsearch.index.codec.vectors.diskbbq.PreconditioningProvider; import org.elasticsearch.simdvec.ES92Int7VectorsScorer; import org.elasticsearch.simdvec.ESNextOSQVectorsScorer; import org.elasticsearch.simdvec.ESVectorUtil; @@ -187,6 +188,18 @@ public CentroidIterator getCentroidIterator( return getPostingListPrefetchIterator(centroidIterator, postingListSlice); } + @Override + protected float[] preconditionVector(FieldInfo fieldInfo, float[] vector) { + FieldEntry entry = fields.get(fieldInfo.number); + PreconditioningProvider preconditioningProvider = ((NextFieldEntry) entry).preconditioningProvider(); + // only precondition if during writing preconditioning was enabled + if (preconditioningProvider != null) { + // have to copy so we don't modify the original search vector + return preconditioningProvider.applyPreconditioningTransform(vector); + } + return vector; + } + @Override protected FieldEntry doReadField( IndexInput input, @@ -203,6 +216,10 @@ protected FieldEntry doReadField( float globalCentroidDp ) throws IOException { ESNextDiskBBQVectorsFormat.QuantEncoding quantEncoding = ESNextDiskBBQVectorsFormat.QuantEncoding.fromId(input.readInt()); + PreconditioningProvider preconditioningProvider = null; + if (input.readByte() == (byte) 1) { + preconditioningProvider = PreconditioningProvider.read(input); + } return new NextFieldEntry( rawVectorFormat, useDirectIOReads, @@ -215,12 +232,14 @@ protected FieldEntry doReadField( postingListLength, globalCentroid, globalCentroidDp, - quantEncoding + quantEncoding, + preconditioningProvider ); } static class NextFieldEntry extends FieldEntry { private final ESNextDiskBBQVectorsFormat.QuantEncoding quantEncoding; + private final PreconditioningProvider preconditioningProvider; NextFieldEntry( String rawVectorFormat, @@ -234,7 +253,8 @@ static class NextFieldEntry extends FieldEntry { long postingListLength, float[] globalCentroid, float globalCentroidDp, - ESNextDiskBBQVectorsFormat.QuantEncoding quantEncoding + ESNextDiskBBQVectorsFormat.QuantEncoding quantEncoding, + PreconditioningProvider preconditioningProvider ) { super( rawVectorFormat, @@ -250,11 +270,16 @@ static class NextFieldEntry extends FieldEntry { globalCentroidDp ); this.quantEncoding = quantEncoding; + this.preconditioningProvider = preconditioningProvider; } public ESNextDiskBBQVectorsFormat.QuantEncoding quantEncoding() { return quantEncoding; } + + public PreconditioningProvider preconditioningProvider() { + return preconditioningProvider; + } } private static CentroidIterator getCentroidIteratorNoParent( diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsWriter.java index 49d0554a84d80..459014982a72a 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsWriter.java @@ -34,6 +34,7 @@ import org.elasticsearch.index.codec.vectors.diskbbq.IVFVectorsWriter; import org.elasticsearch.index.codec.vectors.diskbbq.IntSorter; import org.elasticsearch.index.codec.vectors.diskbbq.IntToBooleanFunction; +import org.elasticsearch.index.codec.vectors.diskbbq.PreconditioningProvider; import org.elasticsearch.index.codec.vectors.diskbbq.QuantizedVectorValues; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; @@ -62,6 +63,8 @@ public class ESNextDiskBBQVectorsWriter extends IVFVectorsWriter { private final int vectorPerCluster; private final int centroidsPerParentCluster; private final ESNextDiskBBQVectorsFormat.QuantEncoding quantEncoding; + private final boolean doPrecondition; + private final int preconditioningBlockDimension; public ESNextDiskBBQVectorsWriter( SegmentWriteState state, @@ -70,12 +73,57 @@ public ESNextDiskBBQVectorsWriter( FlatVectorsWriter rawVectorDelegate, ESNextDiskBBQVectorsFormat.QuantEncoding encoding, int vectorPerCluster, - int centroidsPerParentCluster + int centroidsPerParentCluster, + boolean doPrecondition, + int preconditioningBlockDimension ) throws IOException { super(state, rawVectorFormatName, useDirectIOReads, rawVectorDelegate, ESNextDiskBBQVectorsFormat.VERSION_CURRENT); this.vectorPerCluster = vectorPerCluster; this.centroidsPerParentCluster = centroidsPerParentCluster; this.quantEncoding = encoding; + this.doPrecondition = doPrecondition; + this.preconditioningBlockDimension = preconditioningBlockDimension; + } + + private PreconditioningProvider preconditioningProvider; + + @Override + public FloatVectorValues preconditionVectors(FloatVectorValues vectors) throws IOException { + if (doPrecondition) { + preconditioningProvider = new PreconditioningProvider(preconditioningBlockDimension, vectors); + return new FloatVectorValues() { + float[] vectorValue; + int cachedOrd = -1; + + @Override + public float[] vectorValue(int ord) throws IOException { + assert ord != -1; + if (ord != cachedOrd) { + float[] vectorValue = vectors.vectorValue(ord); + this.vectorValue = preconditioningProvider.applyPreconditioningTransform(vectorValue); + cachedOrd = ord; + } + return this.vectorValue; + } + + @Override + public FloatVectorValues copy() throws IOException { + return vectors.copy(); + } + + @Override + public int dimension() { + return vectors.dimension(); + } + + @Override + public int size() { + return vectors.size(); + } + }; + } else { + return vectors; + } } @Override @@ -376,6 +424,12 @@ public CentroidSupplier createCentroidSupplier( @Override protected void doWriteMeta(IndexOutput metaOutput, FieldInfo field, int numCentroids) throws IOException { metaOutput.writeInt(quantEncoding.id()); + if (preconditioningProvider != null) { + metaOutput.writeByte((byte) 1); + preconditioningProvider.write(metaOutput); + } else { + metaOutput.writeByte((byte) 0); + } } @Override diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index aa02950e460b9..27a93a3fee78b 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -2386,7 +2386,9 @@ KnnVectorsFormat getVectorsFormat(ElementType elementType) { clusterSize, ES920DiskBBQVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER, elementType, - onDiskRescore + onDiskRescore, + false, + ESNextDiskBBQVectorsFormat.DEFAULT_PRECONDITIONING_BLOCK_DIMENSION ); } return new ES920DiskBBQVectorsFormat( diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/PreconditioningProviderTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/PreconditioningProviderTests.java new file mode 100644 index 0000000000000..3f1970757953e --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/PreconditioningProviderTests.java @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.diskbbq; + +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.tests.util.LuceneTestCase; + +import java.io.IOException; + +public class PreconditioningProviderTests extends LuceneTestCase { + public void testRandomProviderConfigurations() throws IOException { + int dim = random().nextInt(128, 1024); + + int corpusLen = random().nextInt(100, 200); + float[][] corpus = new float[corpusLen][]; + for (int i = 0; i < corpusLen; i++) { + corpus[i] = new float[dim]; + for (int j = 0; j < dim; j++) { + if (j > 320) { + corpus[i][j] = 0f; + } else { + corpus[i][j] = random().nextFloat(); + } + } + } + + float[] query = new float[dim]; + for (int i = 0; i < dim; i++) { + query[i] = random().nextFloat(); + } + + int blockDim = random().nextInt(8, 384); + + PreconditioningProvider preconditioningProvider = new PreconditioningProvider(blockDim, new FloatVectorValues() { + @Override + public int size() { + return corpus.length; + } + + @Override + public int dimension() { + return dim; + } + + @Override + public float[] vectorValue(int targetOrd) { + return corpus[targetOrd]; + } + + @Override + public FloatVectorValues copy() { + return this; + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + }); + + preconditioningProvider.applyPreconditioningTransform(query); + + assertEquals(blockDim, preconditioningProvider.blockDim); + assertEquals(dim / blockDim + 1, preconditioningProvider.permutationMatrix.length); + assertEquals(Math.min(blockDim, dim), preconditioningProvider.permutationMatrix[0].length); + assertEquals( + dim - (long) (dim / blockDim) * blockDim, + preconditioningProvider.permutationMatrix[preconditioningProvider.permutationMatrix.length - 1].length + ); + assertEquals(dim / blockDim + 1, preconditioningProvider.blocks.length); + assertEquals(Math.min(blockDim, dim), preconditioningProvider.blocks[0].length); + assertEquals(Math.min(blockDim, dim), preconditioningProvider.blocks[0][0].length); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQBFloat16VectorsFormatTests.java index 89fdc382874a1..452c13883f046 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQBFloat16VectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQBFloat16VectorsFormatTests.java @@ -21,7 +21,6 @@ import org.apache.lucene.tests.util.TestUtil; import org.elasticsearch.common.logging.LogConfigurator; import org.elasticsearch.index.codec.vectors.BaseBFloat16KnnVectorsFormatTestCase; -import org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.junit.AssumptionViolatedException; import org.junit.Before; @@ -29,8 +28,13 @@ import java.io.IOException; import java.util.List; -import static org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat.MIN_CENTROIDS_PER_PARENT_CLUSTER; -import static org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat.MIN_VECTORS_PER_CLUSTER; +import static org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat.DEFAULT_PRECONDITIONING_BLOCK_DIMENSION; +import static org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat.MAX_CENTROIDS_PER_PARENT_CLUSTER; +import static org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat.MAX_PRECONDITIONING_BLOCK_DIMS; +import static org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat.MAX_VECTORS_PER_CLUSTER; +import static org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat.MIN_CENTROIDS_PER_PARENT_CLUSTER; +import static org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat.MIN_PRECONDITIONING_BLOCK_DIMS; +import static org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat.MIN_VECTORS_PER_CLUSTER; import static org.hamcrest.Matchers.anEmptyMap; import static org.hamcrest.Matchers.equalTo; @@ -52,10 +56,22 @@ public void setUp() throws Exception { if (rarely()) { format = new ESNextDiskBBQVectorsFormat( encoding, - random().nextInt(2 * MIN_VECTORS_PER_CLUSTER, ES920DiskBBQVectorsFormat.MAX_VECTORS_PER_CLUSTER), - random().nextInt(8, ES920DiskBBQVectorsFormat.MAX_CENTROIDS_PER_PARENT_CLUSTER), + random().nextInt(2 * MIN_VECTORS_PER_CLUSTER, MAX_VECTORS_PER_CLUSTER), + random().nextInt(8, MAX_CENTROIDS_PER_PARENT_CLUSTER), DenseVectorFieldMapper.ElementType.BFLOAT16, - random().nextBoolean() + random().nextBoolean(), + false, + DEFAULT_PRECONDITIONING_BLOCK_DIMENSION + ); + } else if (rarely()) { + format = new ESNextDiskBBQVectorsFormat( + encoding, + random().nextInt(MIN_VECTORS_PER_CLUSTER, MAX_VECTORS_PER_CLUSTER), + random().nextInt(MIN_CENTROIDS_PER_PARENT_CLUSTER, MAX_CENTROIDS_PER_PARENT_CLUSTER), + DenseVectorFieldMapper.ElementType.BFLOAT16, + false, + true, + random().nextInt(MIN_PRECONDITIONING_BLOCK_DIMS, MAX_PRECONDITIONING_BLOCK_DIMS) ); } else { // run with low numbers to force many clusters with parents @@ -64,7 +80,9 @@ public void setUp() throws Exception { random().nextInt(MIN_VECTORS_PER_CLUSTER, 2 * MIN_VECTORS_PER_CLUSTER), random().nextInt(MIN_CENTROIDS_PER_PARENT_CLUSTER, 8), DenseVectorFieldMapper.ElementType.BFLOAT16, - random().nextBoolean() + random().nextBoolean(), + false, + DEFAULT_PRECONDITIONING_BLOCK_DIMENSION ); } super.setUp(); diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsFormatTests.java index 04e1416c9c3c4..c407a63a0b2b6 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsFormatTests.java @@ -35,6 +35,7 @@ import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.junit.Before; import java.io.IOException; @@ -46,8 +47,10 @@ import static java.lang.String.format; import static org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat.MAX_CENTROIDS_PER_PARENT_CLUSTER; +import static org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat.MAX_PRECONDITIONING_BLOCK_DIMS; import static org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat.MAX_VECTORS_PER_CLUSTER; import static org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat.MIN_CENTROIDS_PER_PARENT_CLUSTER; +import static org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat.MIN_PRECONDITIONING_BLOCK_DIMS; import static org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat.MIN_VECTORS_PER_CLUSTER; import static org.hamcrest.Matchers.anEmptyMap; import static org.hamcrest.Matchers.equalTo; @@ -71,8 +74,18 @@ public void setUp() throws Exception { if (rarely()) { format = new ESNextDiskBBQVectorsFormat( encoding, - random().nextInt(2 * MIN_VECTORS_PER_CLUSTER, ESNextDiskBBQVectorsFormat.MAX_VECTORS_PER_CLUSTER), - random().nextInt(8, ESNextDiskBBQVectorsFormat.MAX_CENTROIDS_PER_PARENT_CLUSTER) + random().nextInt(2 * MIN_VECTORS_PER_CLUSTER, MAX_VECTORS_PER_CLUSTER), + random().nextInt(8, MAX_CENTROIDS_PER_PARENT_CLUSTER) + ); + } else if (rarely()) { + format = new ESNextDiskBBQVectorsFormat( + encoding, + random().nextInt(MIN_VECTORS_PER_CLUSTER, MAX_VECTORS_PER_CLUSTER), + random().nextInt(MIN_CENTROIDS_PER_PARENT_CLUSTER, MAX_CENTROIDS_PER_PARENT_CLUSTER), + DenseVectorFieldMapper.ElementType.FLOAT, + false, + true, + random().nextInt(MIN_PRECONDITIONING_BLOCK_DIMS, MAX_PRECONDITIONING_BLOCK_DIMS) ); } else { // run with low numbers to force many clusters with parents