diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index 672b15d3ddf0..96f27f73c88e 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -155,6 +155,7 @@ + diff --git a/dotnet/SK-dotnet.slnx b/dotnet/SK-dotnet.slnx index 687e09529260..1b95d7421e07 100644 --- a/dotnet/SK-dotnet.slnx +++ b/dotnet/SK-dotnet.slnx @@ -149,6 +149,7 @@ + @@ -324,6 +325,7 @@ + diff --git a/dotnet/samples/GettingStartedWithVectorStores/GettingStartedWithVectorStores.csproj b/dotnet/samples/GettingStartedWithVectorStores/GettingStartedWithVectorStores.csproj index 5c1fc5edb9cc..36707e4dd0fb 100644 --- a/dotnet/samples/GettingStartedWithVectorStores/GettingStartedWithVectorStores.csproj +++ b/dotnet/samples/GettingStartedWithVectorStores/GettingStartedWithVectorStores.csproj @@ -42,6 +42,7 @@ + diff --git a/dotnet/samples/GettingStartedWithVectorStores/README.md b/dotnet/samples/GettingStartedWithVectorStores/README.md index 36efdea60e78..4e72b84346dd 100644 --- a/dotnet/samples/GettingStartedWithVectorStores/README.md +++ b/dotnet/samples/GettingStartedWithVectorStores/README.md @@ -4,6 +4,10 @@ This project contains a step by step guide to get started using Vector Stores wi The examples can be run as integration tests but their code can also be copied to stand-alone programs. +## Step 5 — LiteDB advanced scenario + +`Step5_LiteDb_AdvancedScenario` shows how to combine transactional batch upserts, per-property embedding generators and metrics, and the extended filter surface (`Contains`, `StartsWith`, `EndsWith`, `IN`) when working with the LiteDB connector. + ## Configuring Secrets Most of the examples will require secrets and credentials, to access OpenAI, Azure OpenAI, diff --git a/dotnet/samples/GettingStartedWithVectorStores/Step5_LiteDb_AdvancedScenario.cs b/dotnet/samples/GettingStartedWithVectorStores/Step5_LiteDb_AdvancedScenario.cs new file mode 100644 index 000000000000..77232318c6c2 --- /dev/null +++ b/dotnet/samples/GettingStartedWithVectorStores/Step5_LiteDb_AdvancedScenario.cs @@ -0,0 +1,173 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.LiteDb; + +namespace GettingStartedWithVectorStores; + +/// +/// Demonstrates advanced LiteDB usage: per-property generators and metrics, transactional batch upserts, and expressive filters. +/// +public sealed class Step5_LiteDb_AdvancedScenario(ITestOutputHelper output) : BaseTest(output) +{ + private static readonly string[] FilterCities = new[] { "Seattle", "Portland" }; + + [Fact] + public async Task RunEndToEndScenarioAsync() + { + var services = new ServiceCollection(); + services.AddSingleton>>(new SampleEmbeddingGenerator("store")); + services.AddLiteDbVectorStore(_ => new LiteDbVectorStoreOptions + { + DisposeDatabase = false, + DistanceMetric = LiteDbDistanceMetric.Cosine, + AutoCreateVectorIndexes = true, + CollectionNamePrefix = "sample_" + }); + services.AddLiteDbCollection("hotels", _ => new VectorStoreCollectionDefinition + { + Properties = + { + new VectorStoreKeyProperty(nameof(SampleHotel.HotelId), typeof(string)), + new VectorStoreDataProperty(nameof(SampleHotel.City), typeof(string)), + new VectorStoreDataProperty(nameof(SampleHotel.Tags), typeof(List)), + new VectorStoreDataProperty(nameof(SampleHotel.Description), typeof(string)), + new VectorStoreVectorProperty(nameof(SampleHotel.Overview), typeof(string), 3) + { + DistanceFunction = DistanceFunction.DotProductSimilarity, + EmbeddingGenerator = new SampleEmbeddingGenerator("overview") + }, + new VectorStoreVectorProperty(nameof(SampleHotel.Amenities), typeof(string), 3) + { + DistanceFunction = DistanceFunction.EuclideanDistance, + EmbeddingGenerator = new SampleEmbeddingGenerator("amenities") + } + } + }); + + using var provider = services.BuildServiceProvider(); + var collection = provider.GetRequiredService>(); + await collection.EnsureCollectionExistsAsync(); + + var hotels = new[] + { + new SampleHotel + { + HotelId = "alpha", + City = "Seattle", + Description = "Waterfront spa resort", + Tags = new List { "spa", "rooftop" }, + Overview = "0.9,0.05,0.05", + Amenities = "0.2,0.5,0.3" + }, + new SampleHotel + { + HotelId = "beta", + City = "Portland", + Description = "Modern downtown escape", + Tags = new List { "boutique", "spa" }, + Overview = "0.7,0.1,0.2", + Amenities = "0.6,0.2,0.2" + }, + new SampleHotel + { + HotelId = "gamma", + City = "San Francisco", + Description = "Historic harbor view", + Tags = new List { "historic", "view" }, + Overview = "0.1,0.8,0.1", + Amenities = "0.3,0.1,0.6" + } + }; + + // Batch upsert executes inside a transaction – either all records are stored or none are. + await collection.UpsertAsync(hotels); + +#pragma warning disable CA1866 // the literal "t" is clearer for sample filtering. + var searchOptions = new VectorSearchOptions + { + VectorProperty = h => h.Overview, + Filter = h => (h.Tags.Contains("spa") || h.City!.StartsWith("Sea")) + && FilterCities.Contains(h.City!) + && h.Description!.EndsWith("t"), + IncludeVectors = false + }; +#pragma warning restore CA1866 + + var results = await collection.SearchAsync("0.8,0.1,0.1", top: 2, searchOptions).ToListAsync(); + + foreach (var result in results) + { + this.WriteLine($"Hotel: {result.Record.HotelId} in {result.Record.City} (score {result.Score:F3})"); + } + + Assert.Single(results); + Assert.Equal("alpha", results[0].Record.HotelId); + } + + private sealed class SampleHotel + { + [VectorStoreKey] + public string? HotelId { get; set; } + + [VectorStoreData] + public string? City { get; set; } + + [VectorStoreData] + public string? Description { get; set; } + + [VectorStoreData] + public List Tags { get; set; } = new(); + + [VectorStoreVector(Dimensions: 3)] + public string? Overview { get; set; } + + [VectorStoreVector(Dimensions: 3)] + public string? Amenities { get; set; } + } + + private sealed class SampleEmbeddingGenerator(string tag) : IEmbeddingGenerator> + { + public Task>> GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) + { + var embeddings = new GeneratedEmbeddings>(); + foreach (var value in values) + { + embeddings.Add(new Embedding(ParseVector(value, tag))); + } + + return Task.FromResult(embeddings); + } + + public Task> GenerateAsync(string value, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) + => Task.FromResult(new Embedding(ParseVector(value, tag))); + + public object? GetService(Type serviceType, object? serviceKey = null) + => null; + + public void Dispose() + { + } + + private static float[] ParseVector(string value, string tagPrefix) + { + var values = value.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries) + .Select(float.Parse) + .ToArray(); + + // Different prefixes allow callers to verify generator precedence in diagnostics. + return tagPrefix switch + { + "overview" => values, + _ => values.Select(v => Math.Clamp(v + 0.05f, 0f, 1f)).ToArray() + }; + } + } +} diff --git a/dotnet/src/VectorData/LiteDb/LiteDb.csproj b/dotnet/src/VectorData/LiteDb/LiteDb.csproj new file mode 100644 index 000000000000..1a7714e9cf51 --- /dev/null +++ b/dotnet/src/VectorData/LiteDb/LiteDb.csproj @@ -0,0 +1,37 @@ + + + + Microsoft.SemanticKernel.Connectors.LiteDb + $(AssemblyName) + net8.0;netstandard2.1 + preview + + + + + + LiteDB provider for Microsoft.Extensions.VectorData + LiteDB provider for Microsoft.Extensions.VectorData by Semantic Kernel + VECTORDATA-CONNECTORS-NUGET.md + + + + + + + + + + + + + + + + + + + + + + diff --git a/dotnet/src/VectorData/LiteDb/LiteDbCollection.cs b/dotnet/src/VectorData/LiteDb/LiteDbCollection.cs new file mode 100644 index 000000000000..60a6fd90cc23 --- /dev/null +++ b/dotnet/src/VectorData/LiteDb/LiteDbCollection.cs @@ -0,0 +1,404 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Threading; +using System.Threading.Tasks; +using LiteDB; +using LiteDB.Vector; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ProviderServices; + +namespace Microsoft.SemanticKernel.Connectors.LiteDb; + +#pragma warning disable CA1711 // Identifiers should not have incorrect suffix +/// +/// Service for storing and retrieving vector records backed by LiteDB. +/// +/// The data type of the record key. +/// The record data model used when interacting with the collection. +public class LiteDbCollection : VectorStoreCollection + where TKey : notnull + where TRecord : class +#pragma warning restore CA1711 +{ + private readonly LiteDatabase _database; + private readonly ILiteCollection _collection; + private readonly CollectionModel _model; + private readonly LiteDbMapper _mapper; + private readonly LiteDbFilterTranslator _filterTranslator = new(); + private readonly LiteDbVectorStoreOptions _storeOptions; + private readonly LiteDbCollectionOptions _options; + private readonly VectorStoreCollectionMetadata _collectionMetadata; + private readonly IReadOnlyList _vectorProperties; + private readonly LiteDbDistanceMetric _defaultMetric; + + /// + public override string Name { get; } + + internal LiteDbCollection( + LiteDatabase database, + string name, + LiteDbVectorStoreOptions storeOptions, + LiteDbCollectionOptions? options, + Func modelFactory, + string connectionIdentifier) + { + ArgumentNullException.ThrowIfNull(database); + ArgumentNullException.ThrowIfNull(storeOptions); + if (string.IsNullOrWhiteSpace(name)) + { + throw new ArgumentException("Collection name cannot be null or whitespace.", nameof(name)); + } + + this._database = database; + this.Name = name; + this._storeOptions = storeOptions; + this._options = options ?? LiteDbCollectionOptions.Default; + + this._model = modelFactory(this._options); + if (typeof(TKey) != typeof(string) && typeof(TKey) != typeof(object)) + { + throw new NotSupportedException("LiteDB connector currently supports string keys."); + } + + this._collection = this._database.GetCollection(name); + this._mapper = new LiteDbMapper(this._model); + this._vectorProperties = this._model.VectorProperties; + this._defaultMetric = this._options.DistanceMetric ?? storeOptions.DistanceMetric; + + this._collectionMetadata = new() + { + VectorStoreSystemName = LiteDbConstants.VectorStoreSystemName, + VectorStoreName = connectionIdentifier, + CollectionName = name + }; + } + + /// + public override Task CollectionExistsAsync(CancellationToken cancellationToken = default) + { + var names = this._database.GetCollectionNames(); + var exists = names.Contains(this.Name, StringComparer.OrdinalIgnoreCase); + return Task.FromResult(exists); + } + + /// + public override Task EnsureCollectionExistsAsync(CancellationToken cancellationToken = default) + { + if (this._storeOptions.AutoEnsureVectorIndex && this._vectorProperties.Count > 0) + { + foreach (var vectorProperty in this._vectorProperties) + { + var dimensions = this._options.VectorDimensions ?? vectorProperty.Dimensions; + if (dimensions <= 0) + { + throw new InvalidOperationException($"Vector property '{vectorProperty.ModelName}' must specify dimensions when creating a LiteDB collection."); + } + + var metric = ResolveMetric(vectorProperty.DistanceFunction, this._options.DistanceMetric ?? this._defaultMetric); + var path = BsonExpression.Create($"$.{vectorProperty.StorageName}"); + var indexOptions = new VectorIndexOptions((ushort)dimensions, MapMetric(metric)); + this._collection.EnsureIndex(vectorProperty.StorageName, path, indexOptions); + } + } + + return Task.CompletedTask; + } + + /// + public override Task EnsureCollectionDeletedAsync(CancellationToken cancellationToken = default) + { + this._database.DropCollection(this.Name); + return Task.CompletedTask; + } + + /// + public override Task GetAsync(TKey key, RecordRetrievalOptions? options = default, CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(key); + options ??= new RecordRetrievalOptions(); + + if (options.IncludeVectors && this._model.EmbeddingGenerationRequired) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } + + var bsonKey = new BsonValue(key); + var document = this._collection.FindById(bsonKey); + if (document is null) + { + return Task.FromResult(null); + } + + var record = this._mapper.MapToRecord(document, options.IncludeVectors); + return Task.FromResult(record); + } + + /// + public override Task DeleteAsync(TKey key, CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(key); + this._collection.Delete(new BsonValue(key)); + return Task.CompletedTask; + } + + /// + public override Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) + => this.UpsertAsync([record], cancellationToken); + + /// + public override async Task UpsertAsync(IEnumerable records, CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(records); + + var materialized = records as IList ?? records.ToList(); + if (materialized.Count == 0) + { + return; + } + + IReadOnlyDictionary>? generatedVectors = null; + if (this._model.EmbeddingGenerationRequired) + { + generatedVectors = await this.GenerateEmbeddingsAsync(materialized, cancellationToken).ConfigureAwait(false); + } + + var documents = new List(materialized.Count); + for (var i = 0; i < materialized.Count; i++) + { + cancellationToken.ThrowIfCancellationRequested(); + var document = this._mapper.MapToDocument(materialized[i], generatedVectors, i); + this.ValidateVectorDimensions(document); + documents.Add(document); + } + + if (documents.Count == 1) + { + this._collection.Upsert(documents[0]); + return; + } + + var startedTransaction = this._database.BeginTrans(); + try + { + this._collection.Upsert(documents); + if (startedTransaction) + { + this._database.Commit(); + } + } + catch + { + if (startedTransaction) + { + this._database.Rollback(); + } + + throw; + } + } + + /// + public override IAsyncEnumerable GetAsync(Expression> filter, int top, FilteredRecordRetrievalOptions? options = null, CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(filter); + if (top <= 0) + { + throw new ArgumentOutOfRangeException(nameof(top)); + } + + options ??= new FilteredRecordRetrievalOptions(); + if (options.IncludeVectors && this._model.EmbeddingGenerationRequired) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } + + return this.ExecuteQueryAsync(filter, top, options, cancellationToken); + } + + private async IAsyncEnumerable ExecuteQueryAsync(Expression> filter, int top, FilteredRecordRetrievalOptions options, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken) + { + var query = this._collection.Query(); + var (expression, parameters) = this._filterTranslator.Translate(filter, this._model); + query = query.Where(expression, parameters); + + var result = query.Limit(top + options.Skip); + if (options.Skip > 0) + { + result = result.Skip(options.Skip); + } + + foreach (var document in result.ToEnumerable()) + { + cancellationToken.ThrowIfCancellationRequested(); + yield return this._mapper.MapToRecord(document, options.IncludeVectors); + } + } + + /// + public override async IAsyncEnumerable> SearchAsync(TInput searchValue, int top, VectorSearchOptions? options = null, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default) + { + if (top < 1) + { + throw new ArgumentOutOfRangeException(nameof(top)); + } + options ??= new VectorSearchOptions(); + + if (options.IncludeVectors && this._model.EmbeddingGenerationRequired) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } + + var vectorProperty = this._model.GetVectorPropertyOrSingle(options); + var searchVector = await this.ResolveSearchVectorAsync(searchValue, vectorProperty, cancellationToken).ConfigureAwait(false); + + if (searchVector.Length == 0) + { + yield break; + } + + var metric = ResolveMetric(vectorProperty.DistanceFunction, this._options.DistanceMetric ?? this._defaultMetric); + var query = this._collection.Query(); + + if (options.Filter is not null) + { + var (expression, parameters) = this._filterTranslator.Translate(options.Filter, this._model); + query = query.Where(expression, parameters); + } + + var vectorArray = searchVector.ToArray(); + var take = top + options.Skip; + var results = query.TopKNear($"$.{vectorProperty.StorageName}", vectorArray, take).ToEnumerable(); + + var index = 0; + + foreach (var document in results) + { + cancellationToken.ThrowIfCancellationRequested(); + if (index++ < options.Skip) + { + continue; + } + var record = this._mapper.MapToRecord(document, options.IncludeVectors); + + if (!document.TryGetValue(vectorProperty.StorageName, out var value) || value is not BsonVector candidate) + { + continue; + } + + var score = LiteDbVectorMath.Compare(vectorArray, candidate.Values, metric); + yield return new VectorSearchResult(record, score); + } + } + + /// + public override object? GetService(Type serviceType, object? serviceKey = null) + { + ArgumentNullException.ThrowIfNull(serviceType); + + return serviceKey is not null + ? null + : serviceType == typeof(VectorStoreCollectionMetadata) + ? this._collectionMetadata + : serviceType.IsInstanceOfType(this) + ? this + : null; + } + + private async Task>?> GenerateEmbeddingsAsync(IList records, CancellationToken cancellationToken) + { + if (this._vectorProperties.Count == 0) + { + return null; + } + + Dictionary>? generated = null; + + foreach (var property in this._vectorProperties) + { + var existingValue = property.GetValueAsObject(records[0]); + if (existingValue is not null && LiteDbMapper.TryConvertVector(existingValue, out _)) + { + continue; + } + + if (property.TryGenerateEmbeddings>(records, cancellationToken, out var task)) + { + var embeddings = (IReadOnlyList>)await task.ConfigureAwait(false); + generated ??= new Dictionary>(StringComparer.Ordinal); + generated[property.ModelName] = embeddings.Select(e => e.Vector.ToArray()).ToList(); + } + else + { + throw new InvalidOperationException(VectorDataStrings.IncompatibleEmbeddingGeneratorWasConfiguredForInputType(typeof(TRecord), property.EmbeddingGenerator?.GetType() ?? typeof(object))); + } + } + + return generated; + } + + private async Task> ResolveSearchVectorAsync(TInput value, VectorPropertyModel property, CancellationToken cancellationToken) + { + switch (value) + { + case ReadOnlyMemory memory: + return memory; + case float[] array: + return new ReadOnlyMemory(array); + case Embedding embedding: + return embedding.Vector; + default: + if (property.EmbeddingGenerator is IEmbeddingGenerator> generator) + { + return await generator.GenerateVectorAsync(value, cancellationToken: cancellationToken).ConfigureAwait(false); + } + + throw new NotSupportedException(VectorDataStrings.InvalidSearchInputAndNoEmbeddingGeneratorWasConfigured(value?.GetType() ?? typeof(object), LiteDbModelBuilder.SupportedVectorTypes)); + } + } + + private static LiteDbDistanceMetric ResolveMetric(string? distanceFunction, LiteDbDistanceMetric fallback) + => distanceFunction switch + { + DistanceFunction.CosineSimilarity => LiteDbDistanceMetric.Cosine, + DistanceFunction.CosineDistance => LiteDbDistanceMetric.Cosine, + DistanceFunction.DotProductSimilarity => LiteDbDistanceMetric.DotProduct, + DistanceFunction.EuclideanDistance => LiteDbDistanceMetric.Euclidean, + _ => fallback + }; + + private static VectorDistanceMetric MapMetric(LiteDbDistanceMetric metric) + => metric switch + { + LiteDbDistanceMetric.Cosine => VectorDistanceMetric.Cosine, + LiteDbDistanceMetric.DotProduct => VectorDistanceMetric.DotProduct, + LiteDbDistanceMetric.Euclidean => VectorDistanceMetric.Euclidean, + _ => throw new NotSupportedException($"Unsupported distance metric '{metric}'.") + }; + + private void ValidateVectorDimensions(BsonDocument document) + { + foreach (var property in this._vectorProperties) + { + if (!document.TryGetValue(property.StorageName, out var value) || value is not BsonVector vector) + { + continue; + } + + var expectedDimensions = this._options.VectorDimensions ?? property.Dimensions; + if (expectedDimensions <= 0) + { + continue; + } + + if (vector.Values.Length != expectedDimensions) + { + throw new InvalidOperationException($"Vector property '{property.ModelName}' expects {expectedDimensions} dimensions but received {vector.Values.Length}."); + } + } + } +} diff --git a/dotnet/src/VectorData/LiteDb/LiteDbCollectionOptions.cs b/dotnet/src/VectorData/LiteDb/LiteDbCollectionOptions.cs new file mode 100644 index 000000000000..9ed37a9abf11 --- /dev/null +++ b/dotnet/src/VectorData/LiteDb/LiteDbCollectionOptions.cs @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft. All rights reserved. + +using LiteDB; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.LiteDb; + +/// +/// Options for configuring a LiteDB vector collection. +/// +public sealed class LiteDbCollectionOptions +{ + internal static readonly LiteDbCollectionOptions Default = new(); + + /// + /// Initializes a new instance of the class. + /// + public LiteDbCollectionOptions() + { + } + + internal LiteDbCollectionOptions(LiteDbCollectionOptions? source) + { + if (source is null) + { + return; + } + + this.Definition = source.Definition; + this.EmbeddingGenerator = source.EmbeddingGenerator; + this.VectorDimensions = source.VectorDimensions; + this.DistanceMetric = source.DistanceMetric; + this.ConnectionString = source.ConnectionString; + this.CollectionNamePrefix = source.CollectionNamePrefix; + } + + /// + /// Gets or sets the schema definition for dynamic collections. + /// + public VectorStoreCollectionDefinition? Definition { get; set; } + + /// + /// Gets or sets the embedding generator used to populate vector properties for this collection. + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; set; } + + /// + /// Gets or sets the dimensionality of the stored vectors. + /// + public int? VectorDimensions { get; set; } + + /// + /// Gets or sets the distance metric to use for the vector index. + /// + public LiteDbDistanceMetric? DistanceMetric { get; set; } + + /// + /// Gets or sets the LiteDB connection string override. + /// + public string? ConnectionString { get; set; } + + /// + /// Gets or sets the collection name prefix applied when creating tables. + /// + public string? CollectionNamePrefix { get; set; } +} diff --git a/dotnet/src/VectorData/LiteDb/LiteDbConstants.cs b/dotnet/src/VectorData/LiteDb/LiteDbConstants.cs new file mode 100644 index 000000000000..e0de736d0ae6 --- /dev/null +++ b/dotnet/src/VectorData/LiteDb/LiteDbConstants.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.SemanticKernel.Connectors.LiteDb; + +internal static class LiteDbConstants +{ + internal const string VectorStoreSystemName = "LiteDB"; + internal const string DefaultConnectionString = "Filename=LiteDbVectorStore.db;Connection=shared"; + internal const string DefaultKeyField = "_id"; +} diff --git a/dotnet/src/VectorData/LiteDb/LiteDbDynamicCollection.cs b/dotnet/src/VectorData/LiteDb/LiteDbDynamicCollection.cs new file mode 100644 index 000000000000..6c72db6a9c05 --- /dev/null +++ b/dotnet/src/VectorData/LiteDb/LiteDbDynamicCollection.cs @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using LiteDB; + +namespace Microsoft.SemanticKernel.Connectors.LiteDb; + +#pragma warning disable CA1711 +/// +/// Represents a LiteDB collection mapped to dynamic records. +/// +public sealed class LiteDbDynamicCollection : LiteDbCollection> +#pragma warning restore CA1711 +{ + /// + /// Initializes a new instance of the class. + /// + public LiteDbDynamicCollection( + LiteDatabase database, + string name, + LiteDbVectorStoreOptions storeOptions, + LiteDbCollectionOptions options, + string connectionIdentifier) + : base( + database, + name, + storeOptions, + options, + static opts => new LiteDbModelBuilder() + .BuildDynamic(opts.Definition ?? throw new ArgumentException("Definition is required for dynamic collections"), opts.EmbeddingGenerator), + connectionIdentifier) + { + } +} diff --git a/dotnet/src/VectorData/LiteDb/LiteDbFilterTranslator.cs b/dotnet/src/VectorData/LiteDb/LiteDbFilterTranslator.cs new file mode 100644 index 000000000000..ee60451cc6a8 --- /dev/null +++ b/dotnet/src/VectorData/LiteDb/LiteDbFilterTranslator.cs @@ -0,0 +1,400 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using LiteDB; +using Microsoft.Extensions.VectorData.ProviderServices; +using Microsoft.Extensions.VectorData.ProviderServices.Filter; + +namespace Microsoft.SemanticKernel.Connectors.LiteDb; + +internal sealed class LiteDbFilterTranslator +{ + private CollectionModel _model = null!; + private ParameterExpression _parameter = null!; + private readonly List _parameters = new(); + + internal (string Expression, BsonValue[] Parameters) Translate(LambdaExpression expression, CollectionModel model) + { + this._model = model; + this._parameter = expression.Parameters[0]; + this._parameters.Clear(); + + var preprocessor = new FilterTranslationPreprocessor { SupportsParameterization = true }; + var preprocessed = preprocessor.Preprocess(expression.Body); + + var exprText = this.TranslateNode(preprocessed); + return (exprText, this._parameters.ToArray()); + } + + private string TranslateNode(Expression? node) + => node switch + { + null => "true", + ConstantExpression constant when constant.Value is bool boolValue => boolValue ? "true" : "false", + MethodCallExpression method => this.TranslateMethodCall(method), + BinaryExpression binary when IsComparison(binary.NodeType) => this.TranslateComparison(binary), + BinaryExpression binary when binary.NodeType is ExpressionType.AndAlso or ExpressionType.OrElse + => this.TranslateLogical(binary), + UnaryExpression { NodeType: ExpressionType.Not } not => $"NOT ({this.TranslateNode(not.Operand)})", + UnaryExpression { NodeType: ExpressionType.Convert } convert => this.TranslateNode(convert.Operand), + Expression expr when expr.Type == typeof(bool) && this.TryBindProperty(expr, out var property) + => this.GenerateComparison(property, true, ExpressionType.Equal), + _ => throw new NotSupportedException($"Unsupported expression node '{node?.NodeType}'.") + }; + + private string TranslateComparison(BinaryExpression binary) + { + if (this.TryBindProperty(binary.Left, out var property) && this.TryExtractConstant(binary.Right, out var constant)) + { + return this.GenerateComparison(property, constant, binary.NodeType); + } + + if (this.TryBindProperty(binary.Right, out property) && this.TryExtractConstant(binary.Left, out var leftConstant)) + { + return this.GenerateComparison(property, leftConstant, binary.NodeType); + } + + throw new NotSupportedException("LiteDB filter translation expects member-to-constant comparisons."); + } + + private string TranslateLogical(BinaryExpression binary) + { + var left = this.TranslateNode(binary.Left); + var right = this.TranslateNode(binary.Right); + var op = binary.NodeType == ExpressionType.AndAlso ? "AND" : "OR"; + return $"({left} {op} {right})"; + } + + private string GenerateComparison(PropertyModel property, object? value, ExpressionType nodeType) + { + var field = $"$.{property.StorageName}"; + if (value is null) + { + return nodeType switch + { + ExpressionType.Equal => $"({field} IS NULL)", + ExpressionType.NotEqual => $"({field} IS NOT NULL)", + _ => throw new NotSupportedException("Null comparisons are only supported for equality checks.") + }; + } + + var placeholder = this.AddParameter(this.CreateParameterValue(value)); + var op = nodeType switch + { + ExpressionType.Equal => "=", + ExpressionType.NotEqual => "!=", + ExpressionType.GreaterThan => ">", + ExpressionType.GreaterThanOrEqual => ">=", + ExpressionType.LessThan => "<", + ExpressionType.LessThanOrEqual => "<=", + _ => throw new NotSupportedException($"Comparison operator '{nodeType}' is not supported.") + }; + + return $"({field} {op} {placeholder})"; + } + + private string TranslateMethodCall(MethodCallExpression methodCall) + { + if (methodCall.Method.DeclaringType == typeof(string) + && methodCall.Method.Name is nameof(string.Contains) or nameof(string.StartsWith) or nameof(string.EndsWith) + && methodCall.Object is { } stringObject + && this.TryBindProperty(stringObject, out var stringProperty)) + { + return this.TranslateStringMethod(methodCall, stringProperty); + } + + if (methodCall is { Method.Name: nameof(Enumerable.Contains), Method.DeclaringType: var declaringType } enumerableCall + && declaringType == typeof(Enumerable)) + { + return this.TranslateEnumerableContains(enumerableCall); + } + + if (methodCall.Method.Name == nameof(List.Contains) + && methodCall.Object is not null + && this.TryBindProperty(methodCall.Object, out var collectionProperty)) + { + return this.TranslateCollectionContains(methodCall, collectionProperty); + } + + if (methodCall.Method.Name == "Contains" + && methodCall.Object is not null + && methodCall.Method.DeclaringType is { } collectionDeclaringType + && collectionDeclaringType != typeof(string) + && this.TryBindProperty(methodCall.Object, out collectionProperty)) + { + return this.TranslateCollectionContains(methodCall, collectionProperty); + } + + throw new NotSupportedException($"Unsupported method call '{methodCall.Method.DeclaringType?.Name}.{methodCall.Method.Name}' in LiteDB filter expression."); + } + + private string TranslateStringMethod(MethodCallExpression methodCall, PropertyModel property) + { + var propertyType = Nullable.GetUnderlyingType(property.Type) ?? property.Type; + if (propertyType != typeof(string)) + { + throw new NotSupportedException($"String method '{methodCall.Method.Name}' is only supported on string properties, but '{property.ModelName}' is of type '{property.Type.Name}'."); + } + + if (methodCall.Arguments.Count == 0) + { + throw new NotSupportedException($"Method '{methodCall.Method.Name}' on property '{property.ModelName}' must specify a value argument."); + } + + if (methodCall.Arguments.Count > 1) + { + throw new NotSupportedException($"LiteDB filters only support the overload of '{methodCall.Method.Name}' with a single value argument."); + } + + if (!this.TryExtractConstant(methodCall.Arguments[0], out var value) || value is not string stringValue) + { + throw new NotSupportedException($"LiteDB filters require '{methodCall.Method.Name}' arguments to be constant strings."); + } + + var pattern = methodCall.Method.Name switch + { + nameof(string.Contains) => $"%{EscapeLikePattern(stringValue)}%", + nameof(string.StartsWith) => $"{EscapeLikePattern(stringValue)}%", + nameof(string.EndsWith) => $"%{EscapeLikePattern(stringValue)}", + _ => throw new NotSupportedException($"Unsupported string method '{methodCall.Method.Name}'.") + }; + + var placeholder = this.AddParameter(new BsonValue(pattern)); + var field = GetField(property); + return $"({field} LIKE {placeholder})"; + } + + private string TranslateCollectionContains(MethodCallExpression methodCall, PropertyModel property) + { + if (!typeof(IEnumerable).IsAssignableFrom((Nullable.GetUnderlyingType(property.Type) ?? property.Type))) + { + throw new NotSupportedException($"Collection.Contains is only supported for enumerable properties. Property '{property.ModelName}' has type '{property.Type.Name}'."); + } + + if (methodCall.Arguments.Count != 1) + { + throw new NotSupportedException($"Method '{methodCall.Method.Name}' must have exactly one argument in LiteDB filters."); + } + + if (!this.TryExtractConstant(methodCall.Arguments[0], out var value)) + { + throw new NotSupportedException("LiteDB filters require collection.Contains arguments to be constant values."); + } + + var placeholder = this.AddParameter(this.CreateParameterValue(value)); + var field = GetField(property); + return $"({field} ANY = {placeholder})"; + } + + private string TranslateEnumerableContains(MethodCallExpression methodCall) + { + if (methodCall.Arguments.Count != 2) + { + throw new NotSupportedException("Enumerable.Contains must specify the source and the item to compare."); + } + + var source = methodCall.Arguments[0]; + var item = methodCall.Arguments[1]; + + if (this.TryBindProperty(source, out var collectionProperty)) + { + var propertyType = Nullable.GetUnderlyingType(collectionProperty.Type) ?? collectionProperty.Type; + if (!typeof(IEnumerable).IsAssignableFrom(propertyType)) + { + throw new NotSupportedException($"Enumerable.Contains is only supported on enumerable properties. Property '{collectionProperty.ModelName}' has type '{collectionProperty.Type.Name}'."); + } + + if (!this.TryExtractConstant(item, out var element)) + { + throw new NotSupportedException("LiteDB filters require Enumerable.Contains item arguments to be constant values when the source is a record property."); + } + + var collectionPlaceholder = this.AddParameter(this.CreateParameterValue(element)); + var collectionField = GetField(collectionProperty); + return $"({collectionField} ANY = {collectionPlaceholder})"; + } + + if (!this.TryExtractConstant(source, out var values)) + { + try + { + values = Expression.Lambda(source).Compile().DynamicInvoke(); + } + catch + { + throw new NotSupportedException("LiteDB filters require Enumerable.Contains sources to be constant sequences."); + } + } + + if (values is not IEnumerable enumerable) + { + throw new NotSupportedException("LiteDB filters require Enumerable.Contains sources to be constant sequences."); + } + + if (!this.TryBindProperty(item, out var property)) + { + throw new NotSupportedException("LiteDB filters support Enumerable.Contains only when comparing against a record property."); + } + + var placeholder = this.AddParameter(this.CreateParameterValue(enumerable)); + var field = GetField(property); + return $"({field} IN {placeholder})"; + } + + private string AddParameter(BsonValue value) + { + var index = this._parameters.Count; + this._parameters.Add(value); + return $"@{index}"; + } + + private bool TryExtractConstant(Expression expression, out object? value) + { + switch (expression) + { + case ConstantExpression constant: + value = constant.Value; + return true; + case QueryParameterExpression queryParameter: + value = queryParameter.Value; + return true; + case MemberExpression { Expression: QueryParameterExpression queryParameter, Member: PropertyInfo property } + when property.CanRead: + value = property.GetValue(queryParameter); + return true; + case MemberExpression { Expression: { } instanceExpression } member + when this.TryExtractConstant(instanceExpression, out var instance) + && TryGetMemberValue(member.Member, instance, out value): + return true; + case MemberExpression { Expression: null } member when TryGetMemberValue(member.Member, null, out value): + return true; + case UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unary + when this.TryExtractConstant(unary.Operand, out var operand): + value = operand; + return true; + case NewArrayExpression newArray: + { + var items = new object?[newArray.Expressions.Count]; + for (var i = 0; i < newArray.Expressions.Count; i++) + { + if (!this.TryExtractConstant(newArray.Expressions[i], out var element)) + { + value = null; + return false; + } + + items[i] = element; + } + + value = items; + return true; + } + default: + value = null; + return false; + } + } + + private BsonValue CreateParameterValue(object? value) + { + switch (value) + { + case null: + return BsonValue.Null; + case BsonValue bson: + return bson; + case IEnumerable enumerable when value is not string: + { + var array = new BsonArray(); + foreach (var element in enumerable) + { + array.Add(this.CreateParameterValue(element)); + } + + return array; + } + default: + return new BsonValue(value); + } + } + + private static string GetField(PropertyModel property) + => $"$.{property.StorageName}"; + + private static string EscapeLikePattern(string value) + => value.Replace("%", "[%]").Replace("_", "[_]"); + + private static bool IsComparison(ExpressionType type) + => type is ExpressionType.Equal or ExpressionType.NotEqual + or ExpressionType.GreaterThan or ExpressionType.GreaterThanOrEqual + or ExpressionType.LessThan or ExpressionType.LessThanOrEqual; + + private static bool TryGetMemberValue(MemberInfo member, object? instance, out object? value) + { + switch (member) + { + case FieldInfo field: + value = field.GetValue(instance); + return true; + case PropertyInfo { CanRead: true } property: + value = property.GetValue(instance); + return true; + default: + value = null; + return false; + } + } + + private bool TryBindProperty(Expression expression, [NotNullWhen(true)] out PropertyModel? property) + { + var unwrapped = expression; + while (unwrapped is UnaryExpression { NodeType: ExpressionType.Convert } convert) + { + unwrapped = convert.Operand; + } + + string? modelName = unwrapped switch + { + MemberExpression member when member.Expression == this._parameter => member.Member.Name, + MethodCallExpression + { + Method: { Name: "get_Item", DeclaringType: var declaringType }, + Arguments: [ConstantExpression { Value: string key }] + } call when call.Object == this._parameter && declaringType == typeof(Dictionary) + => key, + _ => null + }; + + if (modelName is null) + { + property = null; + return false; + } + + if (!this._model.PropertyMap.TryGetValue(modelName, out property)) + { + throw new InvalidOperationException($"Property name '{modelName}' provided as part of the filter clause is not a valid property name."); + } + + var expectedType = Nullable.GetUnderlyingType(property.Type) ?? property.Type; + unwrapped = expression; + while (unwrapped is UnaryExpression { NodeType: ExpressionType.Convert } convert) + { + var convertType = Nullable.GetUnderlyingType(convert.Type) ?? convert.Type; + if (convertType != expectedType && convertType != typeof(object)) + { + throw new InvalidCastException($"Property '{property.ModelName}' is being cast to type '{convert.Type.Name}', but its configured type is '{property.Type.Name}'."); + } + + unwrapped = convert.Operand; + } + + return true; + } +} diff --git a/dotnet/src/VectorData/LiteDb/LiteDbMapper.cs b/dotnet/src/VectorData/LiteDb/LiteDbMapper.cs new file mode 100644 index 000000000000..38adbf32ad01 --- /dev/null +++ b/dotnet/src/VectorData/LiteDb/LiteDbMapper.cs @@ -0,0 +1,231 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Generic; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.VectorData.ProviderServices; +using LiteDB; + +namespace Microsoft.SemanticKernel.Connectors.LiteDb; + +internal sealed class LiteDbMapper(CollectionModel model) + where TRecord : class +{ + private readonly CollectionModel _model = model; + + public BsonDocument MapToDocument( + TRecord record, + IReadOnlyDictionary>? generatedVectors, + int recordIndex) + { + var document = new BsonDocument(); + + var keyValue = this._model.KeyProperty.GetValueAsObject(record) + ?? throw new InvalidOperationException($"Key property '{this._model.KeyProperty.ModelName}' cannot be null."); + document[LiteDbConstants.DefaultKeyField] = new BsonValue(keyValue); + + foreach (var property in this._model.DataProperties) + { + var value = property.GetValueAsObject(record); + if (value is null) + { + continue; + } + + if (value is IEnumerable enumerable && value is not string) + { + var array = new BsonArray(); + foreach (var element in enumerable) + { + array.Add(element is null ? BsonValue.Null : new BsonValue(element)); + } + + document[property.StorageName] = array; + continue; + } + + document[property.StorageName] = new BsonValue(value); + } + + foreach (var property in this._model.VectorProperties) + { + float[]? vector = null; + var value = property.GetValueAsObject(record); + if (value is not null && TryConvertVector(value, out vector)) + { + // vector assigned + } + else if (generatedVectors is not null && generatedVectors.TryGetValue(property.ModelName, out var generated)) + { + vector = generated[recordIndex]; + } + + if (vector is not null) + { + document[property.StorageName] = new BsonVector(vector); + } + } + + return document; + } + + public TRecord MapToRecord(BsonDocument document, bool includeVectors) + { + var record = this._model.CreateRecord(); + + var keyValue = document[LiteDbConstants.DefaultKeyField]; + this._model.KeyProperty.SetValueAsObject(record, ConvertValue(keyValue, this._model.KeyProperty.Type)); + + foreach (var property in this._model.DataProperties) + { + if (!document.TryGetValue(property.StorageName, out var bsonValue) || bsonValue.IsNull) + { + continue; + } + + property.SetValueAsObject(record, ConvertValue(bsonValue, property.Type)); + } + + if (includeVectors) + { + foreach (var property in this._model.VectorProperties) + { + if (!document.TryGetValue(property.StorageName, out var bsonValue) || bsonValue.IsNull) + { + continue; + } + + var floats = bsonValue is BsonVector vector ? vector.Values : null; + if (floats is null) + { + continue; + } + + var targetType = Nullable.GetUnderlyingType(property.Type) ?? property.Type; + object vectorValue = targetType switch + { + { } t when t == typeof(ReadOnlyMemory) => new ReadOnlyMemory(floats), + { } t when t == typeof(Embedding) => new Embedding(floats), + { } t when t == typeof(float[]) => floats, + _ => throw new NotSupportedException($"Vector property '{property.ModelName}' has unsupported type '{property.Type}'.") + }; + + property.SetValueAsObject(record, vectorValue); + } + } + + return record; + } + + internal static bool TryConvertVector(object value, out float[]? vector) + { + switch (value) + { + case float[] floats: + vector = floats; + return true; + case ReadOnlyMemory memory: + vector = memory.ToArray(); + return true; + case Embedding embedding: + vector = embedding.Vector.ToArray(); + return true; + default: + vector = null; + return false; + } + } + + private static object? ConvertValue(BsonValue value, Type targetType) + { + var underlyingType = Nullable.GetUnderlyingType(targetType) ?? targetType; + if (value.IsNull) + { + return null; + } + + if (underlyingType == typeof(string)) + { + return value.AsString; + } + + if (value.IsArray) + { + var array = value.AsArray; + + if (underlyingType.IsArray) + { + var elementType = underlyingType.GetElementType() ?? typeof(object); + var result = Array.CreateInstance(elementType, array.Count); + for (var i = 0; i < array.Count; i++) + { + result.SetValue(ConvertValue(array[i], elementType), i); + } + + return result; + } + + if (typeof(IList).IsAssignableFrom(underlyingType) && underlyingType.IsGenericType) + { + var elementType = underlyingType.GetGenericArguments()[0]; + var listType = typeof(List<>).MakeGenericType(elementType); + var list = (IList)Activator.CreateInstance(listType)!; + foreach (var item in array) + { + list.Add(ConvertValue(item, elementType)); + } + + return list; + } + } + + if (underlyingType == typeof(int)) + { + return value.AsInt32; + } + + if (underlyingType == typeof(long)) + { + return value.AsInt64; + } + + if (underlyingType == typeof(double)) + { + return value.AsDouble; + } + + if (underlyingType == typeof(float)) + { + return (float)value.AsDouble; + } + + if (underlyingType == typeof(bool)) + { + return value.AsBoolean; + } + + if (underlyingType == typeof(DateTime)) + { + return value.AsDateTime; + } + + if (underlyingType == typeof(Guid)) + { + return value.AsGuid; + } + + if (underlyingType == typeof(Dictionary)) + { + var dictionary = new Dictionary(); + foreach (var element in value.AsDocument) + { + dictionary[element.Key] = element.Value.RawValue; + } + + return dictionary; + } + + return value.RawValue; + } +} diff --git a/dotnet/src/VectorData/LiteDb/LiteDbModelBuilder.cs b/dotnet/src/VectorData/LiteDb/LiteDbModelBuilder.cs new file mode 100644 index 000000000000..7ebb63678c8f --- /dev/null +++ b/dotnet/src/VectorData/LiteDb/LiteDbModelBuilder.cs @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.VectorData.ProviderServices; + +namespace Microsoft.SemanticKernel.Connectors.LiteDb; + +internal sealed class LiteDbModelBuilder() : CollectionModelBuilder(ValidationOptions) +{ + private const string SupportedKeyTypes = "string"; + internal const string SupportedVectorTypes = "ReadOnlyMemory, Embedding, float[]"; + + internal static readonly CollectionModelBuildingOptions ValidationOptions = new() + { + SupportsMultipleKeys = false, + SupportsMultipleVectors = true, + RequiresAtLeastOneVector = false + }; + + protected override bool IsKeyPropertyTypeValid(Type type, [NotNullWhen(false)] out string? supportedTypes) + { + supportedTypes = SupportedKeyTypes; + return type == typeof(string); + } + + protected override bool IsDataPropertyTypeValid(Type type, [NotNullWhen(false)] out string? supportedTypes) + { + supportedTypes = string.Empty; + return true; + } + + protected override bool IsVectorPropertyTypeValid(Type type, [NotNullWhen(false)] out string? supportedTypes) + { + supportedTypes = SupportedVectorTypes; + return type == typeof(ReadOnlyMemory) + || type == typeof(ReadOnlyMemory?) + || type == typeof(Embedding) + || type == typeof(float[]); + } +} diff --git a/dotnet/src/VectorData/LiteDb/LiteDbServiceCollectionExtensions.cs b/dotnet/src/VectorData/LiteDb/LiteDbServiceCollectionExtensions.cs new file mode 100644 index 000000000000..453712b63383 --- /dev/null +++ b/dotnet/src/VectorData/LiteDb/LiteDbServiceCollectionExtensions.cs @@ -0,0 +1,251 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.LiteDb; + +namespace Microsoft.Extensions.DependencyInjection; + +/// +/// Extension methods to register LiteDB instances on an . +/// +public static class LiteDbServiceCollectionExtensions +{ + private const string DynamicCodeMessage = "This method is incompatible with NativeAOT, consult the documentation for adding collections in a way that's compatible with NativeAOT."; + private const string UnreferencedCodeMessage = "This method is incompatible with trimming, consult the documentation for adding collections in a way that's compatible with NativeAOT."; + + /// + /// Registers a as using the specified connection string. + /// + /// + public static IServiceCollection AddLiteDbVectorStore( + this IServiceCollection services, + string connectionString, + LiteDbVectorStoreOptions? options = null, + ServiceLifetime lifetime = ServiceLifetime.Singleton) + => AddKeyedLiteDbVectorStore(services, serviceKey: null, connectionString, options, lifetime); + + /// + /// Registers a keyed as using the specified connection string. + /// + /// The to register the store on. + /// The key with which to associate the store. + /// The LiteDB connection string. + /// Additional options to configure the store. + /// The service lifetime for the store. Defaults to . + /// The service collection. + public static IServiceCollection AddKeyedLiteDbVectorStore( + this IServiceCollection services, + object? serviceKey, + string connectionString, + LiteDbVectorStoreOptions? options = null, + ServiceLifetime lifetime = ServiceLifetime.Singleton) + { + ArgumentNullException.ThrowIfNull(services); + if (string.IsNullOrWhiteSpace(connectionString)) + { + throw new ArgumentException("Connection string cannot be null or whitespace.", nameof(connectionString)); + } + + if (options is not null && (options.Database is not null || options.DatabaseFactory is not null)) + { + throw new ArgumentException("When providing a connection string do not supply Database or DatabaseFactory within the options.", nameof(options)); + } + + services.Add(new ServiceDescriptor(typeof(LiteDbVectorStore), serviceKey, (serviceProvider, _) => + { + var storeOptions = PrepareStoreOptions(serviceProvider, options, connectionString); + return new LiteDbVectorStore(storeOptions); + }, lifetime)); + + services.Add(new ServiceDescriptor(typeof(VectorStore), serviceKey, + static (sp, key) => sp.GetRequiredKeyedService(key), lifetime)); + + return services; + } + + /// + /// Registers a using an options factory. + /// + /// + public static IServiceCollection AddLiteDbVectorStore( + this IServiceCollection services, + Func? optionsProvider, + ServiceLifetime lifetime = ServiceLifetime.Singleton) + => AddKeyedLiteDbVectorStore(services, serviceKey: null, optionsProvider, lifetime); + + /// + /// Registers a keyed using an options factory. + /// + /// The to register the store on. + /// The key with which to associate the store. + /// Factory that produces the . + /// The service lifetime for the store. Defaults to . + /// The service collection. + public static IServiceCollection AddKeyedLiteDbVectorStore( + this IServiceCollection services, + object? serviceKey, + Func? optionsProvider, + ServiceLifetime lifetime = ServiceLifetime.Singleton) + { + ArgumentNullException.ThrowIfNull(services); + + services.Add(new ServiceDescriptor(typeof(LiteDbVectorStore), serviceKey, (serviceProvider, _) => + { + var storeOptions = GetStoreOptions(serviceProvider, optionsProvider); + return new LiteDbVectorStore(storeOptions); + }, lifetime)); + + services.Add(new ServiceDescriptor(typeof(VectorStore), serviceKey, + static (sp, key) => sp.GetRequiredKeyedService(key), lifetime)); + + return services; + } + + /// + /// Registers a backed by LiteDB using a previously registered store. + /// + /// + [RequiresDynamicCode(DynamicCodeMessage)] + [RequiresUnreferencedCode(UnreferencedCodeMessage)] + public static IServiceCollection AddLiteDbCollection( + this IServiceCollection services, + string name, + Func? definitionProvider = null, + ServiceLifetime lifetime = ServiceLifetime.Singleton) + where TKey : notnull + where TRecord : class + => AddKeyedLiteDbCollection(services, serviceKey: null, name, definitionProvider, lifetime); + + /// + /// Registers a keyed backed by LiteDB using a previously registered store. + /// + /// The to register the collection on. + /// The key with which to associate the collection (and matching store). + /// The logical name of the collection. + /// Optional provider supplying the . + /// The service lifetime for the collection. Defaults to . + /// The service collection. + [RequiresDynamicCode(DynamicCodeMessage)] + [RequiresUnreferencedCode(UnreferencedCodeMessage)] + public static IServiceCollection AddKeyedLiteDbCollection( + this IServiceCollection services, + object? serviceKey, + string name, + Func? definitionProvider = null, + ServiceLifetime lifetime = ServiceLifetime.Singleton) + where TKey : notnull + where TRecord : class + { + ArgumentNullException.ThrowIfNull(services); + if (string.IsNullOrWhiteSpace(name)) + { + throw new ArgumentException("Collection name cannot be null or whitespace.", nameof(name)); + } + + services.Add(new ServiceDescriptor(typeof(VectorStoreCollection), serviceKey, (serviceProvider, key) => + { + var store = serviceProvider.GetRequiredKeyedService(key); + var definition = definitionProvider?.Invoke(serviceProvider); + return store.GetCollection(name, definition); + }, lifetime)); + + services.Add(new ServiceDescriptor(typeof(IVectorSearchable), serviceKey, + static (sp, key) => sp.GetRequiredKeyedService>(key), lifetime)); + + return services; + } + + /// + /// Registers a dynamic backed by LiteDB using a previously registered store. + /// + /// + [RequiresDynamicCode(DynamicCodeMessage)] + [RequiresUnreferencedCode(UnreferencedCodeMessage)] + public static IServiceCollection AddLiteDbDynamicCollection( + this IServiceCollection services, + string name, + Func definitionProvider, + ServiceLifetime lifetime = ServiceLifetime.Singleton) + => AddKeyedLiteDbDynamicCollection(services, serviceKey: null, name, definitionProvider, lifetime); + + /// + /// Registers a keyed dynamic backed by LiteDB using a previously registered store. + /// + /// The to register the collection on. + /// The key with which to associate the collection (and matching store). + /// The logical name of the collection. + /// Provider supplying the . + /// The service lifetime for the collection. Defaults to . + /// The service collection. + [RequiresDynamicCode(DynamicCodeMessage)] + [RequiresUnreferencedCode(UnreferencedCodeMessage)] + public static IServiceCollection AddKeyedLiteDbDynamicCollection( + this IServiceCollection services, + object? serviceKey, + string name, + Func definitionProvider, + ServiceLifetime lifetime = ServiceLifetime.Singleton) + { + ArgumentNullException.ThrowIfNull(services); + if (string.IsNullOrWhiteSpace(name)) + { + throw new ArgumentException("Collection name cannot be null or whitespace.", nameof(name)); + } + ArgumentNullException.ThrowIfNull(definitionProvider); + + services.Add(new ServiceDescriptor(typeof(VectorStoreCollection>), serviceKey, (serviceProvider, key) => + { + var store = serviceProvider.GetRequiredKeyedService(key); + var definition = definitionProvider(serviceProvider); + return store.GetDynamicCollection(name, definition); + }, lifetime)); + + services.Add(new ServiceDescriptor(typeof(IVectorSearchable>), serviceKey, + static (sp, key) => sp.GetRequiredKeyedService>>(key), lifetime)); + + return services; + } + + private static LiteDbVectorStoreOptions PrepareStoreOptions(IServiceProvider serviceProvider, LiteDbVectorStoreOptions? options, string connectionString) + { + var storeOptions = new LiteDbVectorStoreOptions(options) + { + ConnectionString = connectionString + }; + + ApplyEmbeddingGeneratorIfMissing(serviceProvider, storeOptions); + return storeOptions; + } + + private static LiteDbVectorStoreOptions GetStoreOptions(IServiceProvider serviceProvider, Func? optionsProvider) + { + var provided = optionsProvider?.Invoke(serviceProvider); + var storeOptions = new LiteDbVectorStoreOptions(provided); + + if (string.IsNullOrWhiteSpace(storeOptions.ConnectionString) && + storeOptions.Database is null && + storeOptions.DatabaseFactory is null) + { + storeOptions.ConnectionString = LiteDbConstants.DefaultConnectionString; + } + + ApplyEmbeddingGeneratorIfMissing(serviceProvider, storeOptions); + return storeOptions; + } + + private static void ApplyEmbeddingGeneratorIfMissing(IServiceProvider serviceProvider, LiteDbVectorStoreOptions storeOptions) + { + if (storeOptions.EmbeddingGenerator is null) + { + var embeddingGenerator = serviceProvider.GetService(); + if (embeddingGenerator is not null) + { + storeOptions.EmbeddingGenerator = embeddingGenerator; + } + } + } +} diff --git a/dotnet/src/VectorData/LiteDb/LiteDbVectorMath.cs b/dotnet/src/VectorData/LiteDb/LiteDbVectorMath.cs new file mode 100644 index 000000000000..0f88b961a837 --- /dev/null +++ b/dotnet/src/VectorData/LiteDb/LiteDbVectorMath.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Numerics.Tensors; + +namespace Microsoft.SemanticKernel.Connectors.LiteDb; + +internal static class LiteDbVectorMath +{ + public static float Compare(ReadOnlySpan x, ReadOnlySpan y, LiteDbDistanceMetric metric) + => metric switch + { + LiteDbDistanceMetric.Cosine => TensorPrimitives.CosineSimilarity(x, y), + LiteDbDistanceMetric.DotProduct => TensorPrimitives.Dot(x, y), + LiteDbDistanceMetric.Euclidean => TensorPrimitives.Distance(x, y), + _ => throw new NotSupportedException($"Unsupported distance metric '{metric}'.") + }; + + public static bool ShouldSortDescending(LiteDbDistanceMetric metric) + => metric switch + { + LiteDbDistanceMetric.Euclidean => false, + LiteDbDistanceMetric.Cosine => true, + LiteDbDistanceMetric.DotProduct => true, + _ => throw new NotSupportedException($"Unsupported distance metric '{metric}'.") + }; +} diff --git a/dotnet/src/VectorData/LiteDb/LiteDbVectorStore.cs b/dotnet/src/VectorData/LiteDb/LiteDbVectorStore.cs new file mode 100644 index 000000000000..73a605a6aaa1 --- /dev/null +++ b/dotnet/src/VectorData/LiteDb/LiteDbVectorStore.cs @@ -0,0 +1,255 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using System.Linq; +using LiteDB; +using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ProviderServices; + +namespace Microsoft.SemanticKernel.Connectors.LiteDb; + +/// +/// Vector store implementation backed by LiteDB. +/// +public sealed class LiteDbVectorStore : VectorStore +{ + private readonly LiteDatabase _database; + private readonly bool _ownsDatabase; + private readonly LiteDbVectorStoreOptions _options; + private readonly VectorStoreMetadata _metadata; + private readonly string _connectionIdentifier; + + /// + /// Initializes a new instance of the class using the provided connection string. + /// + public LiteDbVectorStore(string connectionString, LiteDbVectorStoreOptions? options = null) + : this(NormalizeConnectionOptions(connectionString, options)) + { + } + + /// + /// Initializes a new instance of the class using an existing . + /// + public LiteDbVectorStore(LiteDatabase database, LiteDbVectorStoreOptions? options = null) + : this(NormalizeDatabaseOptions(database, options)) + { + } + + /// + /// Initializes a new instance of the class using the supplied options. + /// + public LiteDbVectorStore(LiteDbVectorStoreOptions options) + { + ArgumentNullException.ThrowIfNull(options); + + this._options = options; + var (database, ownsDatabase, connectionIdentifier) = ResolveDatabase(options); + + this._database = database; + this._ownsDatabase = ownsDatabase; + this._connectionIdentifier = connectionIdentifier; + this._metadata = new VectorStoreMetadata + { + VectorStoreSystemName = LiteDbConstants.VectorStoreSystemName, + VectorStoreName = connectionIdentifier + }; + } + + /// + [RequiresDynamicCode("This overload of GetCollection() is incompatible with NativeAOT. For dynamic mapping via Dictionary call GetDynamicCollection() instead.")] + [RequiresUnreferencedCode("This overload of GetCollection() is incompatible with trimming. For dynamic mapping via Dictionary call GetDynamicCollection() instead.")] +#if NET8_0_OR_GREATER + public override LiteDbCollection GetCollection(string name, VectorStoreCollectionDefinition? definition = null) +#else + public override VectorStoreCollection GetCollection(string name, VectorStoreCollectionDefinition? definition = null) +#endif + { + if (string.IsNullOrWhiteSpace(name)) + { + throw new ArgumentException("Collection name cannot be null or whitespace.", nameof(name)); + } + var collectionOptions = new LiteDbCollectionOptions + { + Definition = definition, + EmbeddingGenerator = this._options.EmbeddingGenerator, + CollectionNamePrefix = this._options.CollectionNamePrefix + }; + + return new LiteDbCollection( + this._database, + this.ResolveCollectionName(name, collectionOptions), + this._options, + collectionOptions, + static opts => typeof(TRecord) == typeof(Dictionary) + ? throw new NotSupportedException(VectorDataStrings.NonDynamicCollectionWithDictionaryNotSupported(typeof(LiteDbDynamicCollection))) + : new LiteDbModelBuilder().Build(typeof(TRecord), opts.Definition, opts.EmbeddingGenerator), + this._connectionIdentifier); + } + +#if NET8_0_OR_GREATER + /// + public override LiteDbCollection> GetDynamicCollection(string name, VectorStoreCollectionDefinition definition) +#else + /// + public override VectorStoreCollection> GetDynamicCollection(string name, VectorStoreCollectionDefinition definition) +#endif + { + if (string.IsNullOrWhiteSpace(name)) + { + throw new ArgumentException("Collection name cannot be null or whitespace.", nameof(name)); + } + ArgumentNullException.ThrowIfNull(definition); + + var collectionOptions = new LiteDbCollectionOptions + { + Definition = definition, + EmbeddingGenerator = this._options.EmbeddingGenerator, + CollectionNamePrefix = this._options.CollectionNamePrefix + }; + + return new LiteDbDynamicCollection( + this._database, + this.ResolveCollectionName(name, collectionOptions), + this._options, + collectionOptions, + this._connectionIdentifier); + } + + /// + public override async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + foreach (var storageName in this._database.GetCollectionNames()) + { + cancellationToken.ThrowIfCancellationRequested(); + + if (!this.TryNormalizeCollectionName(storageName, out var logicalName)) + { + continue; + } + + yield return logicalName; + await Task.Yield(); + } + } + + /// + public override Task CollectionExistsAsync(string name, CancellationToken cancellationToken = default) + { + var storageName = this.ResolveCollectionName(name); + var exists = this._database.GetCollectionNames().Contains(storageName, StringComparer.OrdinalIgnoreCase); + return Task.FromResult(exists); + } + + /// + public override Task EnsureCollectionDeletedAsync(string name, CancellationToken cancellationToken = default) + { + var storageName = this.ResolveCollectionName(name); + this._database.DropCollection(storageName); + return Task.CompletedTask; + } + + /// + public override object? GetService(Type serviceType, object? serviceKey = null) + { + ArgumentNullException.ThrowIfNull(serviceType); + + return serviceKey is not null + ? null + : serviceType == typeof(VectorStoreMetadata) + ? this._metadata + : serviceType.IsInstanceOfType(this) + ? this + : null; + } + + /// + protected override void Dispose(bool disposing) + { + if (disposing && this._ownsDatabase) + { + this._database.Dispose(); + } + + base.Dispose(disposing); + } + + private static LiteDbVectorStoreOptions NormalizeConnectionOptions(string connectionString, LiteDbVectorStoreOptions? options) + { + if (string.IsNullOrWhiteSpace(connectionString)) + { + throw new ArgumentException("Connection string cannot be null or whitespace.", nameof(connectionString)); + } + + var normalized = new LiteDbVectorStoreOptions(options); + normalized.ConnectionString = connectionString; + return normalized; + } + + private static LiteDbVectorStoreOptions NormalizeDatabaseOptions(LiteDatabase database, LiteDbVectorStoreOptions? options) + { + ArgumentNullException.ThrowIfNull(database); + + var normalized = new LiteDbVectorStoreOptions(options) + { + Database = database + }; + + return normalized; + } + + private static (LiteDatabase Database, bool OwnsDatabase, string ConnectionIdentifier) ResolveDatabase(LiteDbVectorStoreOptions options) + { + if (options.DatabaseFactory is not null) + { + var database = options.DatabaseFactory(); + if (database is null) + { + throw new InvalidOperationException("The LiteDbVectorStoreOptions.DatabaseFactory returned null."); + } + + return (database, options.DisposeDatabase, LiteDbConstants.VectorStoreSystemName); + } + + if (options.Database is not null) + { + return (options.Database, options.DisposeDatabase, LiteDbConstants.VectorStoreSystemName); + } + + var connectionString = string.IsNullOrWhiteSpace(options.ConnectionString) + ? LiteDbConstants.DefaultConnectionString + : options.ConnectionString; + + var databaseInstance = new LiteDatabase(connectionString); + return (databaseInstance, options.DisposeDatabase, connectionString); + } + + private string ResolveCollectionName(string name, LiteDbCollectionOptions? collectionOptions = null) + { + var prefix = collectionOptions?.CollectionNamePrefix ?? this._options.CollectionNamePrefix; + return string.IsNullOrEmpty(prefix) ? name : string.Concat(prefix, name); + } + + private bool TryNormalizeCollectionName(string storageName, out string logicalName) + { + var prefix = this._options.CollectionNamePrefix; + if (string.IsNullOrEmpty(prefix)) + { + logicalName = storageName; + return true; + } + + if (!storageName.StartsWith(prefix, StringComparison.Ordinal)) + { + logicalName = string.Empty; + return false; + } + + logicalName = storageName[prefix.Length..]; + return true; + } +} diff --git a/dotnet/src/VectorData/LiteDb/LiteDbVectorStoreOptions.cs b/dotnet/src/VectorData/LiteDb/LiteDbVectorStoreOptions.cs new file mode 100644 index 000000000000..47fd83e4c07e --- /dev/null +++ b/dotnet/src/VectorData/LiteDb/LiteDbVectorStoreOptions.cs @@ -0,0 +1,130 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using LiteDB; +using Microsoft.Extensions.AI; + +namespace Microsoft.SemanticKernel.Connectors.LiteDb; + +/// +/// Options when creating a . +/// +public sealed class LiteDbVectorStoreOptions +{ + internal static readonly LiteDbVectorStoreOptions Default = new(); + private bool _autoEnsureVectorIndex = true; + private bool? _autoCreateVectorIndexes; + + /// + /// Initializes a new instance of the class. + /// + public LiteDbVectorStoreOptions() + { + } + + internal LiteDbVectorStoreOptions(LiteDbVectorStoreOptions? source) + { + if (source is null) + { + return; + } + + this.ConnectionString = source.ConnectionString; + this.Database = source.Database; + this.DatabaseFactory = source.DatabaseFactory; + this.DisposeDatabase = source.DisposeDatabase; + this.DistanceMetric = source.DistanceMetric; + this.AutoEnsureVectorIndex = source.AutoEnsureVectorIndex; + this.AutoCreateVectorIndexes = source.AutoCreateVectorIndexes; + this.EmbeddingGenerator = source.EmbeddingGenerator; + this.CollectionNamePrefix = source.CollectionNamePrefix; + } + + /// + /// Gets or sets the connection string used to open the LiteDB database. + /// + /// + /// When not provided, a shared-file connection string is derived automatically. + /// + public string? ConnectionString { get; set; } + + /// + /// Gets or sets the instance used by the store. + /// + /// + /// When provided, the caller retains ownership unless is set to . + /// + public LiteDatabase? Database { get; set; } + + /// + /// Gets or sets the factory used to create a instance when constructing the store. + /// + /// + /// When provided, and are ignored. + /// + public Func? DatabaseFactory { get; set; } + + /// + /// Gets or sets a value indicating whether the store should dispose the supplied database when disposed. + /// + public bool DisposeDatabase { get; set; } = true; + + /// + /// Gets or sets the default distance metric to use when building vector indexes. + /// + public LiteDbDistanceMetric DistanceMetric { get; set; } = LiteDbDistanceMetric.Cosine; + + /// + /// Gets or sets a value indicating whether collections created through this store automatically ensure their vector index. + /// + public bool AutoEnsureVectorIndex + { + get => this._autoEnsureVectorIndex; + set + { + this._autoEnsureVectorIndex = value; + this._autoCreateVectorIndexes = value; + } + } + + /// + /// Gets or sets a value indicating whether collections created through this store automatically ensure their vector index. + /// + public bool? AutoCreateVectorIndexes + { + get => this._autoCreateVectorIndexes; + set + { + this._autoCreateVectorIndexes = value; + if (value.HasValue) + { + this._autoEnsureVectorIndex = value.Value; + } + } + } + + /// + /// Gets or sets the default embedding generator to use when generating embeddings for vector properties. + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; set; } + + /// + /// Gets or sets the prefix applied to collection names created through this store. + /// + public string? CollectionNamePrefix { get; set; } +} + +/// +/// LiteDB distance metrics supported by the connector. +/// +public enum LiteDbDistanceMetric +{ + /// Cosine similarity. + Cosine, + + /// L2 (Euclidean) distance. + Euclidean, + + /// Dot product similarity. + DotProduct, +} diff --git a/dotnet/src/VectorData/LiteDb/README.md b/dotnet/src/VectorData/LiteDb/README.md new file mode 100644 index 000000000000..af7ed158941f --- /dev/null +++ b/dotnet/src/VectorData/LiteDb/README.md @@ -0,0 +1,145 @@ +# LiteDB Vector Store Connector + +The LiteDB connector provides an embedded, single-file option for working with the Semantic Kernel vector store abstractions. It uses LiteDB v6's native `BsonVector` storage and HNSW indexes so that applications can run entirely on managed runtimes without external dependencies. + +## Prerequisites + +- Install the `LiteDB` NuGet package version `6.0.0-prerelease.63`. +- Reference the `Microsoft.SemanticKernel.Connectors.LiteDb` project or package in your .NET solution. +- Ensure the target framework is at least .NET Standard 2.1 or .NET 8.0. + +## Creating a vector store + +```csharp +using LiteDB; +using Microsoft.SemanticKernel.Connectors.LiteDb; + +using var database = new LiteDatabase("Filename=sk.db;Mode=Exclusive"); +var store = new LiteDbVectorStore(database); +``` + +The store can also open a LiteDB database using a connection string. When you pass an existing `LiteDatabase` instance you control its lifecycle via `LiteDbVectorStoreOptions.DisposeDatabase`. + +When you need the connector to manage the database lifecycle, create the store with `LiteDbVectorStoreOptions`. The options surface supports supplying a database factory and a collection name prefix. + +Connection sources are considered in the following order: a configured `DatabaseFactory` is used first, then an existing `LiteDatabase` instance via `Database`, and finally the `ConnectionString` (or the embedded default when none is supplied). The `AutoCreateVectorIndexes` property is an alias for `AutoEnsureVectorIndex` so that existing configuration snippets continue to work unchanged. + +```csharp +var options = new LiteDbVectorStoreOptions +{ + DatabaseFactory = () => new LiteDatabase("Filename=sk.db;Connection=shared"), + CollectionNamePrefix = "sk_" +}; + +using var store = new LiteDbVectorStore(options); +``` + +Collections created through this store are materialized as `sk_*` tables, while APIs such as `CollectionExistsAsync` and `ListCollectionNamesAsync` use the logical names you pass to `GetCollection`. + +## Defining a collection + +Collections map strongly-typed record models to LiteDB BSON documents. Use the Semantic Kernel attributes to annotate key, data, and vector fields: + +```csharp +public sealed class Hotel +{ + [VectorStoreKey] + public string Id { get; set; } = string.Empty; + + [VectorStoreData(IsIndexed = true)] + public string Name { get; set; } = string.Empty; + + [VectorStoreVector(Dimensions: 3, DistanceFunction = DistanceFunction.CosineSimilarity)] + public ReadOnlyMemory? DescriptionEmbedding { get; set; } +} + +var collection = store.GetCollection("hotels"); +await collection.EnsureCollectionExistsAsync(); +``` + +The connector automatically provisions HNSW vector indexes when `EnsureCollectionExistsAsync` is called and a vector property is present. Distance metrics default to cosine similarity but can be overridden through attributes or `LiteDbCollectionOptions`. + +## Filtering + +LiteDB filter translation covers the standard comparison operators along with string and set membership helpers. The following table shows a subset of the mappings: + +| Semantic Kernel filter | LiteDB predicate | +| --- | --- | +| `r => r.Rating >= 4` | `($.Rating >= @0)` | +| `r => r.City.StartsWith("Sea")` | `($.City LIKE @0)` | +| `r => r.Description.EndsWith("Inn")` | `($.Description LIKE @0)` | +| `r => r.Tags.Contains("spa")` | `($.Tags ANY = @0)` | +| `r => new[] { "Seattle", "Portland" }.Contains(r.City)` | `($.City IN @0)` | + +Captured variables are parameterized, and unsupported constructs (such as the `StringComparison` overloads) throw `NotSupportedException` so that filters never silently fall back to client-side evaluation. + +## Generating embeddings on the fly + +When a record's vector property is a non-vector type (for example `string` or `DataContent`), configure an `IEmbeddingGenerator` so that LiteDB receives vectors during `UpsertAsync` and vector search: + +```csharp +var options = new LiteDbVectorStoreOptions +{ + EmbeddingGenerator = myTextEmbeddingGenerator +}; + +var store = new LiteDbVectorStore("Filename=sk.db", options); +var collection = store.GetCollection("articles"); +``` + +The store-level generator is inherited by collections, but you can override it per collection via `LiteDbCollectionOptions` or per property through a `VectorStoreCollectionDefinition`. + +## Batch upserts and transactions + +`UpsertAsync(IEnumerable)` executes inside a LiteDB transaction. Either all documents in the batch are written or an exception is thrown and the collection is left unchanged. Single-record upserts skip the transaction for better throughput. + +## Dependency injection + +`Microsoft.Extensions.DependencyInjection` helpers are available so that the LiteDB connector participates in common hosting patterns: + +```csharp +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel.Connectors.LiteDb; + +var services = new ServiceCollection(); + +services.AddLiteDbVectorStore(sp => new LiteDbVectorStoreOptions +{ + DatabaseFactory = () => new LiteDatabase("Filename=sk.db;Connection=shared") +}); + +services.AddLiteDbCollection("hotels"); +services.AddLiteDbDynamicCollection("snippets", _ => definition); +``` + +Registered collections automatically resolve a `VectorStoreCollection` and `IVectorSearchable` backed by the keyed store, and they inherit the embedding generator registered in the DI container when one is not provided in the options. + +## Performance considerations + +- Index creation happens when `EnsureCollectionExistsAsync` runs. For large collections, schedule this ahead of ingesting data or during maintenance windows. +- Filters that leverage string operators and `IN` clauses are translated to server-side predicates; prefer them over manual in-memory filtering for better selectivity. +- Vectors are validated at write time so that accidental dimension mismatches are caught before they reach storage. + +## Dynamic collections + +Dynamic collections let you work with `Dictionary` payloads without defining a CLR type: + +```csharp +var definition = new VectorStoreCollectionDefinition +{ + Properties = + { + new VectorStoreKeyProperty("Id", typeof(string)), + new VectorStoreDataProperty("Category", typeof(string)), + new VectorStoreVectorProperty("Embedding", typeof(ReadOnlyMemory), dimensions: 3) + } +}; + +var dynamicCollection = store.GetDynamicCollection("snippets", definition); +``` + +## Limitations + +- LiteDB's vector APIs are synchronous; high-throughput scenarios should run on background threads or batch operations. +- LiteDB databases are single-process; avoid opening the same file from multiple processes simultaneously. (although supported) +- `IncludeVectors` cannot be enabled on retrieval operations when embedding generation is configured, matching the behavior of other Semantic Kernel connectors. diff --git a/dotnet/test/VectorData/LiteDb.UnitTests/LiteDb.UnitTests.csproj b/dotnet/test/VectorData/LiteDb.UnitTests/LiteDb.UnitTests.csproj new file mode 100644 index 000000000000..2ca8c56e816a --- /dev/null +++ b/dotnet/test/VectorData/LiteDb.UnitTests/LiteDb.UnitTests.csproj @@ -0,0 +1,34 @@ + + + + SemanticKernel.Connectors.LiteDb.UnitTests + SemanticKernel.Connectors.LiteDb.UnitTests + net8.0 + true + enable + disable + false + $(NoWarn);SKEXP0001,VSTHRD111,CA2007,CS1591 + $(NoWarn);MEVD9001 + + + + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + diff --git a/dotnet/test/VectorData/LiteDb.UnitTests/LiteDbFilterTranslatorTests.cs b/dotnet/test/VectorData/LiteDb.UnitTests/LiteDbFilterTranslatorTests.cs new file mode 100644 index 000000000000..46a5418389f0 --- /dev/null +++ b/dotnet/test/VectorData/LiteDb.UnitTests/LiteDbFilterTranslatorTests.cs @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.LiteDb; +using Xunit; + +namespace SemanticKernel.Connectors.LiteDb.UnitTests; + +public sealed class LiteDbFilterTranslatorTests +{ + [Fact] + public void TranslatesStringAndMembershipOperators() + { + var builder = new LiteDbModelBuilder(); + var model = builder.Build(typeof(FilterHotel), definition: null, defaultEmbeddingGenerator: null); + Expression> filter = h => + (h.Tags.Contains("spa") || h.City.StartsWith("Sea")) + && new[] { "Seattle", "Portland" }.Contains(h.City) + && h.Description.EndsWith("Inn"); + + var translator = new LiteDbFilterTranslator(); + var (expression, parameters) = translator.Translate(filter, model); + + Assert.Equal("(((($.Tags ANY = @0) OR ($.City LIKE @1)) AND ($.City IN @2)) AND ($.Description LIKE @3))", expression); + Assert.Collection(parameters, + p => Assert.Equal("spa", p.AsString), + p => Assert.Equal("Sea%", p.AsString), + p => + { + var array = p.AsArray; + Assert.Equal(2, array.Count); + Assert.Equal("Seattle", array[0].AsString); + Assert.Equal("Portland", array[1].AsString); + }, + p => Assert.Equal("%Inn", p.AsString)); + } + + [Fact] + public void ThrowsHelpfulErrorForUnsupportedStringComparison() + { + var builder = new LiteDbModelBuilder(); + var model = builder.Build(typeof(FilterHotel), definition: null, defaultEmbeddingGenerator: null); + Expression> filter = h => h.City.StartsWith("Sea", StringComparison.OrdinalIgnoreCase); + + var translator = new LiteDbFilterTranslator(); + var exception = Assert.Throws(() => translator.Translate(filter, model)); + Assert.Contains("single value argument", exception.Message); + } + + [System.Diagnostics.CodeAnalysis.SuppressMessage("Performance", "CA1812:Avoid uninstantiated internal classes", Justification = "Instantiated via reflection during model binding.")] + private sealed class FilterHotel + { + [VectorStoreKey] + public string Id { get; set; } = string.Empty; + + [VectorStoreData] + public string[] Tags { get; set; } = Array.Empty(); + + [VectorStoreData] + public string City { get; set; } = string.Empty; + + [VectorStoreData] + public string Description { get; set; } = string.Empty; + + [VectorStoreVector(Dimensions: 3)] + public ReadOnlyMemory? Embedding { get; set; } + } +} diff --git a/dotnet/test/VectorData/LiteDb.UnitTests/LiteDbVectorStoreTests.cs b/dotnet/test/VectorData/LiteDb.UnitTests/LiteDbVectorStoreTests.cs new file mode 100644 index 000000000000..ef287deb965a --- /dev/null +++ b/dotnet/test/VectorData/LiteDb.UnitTests/LiteDbVectorStoreTests.cs @@ -0,0 +1,981 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using LiteDB; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel.Connectors.LiteDb; +using Xunit; + +namespace SemanticKernel.Connectors.LiteDb.UnitTests; + +public sealed class LiteDbVectorStoreTests +{ + [Fact] + public async Task UpsertAndGetRecordWithVectorsAsync() + { + using var database = new LiteDatabase(new MemoryStream()); + using var store = new LiteDbVectorStore(database, new LiteDbVectorStoreOptions { DisposeDatabase = false }); + + var collection = store.GetCollection("hotels"); + await collection.EnsureCollectionExistsAsync(); + + var record = new TestHotel + { + HotelId = "alpha", + HotelName = "Alpha", + DescriptionEmbedding = new ReadOnlyMemory(new[] { 1f, 0f, 0f }) + }; + + await collection.UpsertAsync(record); + var fetched = await collection.GetAsync("alpha", new RecordRetrievalOptions { IncludeVectors = true }); + + Assert.NotNull(fetched); + Assert.Equal("Alpha", fetched!.HotelName); + Assert.True(fetched.DescriptionEmbedding.HasValue); + Assert.Equal(record.DescriptionEmbedding!.Value.ToArray(), fetched.DescriptionEmbedding!.Value.ToArray()); + } + + [Fact] + public async Task VectorSearchReturnsNearestNeighborsAsync() + { + using var database = new LiteDatabase(new MemoryStream()); + using var store = new LiteDbVectorStore(database, new LiteDbVectorStoreOptions { DisposeDatabase = false }); + + var collection = store.GetCollection("hotels"); + await collection.EnsureCollectionExistsAsync(); + + await collection.UpsertAsync(new[] + { + new TestHotel { HotelId = "alpha", HotelName = "Alpha", DescriptionEmbedding = new ReadOnlyMemory(new[] { 1f, 0f, 0f }) }, + new TestHotel { HotelId = "beta", HotelName = "Beta", DescriptionEmbedding = new ReadOnlyMemory(new[] { 0f, 1f, 0f }) }, + new TestHotel { HotelId = "gamma", HotelName = "Gamma", DescriptionEmbedding = new ReadOnlyMemory(new[] { 0f, 0f, 1f }) } + }); + + var results = new List>(); + await foreach (var result in collection.SearchAsync(new ReadOnlyMemory(new[] { 0f, 1f, 0f }), top: 2)) + { + results.Add(result); + } + + Assert.Equal(2, results.Count); + var first = results[0]; + var second = results[1]; + Assert.NotNull(first.Record); + var firstRecord = first.Record!; + Assert.NotNull(firstRecord.HotelId); + var firstHotelId = firstRecord.HotelId!; + Assert.Equal("beta", firstHotelId); + Assert.True(first.Score >= second.Score); + } + + [Fact] + public async Task FilteredQueryReturnsExpectedRecordsAsync() + { + using var database = new LiteDatabase(new MemoryStream()); + using var store = new LiteDbVectorStore(database, new LiteDbVectorStoreOptions { DisposeDatabase = false }); + + var collection = store.GetCollection("hotels"); + await collection.EnsureCollectionExistsAsync(); + + await collection.UpsertAsync(new[] + { + new TestHotel { HotelId = "alpha", HotelName = "Alpha", Rating = 4, City = "Seattle", DescriptionEmbedding = new ReadOnlyMemory(new[] { 1f, 0f, 0f }) }, + new TestHotel { HotelId = "beta", HotelName = "Beta", Rating = 3, City = "Portland", DescriptionEmbedding = new ReadOnlyMemory(new[] { 0f, 1f, 0f }) }, + new TestHotel { HotelId = "gamma", HotelName = "Gamma", Rating = 5, City = "Seattle", DescriptionEmbedding = new ReadOnlyMemory(new[] { 0f, 0f, 1f }) } + }); + + var matches = new List(); + await foreach (var item in collection.GetAsync(h => h.Rating >= 4, top: 5)) + { + matches.Add(item); + } + + Assert.Equal(2, matches.Count); + Assert.Contains(matches, m => m.HotelId == "alpha"); + Assert.Contains(matches, m => m.HotelId == "gamma"); + } + + [Fact] + public async Task EmbeddingGeneratorPopulatesVectorsAsync() + { + using var database = new LiteDatabase(new MemoryStream()); + using var store = new LiteDbVectorStore(database, new LiteDbVectorStoreOptions + { + DisposeDatabase = false, + EmbeddingGenerator = new StringEmbeddingGenerator() + }); + + var collection = store.GetCollection("generated_hotels"); + await collection.EnsureCollectionExistsAsync(); + + var record = new GeneratedHotel + { + HotelId = "alpha", + Description = "1,0,0" + }; + + await collection.UpsertAsync(record); + + var document = database.GetCollection("generated_hotels").FindById("alpha"); + + Assert.NotNull(document); + Assert.True(document.TryGetValue(nameof(GeneratedHotel.Description), out var value)); + Assert.IsType(value); + Assert.Equal(new[] { 1f, 0f, 0f }, ((BsonVector)value).Values); + + var searchResults = new List>(); + await foreach (var result in collection.SearchAsync("1,0,0", top: 1)) + { + searchResults.Add(result); + } + + Assert.Single(searchResults); + Assert.Equal("alpha", searchResults[0].Record.HotelId); + } + + [Fact] + public async Task SearchHonorsDistanceMetricsAsync() + { + using var database = new LiteDatabase(new MemoryStream()); + using var store = new LiteDbVectorStore(database, new LiteDbVectorStoreOptions { DisposeDatabase = false }); + + var dotCollection = store.GetCollection("dot_hotels"); + await dotCollection.EnsureCollectionExistsAsync(); + + await dotCollection.UpsertAsync(new[] + { + new DotProductHotel { HotelId = "alpha", Embedding = new ReadOnlyMemory(new[] { 1f, 0f, 0f }) }, + new DotProductHotel { HotelId = "beta", Embedding = new ReadOnlyMemory(new[] { 0f, 1f, 0f }) } + }); + + var dotResults = new List>(); + await foreach (var result in dotCollection.SearchAsync(new ReadOnlyMemory(new[] { 0f, 1f, 0f }), top: 2)) + { + dotResults.Add(result); + } + + Assert.Equal("beta", dotResults[0].Record.HotelId); + Assert.True(dotResults[0].Score >= dotResults[1].Score); + + var euclideanCollection = store.GetCollection("euclidean_hotels"); + await euclideanCollection.EnsureCollectionExistsAsync(); + + await euclideanCollection.UpsertAsync(new[] + { + new EuclideanHotel { HotelId = "near", Embedding = new ReadOnlyMemory(new[] { 1f, 0f, 0f }) }, + new EuclideanHotel { HotelId = "far", Embedding = new ReadOnlyMemory(new[] { 2f, 0f, 0f }) } + }); + + var euclideanResults = new List>(); + await foreach (var result in euclideanCollection.SearchAsync(new ReadOnlyMemory(new[] { 0f, 0f, 0f }), top: 2)) + { + euclideanResults.Add(result); + } + + Assert.Equal("near", euclideanResults[0].Record.HotelId); + Assert.True(euclideanResults[0].Score <= euclideanResults[1].Score); + } + + [Fact] + public async Task FilterSupportsLogicalOperatorsAndSkipAsync() + { + using var database = new LiteDatabase(new MemoryStream()); + using var store = new LiteDbVectorStore(database, new LiteDbVectorStoreOptions { DisposeDatabase = false }); + + var collection = store.GetCollection("hotels"); + await collection.EnsureCollectionExistsAsync(); + + await collection.UpsertAsync(new[] + { + new TestHotel { HotelId = "alpha", HotelName = "Alpha", Rating = 5, City = "Seattle", DescriptionEmbedding = new ReadOnlyMemory(new[] { 1f, 0f, 0f }) }, + new TestHotel { HotelId = "beta", HotelName = "Beta", Rating = 4, City = "Seattle", DescriptionEmbedding = new ReadOnlyMemory(new[] { 0f, 1f, 0f }) }, + new TestHotel { HotelId = "gamma", HotelName = "Gamma", Rating = 3, City = "Portland", DescriptionEmbedding = new ReadOnlyMemory(new[] { 0f, 0f, 1f }) }, + new TestHotel { HotelId = "delta", HotelName = "Delta", Rating = 4, City = "Portland", DescriptionEmbedding = new ReadOnlyMemory(new[] { 0.5f, 0.5f, 0f }) } + }); + + var options = new FilteredRecordRetrievalOptions + { + Skip = 1 + }; + + var filtered = new List(); + await foreach (var item in collection.GetAsync(h => (h.City == "Seattle" && h.Rating >= 4) || h.City == "Portland", top: 3, options)) + { + filtered.Add(item); + } + + Assert.Equal(3, filtered.Count); + Assert.DoesNotContain(filtered, h => h.HotelId == "alpha"); + Assert.Contains(filtered, h => h.HotelId == "beta"); + Assert.Contains(filtered, h => h.HotelId == "gamma"); + Assert.Contains(filtered, h => h.HotelId == "delta"); + } + + [Fact] + public async Task DynamicCollectionRoundtripsRecordsAsync() + { + using var database = new LiteDatabase(new MemoryStream()); + using var store = new LiteDbVectorStore(database, new LiteDbVectorStoreOptions { DisposeDatabase = false }); + + var definition = new VectorStoreCollectionDefinition + { + Properties = + { + new VectorStoreKeyProperty("Id", typeof(string)), + new VectorStoreDataProperty("Category", typeof(string)), + new VectorStoreVectorProperty("Embedding", typeof(ReadOnlyMemory), 3) + } + }; + + var collection = store.GetDynamicCollection("dynamic_hotels", definition); + await collection.EnsureCollectionExistsAsync(); + + var record = new Dictionary + { + ["Id"] = "alpha", + ["Category"] = "city", + ["Embedding"] = new ReadOnlyMemory(new[] { 1f, 0f, 0f }) + }; + + await collection.UpsertAsync(record); + + var fetched = await collection.GetAsync("alpha", new RecordRetrievalOptions { IncludeVectors = true }); + + Assert.NotNull(fetched); + Assert.Equal("city", fetched!["Category"]); + var embedding = (ReadOnlyMemory)fetched["Embedding"]!; + Assert.Equal(new[] { 1f, 0f, 0f }, embedding.ToArray()); + } + + [Fact] + public async Task CollectionPrefixIsAppliedToStorageNameAsync() + { + using var database = new LiteDatabase(new MemoryStream()); + database.GetCollection("legacy").Insert(new BsonDocument { ["_id"] = "legacy" }); + + var options = new LiteDbVectorStoreOptions + { + DisposeDatabase = false, + CollectionNamePrefix = "sk_" + }; + + using var store = new LiteDbVectorStore(database, options); + + var collection = store.GetCollection("hotels"); + await collection.EnsureCollectionExistsAsync(); + + var names = database.GetCollectionNames().ToArray(); + Assert.Contains("sk_hotels", names); + Assert.Equal("sk_hotels", collection.Name); + + var logicalNames = new List(); + await foreach (var name in store.ListCollectionNamesAsync()) + { + logicalNames.Add(name); + } + + Assert.Contains("hotels", logicalNames); + Assert.DoesNotContain("sk_hotels", logicalNames); + Assert.DoesNotContain("legacy", logicalNames); + Assert.True(await store.CollectionExistsAsync("hotels")); + + await store.EnsureCollectionDeletedAsync("hotels"); + names = database.GetCollectionNames().ToArray(); + Assert.DoesNotContain("sk_hotels", names); + } + + [Fact] + public async Task DatabaseFactoryRespectsDisposeFlagAsync() + { + var stream = new MemoryStream(); + var callCount = 0; + + var options = new LiteDbVectorStoreOptions + { + DisposeDatabase = false, + DatabaseFactory = () => + { + callCount++; + return new LiteDatabase(stream); + } + }; + + using (var store = new LiteDbVectorStore(options)) + { + var collection = store.GetCollection("hotels"); + await collection.EnsureCollectionExistsAsync(); + + await collection.UpsertAsync(new TestHotel + { + HotelId = "alpha", + HotelName = "Alpha", + DescriptionEmbedding = new ReadOnlyMemory(new[] { 1f, 0f, 0f }) + }); + + var fetched = await collection.GetAsync("alpha"); + Assert.NotNull(fetched); + } + + Assert.Equal(1, callCount); + Assert.True(stream.CanRead); + } + + [Fact] + public async Task ServiceCollectionRegistersLiteDbStoreAndCollectionsAsync() + { + var services = new ServiceCollection(); + services.AddSingleton(); + services.AddLiteDbVectorStore(sp => new LiteDbVectorStoreOptions + { + DisposeDatabase = false, + DatabaseFactory = () => new LiteDatabase(new MemoryStream()) + }); + services.AddLiteDbCollection("hotels"); + services.AddLiteDbCollection("generated_hotels"); + + using var provider = services.BuildServiceProvider(); + + var store = provider.GetRequiredService(); + var liteStore = Assert.IsType(store); + + var hotelCollection = provider.GetRequiredService>(); + await hotelCollection.EnsureCollectionExistsAsync(); + + await hotelCollection.UpsertAsync(new TestHotel + { + HotelId = "alpha", + HotelName = "Alpha", + DescriptionEmbedding = new ReadOnlyMemory(new[] { 1f, 0f, 0f }) + }); + + Assert.NotNull(await hotelCollection.GetAsync("alpha")); + + var generatedCollection = provider.GetRequiredService>(); + await generatedCollection.EnsureCollectionExistsAsync(); + + await generatedCollection.UpsertAsync(new GeneratedHotel + { + HotelId = "beta", + Description = "1,0,0" + }); + + var searchResults = new List>(); + await foreach (var result in generatedCollection.SearchAsync("1,0,0", top: 1)) + { + searchResults.Add(result); + } + + Assert.Single(searchResults); + Assert.Equal("beta", searchResults[0].Record.HotelId); + + var collectionNames = new List(); + await foreach (var name in liteStore.ListCollectionNamesAsync()) + { + collectionNames.Add(name); + } + + Assert.Contains("hotels", collectionNames); + Assert.Contains("generated_hotels", collectionNames); + } + + [Fact] + public async Task FilterSupportsStringMembershipAndNestedGroupsAsync() + { + using var database = new LiteDatabase(new MemoryStream()); + using var store = new LiteDbVectorStore(database, new LiteDbVectorStoreOptions { DisposeDatabase = false }); + + var collection = store.GetCollection("tagged_hotels"); + await collection.EnsureCollectionExistsAsync(); + + await collection.UpsertAsync(new[] + { + new TaggedHotel + { + HotelId = "alpha", + City = "Seattle", + Rating = 5, + Tags = new[] { "spa", "downtown" }, + Description = "Coastline Inn", + DescriptionEmbedding = new ReadOnlyMemory(new[] { 1f, 0f, 0f }) + }, + new TaggedHotel + { + HotelId = "beta", + City = "Portland", + Rating = 4, + Tags = new[] { "business" }, + Description = "City Center Hotel", + DescriptionEmbedding = new ReadOnlyMemory(new[] { 0f, 1f, 0f }) + }, + new TaggedHotel + { + HotelId = "gamma", + City = "Seattle", + Rating = 3, + Tags = new[] { "historic" }, + Description = "Harbor Inn", + DescriptionEmbedding = new ReadOnlyMemory(new[] { 0f, 0f, 1f }) + } + }); + + var results = new List(); + var allowedCities = new[] { "Seattle", "Portland" }; + await foreach (var record in collection.GetAsync( + h => (h.Tags.Contains("spa") || h.City!.StartsWith("Sea")) + && allowedCities.Contains(h.City!) + && h.Description!.EndsWith("Inn") + && h.Rating >= 4, + top: 5)) + { + results.Add(record); + } + + var result = Assert.Single(results); + Assert.NotNull(result.HotelId); + var hotelId = result.HotelId!; + Assert.Equal("alpha", hotelId); + } + + [Fact] + public async Task BatchUpsertRollsBackOnFailureAsync() + { + using var database = new LiteDatabase(new MemoryStream()); + using var store = new LiteDbVectorStore(database, new LiteDbVectorStoreOptions { DisposeDatabase = false }); + + var collection = store.GetCollection("hotels"); + await collection.EnsureCollectionExistsAsync(); + + var rawCollection = database.GetCollection("hotels"); + rawCollection.EnsureIndex(nameof(TestHotel.HotelName), unique: true); + + var records = new[] + { + new TestHotel + { + HotelId = "alpha", + HotelName = "Duplicate", + Rating = 4, + City = "Seattle", + DescriptionEmbedding = new ReadOnlyMemory(new[] { 1f, 0f, 0f }) + }, + new TestHotel + { + HotelId = "beta", + HotelName = "Duplicate", + Rating = 5, + City = "Portland", + DescriptionEmbedding = new ReadOnlyMemory(new[] { 0f, 1f, 0f }) + } + }; + + await Assert.ThrowsAsync(() => collection.UpsertAsync(records)); + + Assert.Empty(rawCollection.FindAll()); + } + + [Fact] + public async Task ThrowsWhenVectorDimensionsMismatchAsync() + { + using var database = new LiteDatabase(new MemoryStream()); + using var store = new LiteDbVectorStore(database, new LiteDbVectorStoreOptions { DisposeDatabase = false }); + + var collection = store.GetCollection("hotels"); + await collection.EnsureCollectionExistsAsync(); + + var invalid = new TestHotel + { + HotelId = "alpha", + HotelName = "Alpha", + City = "Seattle", + Rating = 4, + DescriptionEmbedding = new ReadOnlyMemory(new[] { 1f, 0f }) + }; + + var exception = await Assert.ThrowsAsync(() => collection.UpsertAsync(invalid)); + Assert.Contains("expects 3 dimensions", exception.Message, StringComparison.Ordinal); + } + + [Fact] + public async Task PropertyAndCollectionMetricsTakePrecedenceOverStoreDefaultsAsync() + { + using var database = new LiteDatabase(new MemoryStream()); + + var storeOptions = new LiteDbVectorStoreOptions + { + DisposeDatabase = false, + DistanceMetric = LiteDbDistanceMetric.Cosine + }; + + var collectionOptions = new LiteDbCollectionOptions + { + DistanceMetric = LiteDbDistanceMetric.Euclidean + }; + + using var collection = new LiteDbCollection( + database, + "multi_hotels", + storeOptions, + collectionOptions, + opts => new LiteDbModelBuilder().Build(typeof(MultiVectorHotel), opts.Definition, opts.EmbeddingGenerator), + connectionIdentifier: "test"); + + await collection.EnsureCollectionExistsAsync(); + + await collection.UpsertAsync(new[] + { + new MultiVectorHotel + { + HotelId = "alpha", + AmenitiesEmbedding = new ReadOnlyMemory(new[] { 1f, 0f }), + LocationEmbedding = new ReadOnlyMemory(new[] { 0f, 10f }) + }, + new MultiVectorHotel + { + HotelId = "beta", + AmenitiesEmbedding = new ReadOnlyMemory(new[] { 0f, 1f }), + LocationEmbedding = new ReadOnlyMemory(new[] { 0f, 2f }) + } + }); + + var amenityResults = new List>(); + var amenityOptions = new VectorSearchOptions + { + VectorProperty = h => h.AmenitiesEmbedding + }; + + await foreach (var result in collection.SearchAsync(new ReadOnlyMemory(new[] { 0f, 1f }), top: 1, amenityOptions)) + { + amenityResults.Add(result); + } + + Assert.Single(amenityResults); + Assert.Equal("beta", amenityResults[0].Record.HotelId); + + var locationResults = new List>(); + var locationOptions = new VectorSearchOptions + { + VectorProperty = h => h.LocationEmbedding + }; + + await foreach (var result in collection.SearchAsync(new ReadOnlyMemory(new[] { 0f, 0f }), top: 1, locationOptions)) + { + locationResults.Add(result); + } + + Assert.Single(locationResults); + Assert.Equal("beta", locationResults[0].Record.HotelId); + } + + [Fact] + public async Task PropertySpecificEmbeddingGeneratorOverridesDefaultsAsync() + { + using var database = new LiteDatabase(new MemoryStream()); + + var defaultGenerator = new TrackingStringEmbeddingGenerator(value => new[] { 9f, 9f, 9f }); + var overrideGenerator = new TrackingStringEmbeddingGenerator(ParseVector); + + var storeOptions = new LiteDbVectorStoreOptions + { + DisposeDatabase = false, + EmbeddingGenerator = defaultGenerator + }; + + using var store = new LiteDbVectorStore(database, storeOptions); + + var definition = new VectorStoreCollectionDefinition + { + Properties = + { + new VectorStoreKeyProperty(nameof(OverrideHotel.HotelId), typeof(string)), + new VectorStoreVectorProperty(nameof(OverrideHotel.Overview), 3) + { + EmbeddingGenerator = overrideGenerator + }, + new VectorStoreVectorProperty(nameof(OverrideHotel.Amenities), 3) + } + }; + + var collection = store.GetCollection("override_hotels", definition); + await collection.EnsureCollectionExistsAsync(); + + var record = new OverrideHotel + { + HotelId = "alpha", + Overview = "1,0,0", + Amenities = "0,1,0" + }; + + await collection.UpsertAsync(record); + + var raw = database.GetCollection("override_hotels").FindById("alpha"); + Assert.NotNull(raw); + Assert.Equal(new[] { 1f, 0f, 0f }, ((BsonVector)raw![nameof(OverrideHotel.Overview)]).Values); + Assert.Equal(new[] { 9f, 9f, 9f }, ((BsonVector)raw![nameof(OverrideHotel.Amenities)]).Values); + + Assert.Equal(1, overrideGenerator.BatchCalls); + Assert.Single(overrideGenerator.BatchInputs); + Assert.Equal("1,0,0", overrideGenerator.BatchInputs[0]); + Assert.Equal(1, defaultGenerator.BatchCalls); + Assert.Single(defaultGenerator.BatchInputs); + Assert.Equal("0,1,0", defaultGenerator.BatchInputs[0]); + + var options = new VectorSearchOptions + { + VectorProperty = h => h.Overview + }; + + var searchResults = new List>(); + await foreach (var result in collection.SearchAsync("1,0,0", top: 1, options)) + { + searchResults.Add(result); + } + + Assert.Single(searchResults); + Assert.Equal("alpha", searchResults[0].Record.HotelId); + } + + [Fact] + public async Task CancellationStopsEmbeddingGenerationOnUpsertAsync() + { + using var database = new LiteDatabase(new MemoryStream()); + var generator = new CancellableEmbeddingGenerator(); + + using var store = new LiteDbVectorStore(database, new LiteDbVectorStoreOptions + { + DisposeDatabase = false, + EmbeddingGenerator = generator + }); + + var collection = store.GetCollection("generated_hotels"); + await collection.EnsureCollectionExistsAsync(); + + using var cts = new CancellationTokenSource(); + cts.Cancel(); + + await Assert.ThrowsAsync(() => collection.UpsertAsync(new GeneratedHotel + { + HotelId = "alpha", + Description = "1,0,0" + }, cts.Token)); + + Assert.Null(database.GetCollection("generated_hotels").FindById("alpha")); + } + + [Fact] + public async Task CancellationStopsEmbeddingGenerationOnSearchAsync() + { + using var database = new LiteDatabase(new MemoryStream()); + var generator = new CancellableEmbeddingGenerator(); + + using var store = new LiteDbVectorStore(database, new LiteDbVectorStoreOptions + { + DisposeDatabase = false, + EmbeddingGenerator = generator + }); + + var collection = store.GetCollection("generated_hotels"); + await collection.EnsureCollectionExistsAsync(); + + await collection.UpsertAsync(new GeneratedHotel + { + HotelId = "alpha", + Description = "1,0,0" + }); + + using var cts = new CancellationTokenSource(); + cts.Cancel(); + + await Assert.ThrowsAsync(async () => + { + await foreach (var _ in collection.SearchAsync("1,0,0", top: 1, cancellationToken: cts.Token)) + { + } + }); + } + + [Fact] + public void AutoCreateVectorIndexesAliasKeepsOptionsInSync() + { + var options = new LiteDbVectorStoreOptions + { + AutoCreateVectorIndexes = false + }; + + Assert.False(options.AutoEnsureVectorIndex); + + options.AutoEnsureVectorIndex = true; + Assert.True(options.AutoCreateVectorIndexes); + + options.AutoCreateVectorIndexes = false; + Assert.False(options.AutoEnsureVectorIndex); + } + + [Fact] + public async Task DatabaseFactoryHasHighestPrecedenceAsync() + { + var providedStream = new MemoryStream(); + using var providedDatabase = new LiteDatabase(providedStream); + + LiteDatabase? factoryDatabase = null; + var options = new LiteDbVectorStoreOptions + { + DisposeDatabase = false, + Database = providedDatabase, + DatabaseFactory = () => + { + factoryDatabase = new LiteDatabase(new MemoryStream()); + return factoryDatabase; + } + }; + + using (var store = new LiteDbVectorStore(options)) + { + var collection = store.GetCollection("hotels"); + await collection.EnsureCollectionExistsAsync(); + await collection.UpsertAsync(new TestHotel + { + HotelId = "alpha", + HotelName = "Alpha", + City = "Seattle", + Rating = 4, + DescriptionEmbedding = new ReadOnlyMemory(new[] { 1f, 0f, 0f }) + }); + } + + Assert.NotNull(factoryDatabase); + Assert.NotNull(factoryDatabase!.GetCollection("hotels").FindById("alpha")); + Assert.Null(providedDatabase.GetCollection("hotels").FindById("alpha")); + } + + [Fact] + public async Task DatabaseInstanceOverridesConnectionStringAsync() + { + using var providedDatabase = new LiteDatabase(new MemoryStream()); + var options = new LiteDbVectorStoreOptions + { + DisposeDatabase = false, + Database = providedDatabase, + ConnectionString = "Filename=ignored.db" + }; + + using (var store = new LiteDbVectorStore("Filename=also-ignored.db", options)) + { + var collection = store.GetCollection("hotels"); + await collection.EnsureCollectionExistsAsync(); + await collection.UpsertAsync(new TestHotel + { + HotelId = "alpha", + HotelName = "Alpha", + City = "Seattle", + Rating = 4, + DescriptionEmbedding = new ReadOnlyMemory(new[] { 1f, 0f, 0f }) + }); + } + + Assert.NotNull(providedDatabase.GetCollection("hotels").FindById("alpha")); + } + + private static float[] ParseVector(string value) + => value.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries) + .Select(float.Parse) + .ToArray(); + + private sealed class TaggedHotel + { + [VectorStoreKey] + public string? HotelId { get; set; } + + [VectorStoreData] + public string[] Tags { get; set; } = Array.Empty(); + + [VectorStoreData] + public string? City { get; set; } + + [VectorStoreData] + public int Rating { get; set; } + + [VectorStoreData] + public string? Description { get; set; } + + [VectorStoreVector(Dimensions: 3)] + public ReadOnlyMemory? DescriptionEmbedding { get; set; } + } + + private sealed class MultiVectorHotel + { + [VectorStoreKey] + public string? HotelId { get; set; } + + [VectorStoreVector(Dimensions: 2, DistanceFunction = DistanceFunction.DotProductSimilarity)] + public ReadOnlyMemory? AmenitiesEmbedding { get; set; } + + [VectorStoreVector(Dimensions: 2)] + public ReadOnlyMemory? LocationEmbedding { get; set; } + } + + private sealed class OverrideHotel + { + [VectorStoreKey] + public string? HotelId { get; set; } + + public string? Overview { get; set; } + + public string? Amenities { get; set; } + } + + private sealed class TrackingStringEmbeddingGenerator : IEmbeddingGenerator> + { + private readonly Func _projection; + + internal TrackingStringEmbeddingGenerator(Func projection) + { + this._projection = projection; + } + + public int BatchCalls { get; private set; } + + public int SingleCalls { get; private set; } + + public List BatchInputs { get; } = new(); + + public List SingleInputs { get; } = new(); + + public Task>> GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + this.BatchCalls++; + + var embeddings = new GeneratedEmbeddings>(); + foreach (var value in values) + { + cancellationToken.ThrowIfCancellationRequested(); + this.BatchInputs.Add(value); + embeddings.Add(new Embedding(this._projection(value))); + } + + return Task.FromResult(embeddings); + } + + public Task> GenerateAsync(string value, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + this.SingleCalls++; + this.SingleInputs.Add(value); + return Task.FromResult(new Embedding(this._projection(value))); + } + + public object? GetService(Type serviceType, object? serviceKey = null) + => null; + + public void Dispose() + { + } + } + + private sealed class CancellableEmbeddingGenerator : IEmbeddingGenerator> + { + public Task>> GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + var embeddings = new GeneratedEmbeddings>(); + foreach (var value in values) + { + cancellationToken.ThrowIfCancellationRequested(); + embeddings.Add(new Embedding(ParseVector(value))); + } + + return Task.FromResult(embeddings); + } + + public Task> GenerateAsync(string value, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + return Task.FromResult(new Embedding(ParseVector(value))); + } + + public object? GetService(Type serviceType, object? serviceKey = null) + => null; + + public void Dispose() + { + } + } + + private sealed class TestHotel + { + [VectorStoreKey] + public string? HotelId { get; set; } + + [VectorStoreData(IsIndexed = true)] + public string? HotelName { get; set; } + + [VectorStoreData] + public int Rating { get; set; } + + [VectorStoreData] + public string? City { get; set; } + + [VectorStoreVector(Dimensions: 3, DistanceFunction = DistanceFunction.CosineSimilarity)] + public ReadOnlyMemory? DescriptionEmbedding { get; set; } + } + + private sealed class GeneratedHotel + { + [VectorStoreKey] + public string? HotelId { get; set; } + + [VectorStoreVector(Dimensions: 3)] + public string? Description { get; set; } + } + + private sealed class DotProductHotel + { + [VectorStoreKey] + public string? HotelId { get; set; } + + [VectorStoreVector(Dimensions: 3, DistanceFunction = DistanceFunction.DotProductSimilarity)] + public ReadOnlyMemory? Embedding { get; set; } + } + + private sealed class EuclideanHotel + { + [VectorStoreKey] + public string? HotelId { get; set; } + + [VectorStoreVector(Dimensions: 3, DistanceFunction = DistanceFunction.EuclideanDistance)] + public ReadOnlyMemory? Embedding { get; set; } + } + + private sealed class StringEmbeddingGenerator : IEmbeddingGenerator> + { + public Task>> GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) + { + var embeddings = new GeneratedEmbeddings>(); + + foreach (var value in values) + { + var vector = value.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries) + .Select(float.Parse) + .ToArray(); + + embeddings.Add(new Embedding(vector)); + } + + return Task.FromResult(embeddings); + } + + public object? GetService(Type serviceType, object? serviceKey = null) + => null; + + public void Dispose() + { + } + } +}