Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 159 additions & 4 deletions src/store/src/Bridge/Cache/Store.php
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@
use Symfony\AI\Store\Document\Metadata;
use Symfony\AI\Store\Document\VectorDocument;
use Symfony\AI\Store\Exception\InvalidArgumentException;
use Symfony\AI\Store\Exception\UnsupportedQueryTypeException;
use Symfony\AI\Store\ManagedStoreInterface;
use Symfony\AI\Store\Query\Filter\EqualFilter;
use Symfony\AI\Store\Query\HybridQuery;
use Symfony\AI\Store\Query\QueryInterface;
use Symfony\AI\Store\Query\TextQuery;
use Symfony\AI\Store\Query\VectorQuery;
use Symfony\AI\Store\StoreInterface;
use Symfony\Contracts\Cache\CacheInterface;

Expand Down Expand Up @@ -96,14 +102,44 @@ public function remove(string|array $ids, array $options = []): void
$this->cache->save($cacheItem);
}

public function supports(string $queryClass): bool
{
return \in_array($queryClass, [
VectorQuery::class,
TextQuery::class,
HybridQuery::class,
], true);
}

/**
* @param array{
* maxItems?: positive-int,
* filter?: callable(VectorDocument): bool
* } $options If maxItems is provided, only the top N results will be returned.
* If filter is provided, only documents matching the filter will be considered.
*/
public function query(Vector $vector, array $options = []): iterable
public function query(QueryInterface $query, array $options = []): iterable
{
return match (true) {
$query instanceof VectorQuery => $this->queryVector($query, $options),
$query instanceof TextQuery => $this->queryText($query, $options),
$query instanceof HybridQuery => $this->queryHybrid($query, $options),
default => throw new UnsupportedQueryTypeException($query->getType(), $this),
};
}

public function drop(array $options = []): void
{
$this->cache->clear();
}

/**
* @param array{
* maxItems?: positive-int,
* filter?: callable(VectorDocument): bool,
* } $options
*/
private function queryVector(VectorQuery $query, array $options): iterable
{
$documents = $this->cache->get($this->cacheKey, static fn (): array => []);

Expand All @@ -117,15 +153,134 @@ public function query(Vector $vector, array $options = []): iterable
metadata: new Metadata($document['metadata']),
), $documents);

$vectorDocuments = $this->applyFilter($vectorDocuments, $query->getFilter());

if (isset($options['filter'])) {
$vectorDocuments = array_values(array_filter($vectorDocuments, $options['filter']));
}

yield from $this->distanceCalculator->calculate($vectorDocuments, $vector, $options['maxItems'] ?? null);
yield from $this->distanceCalculator->calculate($vectorDocuments, $query->getVector(), $options['maxItems'] ?? null);
}

public function drop(array $options = []): void
/**
* @param array{
* maxItems?: positive-int,
* filter?: callable(VectorDocument): bool,
* } $options
*/
private function queryText(TextQuery $query, array $options): iterable
{
$this->cache->clear();
$documents = $this->cache->get($this->cacheKey, static fn (): array => []);

if ([] === $documents) {
return;
}

$vectorDocuments = array_map(static fn (array $document): VectorDocument => new VectorDocument(
id: $document['id'],
vector: new Vector($document['vector']),
metadata: new Metadata($document['metadata']),
), $documents);

if (isset($options['filter'])) {
$vectorDocuments = array_values(array_filter($vectorDocuments, $options['filter']));
}

$filteredDocuments = array_filter($vectorDocuments, function (VectorDocument $doc) use ($query) {
$text = $doc->metadata->getText() ?? '';

return str_contains(strtolower($text), strtolower($query->getText()));
});

$filteredDocuments = $this->applyFilter($filteredDocuments, $query->getFilter());

$maxItems = $options['maxItems'] ?? null;
$count = 0;

foreach ($filteredDocuments as $document) {
if (null !== $maxItems && $count >= $maxItems) {
break;
}

yield $document;
++$count;
}
}

/**
* @param array{
* maxItems?: positive-int,
* filter?: callable(VectorDocument): bool,
* } $options
*/
private function queryHybrid(HybridQuery $query, array $options): iterable
{
$vectorResults = iterator_to_array($this->queryVector(
new VectorQuery($query->getVector(), $query->getFilter()),
$options
));

$textResults = iterator_to_array($this->queryText(
new TextQuery($query->getText(), $query->getFilter()),
$options
));

$mergedResults = [];
$seenIds = [];

foreach ($vectorResults as $doc) {
$id = $doc->id->toRfc4122();
if (!isset($seenIds[$id])) {
$mergedResults[] = new VectorDocument(
id: $doc->id,
vector: $doc->vector,
metadata: $doc->metadata,
score: null !== $doc->score ? $doc->score * $query->getSemanticRatio() : null,
);
$seenIds[$id] = true;
}
}

foreach ($textResults as $doc) {
$id = $doc->id->toRfc4122();
if (!isset($seenIds[$id])) {
$mergedResults[] = $doc;
$seenIds[$id] = true;
}
}

if (isset($options['filter'])) {
$mergedResults = array_values(array_filter($mergedResults, $options['filter']));
}

$maxItems = $options['maxItems'] ?? null;
$count = 0;

foreach ($mergedResults as $document) {
if (null !== $maxItems && $count >= $maxItems) {
break;
}

yield $document;
++$count;
}
}

/**
* @param VectorDocument[] $documents
*
* @return VectorDocument[]
*/
private function applyFilter(array $documents, $filter): array
{
if (!$filter instanceof EqualFilter) {
return $documents;
}

return array_values(array_filter($documents, function (VectorDocument $doc) use ($filter) {
$metadata = $doc->metadata->getArrayCopy();

return isset($metadata[$filter->getField()]) && $metadata[$filter->getField()] === $filter->getValue();
}));
}
}
95 changes: 82 additions & 13 deletions src/store/src/Bridge/ChromaDb/Store.php
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
use Symfony\AI\Platform\Vector\Vector;
use Symfony\AI\Store\Document\Metadata;
use Symfony\AI\Store\Document\VectorDocument;
use Symfony\AI\Store\Exception\UnsupportedQueryTypeException;
use Symfony\AI\Store\Query\Filter\EqualFilter;
use Symfony\AI\Store\Query\QueryInterface;
use Symfony\AI\Store\Query\TextQuery;
use Symfony\AI\Store\Query\VectorQuery;
use Symfony\AI\Store\StoreInterface;

/**
Expand Down Expand Up @@ -67,30 +72,94 @@ public function remove(string|array $ids, array $options = []): void
$collection->delete(ids: $ids);
}

public function supports(string $queryClass): bool
{
return \in_array($queryClass, [
VectorQuery::class,
TextQuery::class,
], true);
}

/**
* @param array{where?: array<string, string>, whereDocument?: array<string, mixed>, include?: array<string>, queryTexts?: array<string>} $options
* @param array{where?: array<string, string>, whereDocument?: array<string, mixed>, include?: array<string>} $options
*/
public function query(Vector $vector, array $options = []): iterable
public function query(QueryInterface $query, array $options = []): iterable
{
$include = null;
if ([] !== ($options['include'] ?? [])) {
$include = array_values(
array_unique(
array_merge(['embeddings', 'metadatas', 'distances'], $options['include'])
)
);
if (!$this->supports($query::class)) {
throw new UnsupportedQueryTypeException($query->getType(), $this);
}

return match (true) {
$query instanceof VectorQuery => $this->queryVector($query, $options),
$query instanceof TextQuery => $this->queryText($query, $options),
default => throw new UnsupportedQueryTypeException($query->getType(), $this),
};
}

/**
* @param array{where?: array<string, string>, whereDocument?: array<string, mixed>, include?: array<string>, limit?: positive-int} $options
*/
private function queryVector(VectorQuery $query, array $options): iterable
{
$include = $this->buildInclude($options);
$where = $this->buildWhere($query->getFilter(), $options);

$collection = $this->client->getOrCreateCollection($this->collectionName);
$queryResponse = $collection->query(
queryEmbeddings: [$query->getVector()->getData()],
nResults: $options['limit'] ?? 4,
where: $where,
whereDocument: $options['whereDocument'] ?? null,
include: $include,
);

yield from $this->transformResponse($queryResponse);
}

/**
* @param array{where?: array<string, string>, whereDocument?: array<string, mixed>, include?: array<string>, limit?: positive-int} $options
*/
private function queryText(TextQuery $query, array $options): iterable
{
$include = $this->buildInclude($options);
$where = $this->buildWhere($query->getFilter(), $options);

$collection = $this->client->getOrCreateCollection($this->collectionName);
$queryResponse = $collection->query(
queryEmbeddings: [$vector->getData()],
queryTexts: $options['queryTexts'] ?? null,
nResults: 4,
where: $options['where'] ?? null,
queryTexts: [$query->getText()],
nResults: $options['limit'] ?? 4,
where: $where,
whereDocument: $options['whereDocument'] ?? null,
include: $include,
);

yield from $this->transformResponse($queryResponse);
}

private function buildInclude(array $options): ?array
{
if ([] === ($options['include'] ?? [])) {
return null;
}

return array_values(
array_unique(
array_merge(['embeddings', 'metadatas', 'distances'], $options['include'])
)
);
}

private function buildWhere($filter, array $options): ?array
{
if (!$filter instanceof EqualFilter) {
return null;
}

return [$filter->getField() => ['$eq' => $filter->getValue()]];
}

private function transformResponse(object $queryResponse): iterable
{
$metaCount = \count($queryResponse->metadatas[0]);

for ($i = 0; $i < $metaCount; ++$i) {
Expand Down
40 changes: 36 additions & 4 deletions src/store/src/Bridge/Pinecone/Store.php
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
use Symfony\AI\Store\Document\Metadata;
use Symfony\AI\Store\Document\VectorDocument;
use Symfony\AI\Store\Exception\InvalidArgumentException;
use Symfony\AI\Store\Exception\UnsupportedQueryTypeException;
use Symfony\AI\Store\ManagedStoreInterface;
use Symfony\AI\Store\Query\Filter\EqualFilter;
use Symfony\AI\Store\Query\QueryInterface;
use Symfony\AI\Store\Query\VectorQuery;
use Symfony\AI\Store\StoreInterface;

/**
Expand Down Expand Up @@ -112,13 +116,24 @@ public function remove(string|array $ids, array $options = []): void
}
}

public function query(Vector $vector, array $options = []): iterable
public function supports(string $queryClass): bool
{
return VectorQuery::class === $queryClass;
}

public function query(QueryInterface $query, array $options = []): iterable
{
if (!$query instanceof VectorQuery) {
throw new UnsupportedQueryTypeException($query->getType(), $this);
}

$filter = $this->buildFilter($query->getFilter(), $options);

$result = $this->getVectors()->query(
vector: $vector->getData(),
vector: $query->getVector()->getData(),
namespace: $options['namespace'] ?? $this->namespace,
filter: $options['filter'] ?? $this->filter,
topK: $options['topK'] ?? $this->topK,
filter: $filter,
topK: $options['topK'] ?? $options['limit'] ?? $this->topK,
includeValues: true,
);

Expand All @@ -140,6 +155,23 @@ public function drop(array $options = []): void
->delete();
}

private function buildFilter($queryFilter, array $options): array
{
$filter = $this->filter;

if ($queryFilter instanceof EqualFilter) {
$filterCondition = [$queryFilter->getField() => ['$eq' => $queryFilter->getValue()]];

if ([] === $filter) {
$filter = $filterCondition;
} else {
$filter = ['$and' => [$filter, $filterCondition]];
}
}

return $filter;
}

private function getVectors(): VectorResource
{
return $this->pinecone->data()->vectors();
Expand Down
Loading
Loading