Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/chat/src/Bridge/Doctrine/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
CHANGELOG
=========

0.8
---

* [BC BREAK] `DoctrineDbalMessageStore::save()` now upserts a single row instead of inserting a new row on every call; the table schema changed from `(id, messages, added_at)` to `(messages, updated_at)`

0.1
---

Expand Down
103 changes: 47 additions & 56 deletions src/chat/src/Bridge/Doctrine/DoctrineDbalMessageStore.php
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@
namespace Symfony\AI\Chat\Bridge\Doctrine;

use Doctrine\DBAL\Connection as DBALConnection;
use Doctrine\DBAL\Platforms\OraclePlatform;
use Doctrine\DBAL\Schema\ComparatorConfig;
use Doctrine\DBAL\Schema\Name\Identifier;
use Doctrine\DBAL\Schema\Name\UnqualifiedName;
use Doctrine\DBAL\Schema\PrimaryKeyConstraint;
use Doctrine\DBAL\Schema\Schema;
use Doctrine\DBAL\Types\Types;
use Psr\Clock\ClockInterface;
Expand Down Expand Up @@ -85,43 +81,61 @@ public function drop(): void
return;
}

$queryBuilder = $this->dbalConnection->createQueryBuilder()
->delete($this->tableName);

$this->dbalConnection->executeStatement($queryBuilder->getSQL());
$this->dbalConnection->executeStatement(
$this->dbalConnection->createQueryBuilder()
->delete($this->tableName)
->getSQL()
);
}

public function save(MessageBag $messages): void
{
$queryBuilder = $this->dbalConnection->createQueryBuilder()
->insert($this->tableName)
->values([
'messages' => '?',
'added_at' => '?',
]);

$this->dbalConnection->executeStatement($queryBuilder->getSQL(), [
$this->serializer->serialize($messages->getMessages(), 'json'),
$this->clock->now()->getTimestamp(),
]);
$serialized = $this->serializer->serialize($messages->getMessages(), 'json');
$now = $this->clock->now()->getTimestamp();

$rowCount = (int) $this->dbalConnection->executeQuery(
$this->dbalConnection->createQueryBuilder()
->select('COUNT(*)')
->from($this->tableName)
->getSQL()
)->fetchOne();

if (0 === $rowCount) {
$this->dbalConnection->executeStatement(
$this->dbalConnection->createQueryBuilder()
->insert($this->tableName)
->values(['messages' => '?', 'updated_at' => '?'])
->getSQL(),
[$serialized, $now],
);
} else {
$this->dbalConnection->executeStatement(
$this->dbalConnection->createQueryBuilder()
->update($this->tableName)
->set('messages', '?')
->set('updated_at', '?')
->getSQL(),
[$serialized, $now],
);
}
}

public function load(): MessageBag
{
$queryBuilder = $this->dbalConnection->createQueryBuilder()
->select('messages')
->from($this->tableName)
->orderBy('added_at', 'ASC')
;

$result = $this->dbalConnection->executeQuery($queryBuilder->getSQL());
$payload = $this->dbalConnection->executeQuery(
$this->dbalConnection->createQueryBuilder()
->select('messages')
->from($this->tableName)
->getSQL()
)->fetchAssociative();

if (false === $payload) {
return new MessageBag();
}

$messages = array_map(
fn (array $payload): array => $this->serializer->deserialize($payload['messages'], MessageInterface::class.'[]', 'json'),
$result->fetchAllAssociative(),
return new MessageBag(
...$this->serializer->deserialize($payload['messages'], MessageInterface::class.'[]', 'json'),
);

return new MessageBag(...array_merge(...$messages));
}

private function addTableToSchema(Schema $currentSchema): Schema
Expand All @@ -130,31 +144,8 @@ private function addTableToSchema(Schema $currentSchema): Schema

$table = $schema->createTable($this->tableName);
$table->addOption('_symfony_ai_chat_table_name', $this->tableName);
$idColumn = $table->addColumn('id', Types::BIGINT)
->setAutoincrement(true)
->setNotnull(true);
$table->addColumn('messages', Types::TEXT)
->setNotnull(true);
$table->addColumn('added_at', Types::INTEGER)
->setNotnull(true);
if (class_exists(PrimaryKeyConstraint::class)) {
$table->addPrimaryKeyConstraint(new PrimaryKeyConstraint(null, [
new UnqualifiedName(Identifier::unquoted('id')),
], true));
} else {
$table->setPrimaryKey(['id']);
}

// We need to create a sequence for Oracle and set the id column to get the correct nextval
if ($this->dbalConnection->getDatabasePlatform() instanceof OraclePlatform) {
$serverVersion = $this->dbalConnection->executeQuery("SELECT version FROM product_component_version WHERE product LIKE 'Oracle Database%'")->fetchOne();
if (version_compare($serverVersion, '12.1.0', '>=')) {
$idColumn->setAutoincrement(false); // disable the creation of SEQUENCE and TRIGGER
$idColumn->setDefault($this->tableName.'_seq.nextval');

$schema->createSequence($this->tableName.'_seq');
}
}
$table->addColumn('messages', Types::TEXT)->setNotnull(true);
$table->addColumn('updated_at', Types::INTEGER)->setNotnull(true);

return $schema;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,21 @@ public function testMessageBagCanBeSaved()
$this->assertCount(1, $messages);
}

public function testSaveOverwritesPreviousMessages()
{
$connection = DriverManager::getConnection(['driver' => 'pdo_sqlite', 'memory' => true]);

$messageStore = new DoctrineDbalMessageStore('foo', $connection);
$messageStore->setup();

$messageStore->save(new MessageBag(Message::ofUser('First message')));
$messageStore->save(new MessageBag(Message::ofUser('Second message')));

$messages = $messageStore->load();
$this->assertCount(1, $messages);
$this->assertSame('Second message', $messages->getUserMessage()->asText());
}

public function testMessageBagCanBeLoaded()
{
$serializer = new Serializer([
Expand Down
Loading