Skip to content

Commit 549682f

Browse files
committed
[Chat][Doctrine] Replace row-accumulation with single-row upsert in DoctrineDbalMessageStore
1 parent bc0a9a3 commit 549682f

2 files changed

Lines changed: 62 additions & 56 deletions

File tree

src/chat/src/Bridge/Doctrine/DoctrineDbalMessageStore.php

Lines changed: 47 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,7 @@
1212
namespace Symfony\AI\Chat\Bridge\Doctrine;
1313

1414
use Doctrine\DBAL\Connection as DBALConnection;
15-
use Doctrine\DBAL\Platforms\OraclePlatform;
1615
use Doctrine\DBAL\Schema\ComparatorConfig;
17-
use Doctrine\DBAL\Schema\Name\Identifier;
18-
use Doctrine\DBAL\Schema\Name\UnqualifiedName;
19-
use Doctrine\DBAL\Schema\PrimaryKeyConstraint;
2016
use Doctrine\DBAL\Schema\Schema;
2117
use Doctrine\DBAL\Types\Types;
2218
use Psr\Clock\ClockInterface;
@@ -85,43 +81,61 @@ public function drop(): void
8581
return;
8682
}
8783

88-
$queryBuilder = $this->dbalConnection->createQueryBuilder()
89-
->delete($this->tableName);
90-
91-
$this->dbalConnection->executeStatement($queryBuilder->getSQL());
84+
$this->dbalConnection->executeStatement(
85+
$this->dbalConnection->createQueryBuilder()
86+
->delete($this->tableName)
87+
->getSQL()
88+
);
9289
}
9390

9491
public function save(MessageBag $messages): void
9592
{
96-
$queryBuilder = $this->dbalConnection->createQueryBuilder()
97-
->insert($this->tableName)
98-
->values([
99-
'messages' => '?',
100-
'added_at' => '?',
101-
]);
102-
103-
$this->dbalConnection->executeStatement($queryBuilder->getSQL(), [
104-
$this->serializer->serialize($messages->getMessages(), 'json'),
105-
$this->clock->now()->getTimestamp(),
106-
]);
93+
$serialized = $this->serializer->serialize($messages->getMessages(), 'json');
94+
$now = $this->clock->now()->getTimestamp();
95+
96+
$rowCount = (int) $this->dbalConnection->executeQuery(
97+
$this->dbalConnection->createQueryBuilder()
98+
->select('COUNT(*)')
99+
->from($this->tableName)
100+
->getSQL()
101+
)->fetchOne();
102+
103+
if (0 === $rowCount) {
104+
$this->dbalConnection->executeStatement(
105+
$this->dbalConnection->createQueryBuilder()
106+
->insert($this->tableName)
107+
->values(['messages' => '?', 'updated_at' => '?'])
108+
->getSQL(),
109+
[$serialized, $now],
110+
);
111+
} else {
112+
$this->dbalConnection->executeStatement(
113+
$this->dbalConnection->createQueryBuilder()
114+
->update($this->tableName)
115+
->set('messages', '?')
116+
->set('updated_at', '?')
117+
->getSQL(),
118+
[$serialized, $now],
119+
);
120+
}
107121
}
108122

109123
public function load(): MessageBag
110124
{
111-
$queryBuilder = $this->dbalConnection->createQueryBuilder()
112-
->select('messages')
113-
->from($this->tableName)
114-
->orderBy('added_at', 'ASC')
115-
;
116-
117-
$result = $this->dbalConnection->executeQuery($queryBuilder->getSQL());
125+
$payload = $this->dbalConnection->executeQuery(
126+
$this->dbalConnection->createQueryBuilder()
127+
->select('messages')
128+
->from($this->tableName)
129+
->getSQL()
130+
)->fetchAssociative();
131+
132+
if (false === $payload) {
133+
return new MessageBag();
134+
}
118135

119-
$messages = array_map(
120-
fn (array $payload): array => $this->serializer->deserialize($payload['messages'], MessageInterface::class.'[]', 'json'),
121-
$result->fetchAllAssociative(),
136+
return new MessageBag(
137+
...$this->serializer->deserialize($payload['messages'], MessageInterface::class.'[]', 'json'),
122138
);
123-
124-
return new MessageBag(...array_merge(...$messages));
125139
}
126140

127141
private function addTableToSchema(Schema $currentSchema): Schema
@@ -130,31 +144,8 @@ private function addTableToSchema(Schema $currentSchema): Schema
130144

131145
$table = $schema->createTable($this->tableName);
132146
$table->addOption('_symfony_ai_chat_table_name', $this->tableName);
133-
$idColumn = $table->addColumn('id', Types::BIGINT)
134-
->setAutoincrement(true)
135-
->setNotnull(true);
136-
$table->addColumn('messages', Types::TEXT)
137-
->setNotnull(true);
138-
$table->addColumn('added_at', Types::INTEGER)
139-
->setNotnull(true);
140-
if (class_exists(PrimaryKeyConstraint::class)) {
141-
$table->addPrimaryKeyConstraint(new PrimaryKeyConstraint(null, [
142-
new UnqualifiedName(Identifier::unquoted('id')),
143-
], true));
144-
} else {
145-
$table->setPrimaryKey(['id']);
146-
}
147-
148-
// We need to create a sequence for Oracle and set the id column to get the correct nextval
149-
if ($this->dbalConnection->getDatabasePlatform() instanceof OraclePlatform) {
150-
$serverVersion = $this->dbalConnection->executeQuery("SELECT version FROM product_component_version WHERE product LIKE 'Oracle Database%'")->fetchOne();
151-
if (version_compare($serverVersion, '12.1.0', '>=')) {
152-
$idColumn->setAutoincrement(false); // disable the creation of SEQUENCE and TRIGGER
153-
$idColumn->setDefault($this->tableName.'_seq.nextval');
154-
155-
$schema->createSequence($this->tableName.'_seq');
156-
}
157-
}
147+
$table->addColumn('messages', Types::TEXT)->setNotnull(true);
148+
$table->addColumn('updated_at', Types::INTEGER)->setNotnull(true);
158149

159150
return $schema;
160151
}

src/chat/src/Bridge/Doctrine/Tests/DoctrineDbalMessageStoreTest.php

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,21 @@ public function testMessageBagCanBeSaved()
171171
$this->assertCount(1, $messages);
172172
}
173173

174+
public function testSaveOverwritesPreviousMessages()
175+
{
176+
$connection = DriverManager::getConnection(['driver' => 'pdo_sqlite', 'memory' => true]);
177+
178+
$messageStore = new DoctrineDbalMessageStore('foo', $connection);
179+
$messageStore->setup();
180+
181+
$messageStore->save(new MessageBag(Message::ofUser('First message')));
182+
$messageStore->save(new MessageBag(Message::ofUser('Second message')));
183+
184+
$messages = $messageStore->load();
185+
$this->assertCount(1, $messages);
186+
$this->assertSame('Second message', $messages->getUserMessage()->asText());
187+
}
188+
174189
public function testMessageBagCanBeLoaded()
175190
{
176191
$serializer = new Serializer([

0 commit comments

Comments
 (0)