1212namespace Symfony \AI \Chat \Bridge \Doctrine ;
1313
1414use Doctrine \DBAL \Connection as DBALConnection ;
15- use Doctrine \DBAL \Platforms \OraclePlatform ;
1615use Doctrine \DBAL \Schema \ComparatorConfig ;
17- use Doctrine \DBAL \Schema \Name \Identifier ;
18- use Doctrine \DBAL \Schema \Name \UnqualifiedName ;
19- use Doctrine \DBAL \Schema \PrimaryKeyConstraint ;
2016use Doctrine \DBAL \Schema \Schema ;
2117use Doctrine \DBAL \Types \Types ;
2218use Psr \Clock \ClockInterface ;
@@ -85,43 +81,61 @@ public function drop(): void
8581 return ;
8682 }
8783
88- $ queryBuilder = $ this ->dbalConnection ->createQueryBuilder ()
89- ->delete ($ this ->tableName );
90-
91- $ this ->dbalConnection ->executeStatement ($ queryBuilder ->getSQL ());
84+ $ this ->dbalConnection ->executeStatement (
85+ $ this ->dbalConnection ->createQueryBuilder ()
86+ ->delete ($ this ->tableName )
87+ ->getSQL ()
88+ );
9289 }
9390
9491 public function save (MessageBag $ messages ): void
9592 {
96- $ queryBuilder = $ this ->dbalConnection ->createQueryBuilder ()
97- ->insert ($ this ->tableName )
98- ->values ([
99- 'messages ' => '? ' ,
100- 'added_at ' => '? ' ,
101- ]);
102-
103- $ this ->dbalConnection ->executeStatement ($ queryBuilder ->getSQL (), [
104- $ this ->serializer ->serialize ($ messages ->getMessages (), 'json ' ),
105- $ this ->clock ->now ()->getTimestamp (),
106- ]);
93+ $ serialized = $ this ->serializer ->serialize ($ messages ->getMessages (), 'json ' );
94+ $ now = $ this ->clock ->now ()->getTimestamp ();
95+
96+ $ rowCount = (int ) $ this ->dbalConnection ->executeQuery (
97+ $ this ->dbalConnection ->createQueryBuilder ()
98+ ->select ('COUNT(*) ' )
99+ ->from ($ this ->tableName )
100+ ->getSQL ()
101+ )->fetchOne ();
102+
103+ if (0 === $ rowCount ) {
104+ $ this ->dbalConnection ->executeStatement (
105+ $ this ->dbalConnection ->createQueryBuilder ()
106+ ->insert ($ this ->tableName )
107+ ->values (['messages ' => '? ' , 'updated_at ' => '? ' ])
108+ ->getSQL (),
109+ [$ serialized , $ now ],
110+ );
111+ } else {
112+ $ this ->dbalConnection ->executeStatement (
113+ $ this ->dbalConnection ->createQueryBuilder ()
114+ ->update ($ this ->tableName )
115+ ->set ('messages ' , '? ' )
116+ ->set ('updated_at ' , '? ' )
117+ ->getSQL (),
118+ [$ serialized , $ now ],
119+ );
120+ }
107121 }
108122
109123 public function load (): MessageBag
110124 {
111- $ queryBuilder = $ this ->dbalConnection ->createQueryBuilder ()
112- ->select ('messages ' )
113- ->from ($ this ->tableName )
114- ->orderBy ('added_at ' , 'ASC ' )
115- ;
116-
117- $ result = $ this ->dbalConnection ->executeQuery ($ queryBuilder ->getSQL ());
125+ $ payload = $ this ->dbalConnection ->executeQuery (
126+ $ this ->dbalConnection ->createQueryBuilder ()
127+ ->select ('messages ' )
128+ ->from ($ this ->tableName )
129+ ->getSQL ()
130+ )->fetchAssociative ();
131+
132+ if (false === $ payload ) {
133+ return new MessageBag ();
134+ }
118135
119- $ messages = array_map (
120- fn (array $ payload ): array => $ this ->serializer ->deserialize ($ payload ['messages ' ], MessageInterface::class.'[] ' , 'json ' ),
121- $ result ->fetchAllAssociative (),
136+ return new MessageBag (
137+ ...$ this ->serializer ->deserialize ($ payload ['messages ' ], MessageInterface::class.'[] ' , 'json ' ),
122138 );
123-
124- return new MessageBag (...array_merge (...$ messages ));
125139 }
126140
127141 private function addTableToSchema (Schema $ currentSchema ): Schema
@@ -130,31 +144,8 @@ private function addTableToSchema(Schema $currentSchema): Schema
130144
131145 $ table = $ schema ->createTable ($ this ->tableName );
132146 $ table ->addOption ('_symfony_ai_chat_table_name ' , $ this ->tableName );
133- $ idColumn = $ table ->addColumn ('id ' , Types::BIGINT )
134- ->setAutoincrement (true )
135- ->setNotnull (true );
136- $ table ->addColumn ('messages ' , Types::TEXT )
137- ->setNotnull (true );
138- $ table ->addColumn ('added_at ' , Types::INTEGER )
139- ->setNotnull (true );
140- if (class_exists (PrimaryKeyConstraint::class)) {
141- $ table ->addPrimaryKeyConstraint (new PrimaryKeyConstraint (null , [
142- new UnqualifiedName (Identifier::unquoted ('id ' )),
143- ], true ));
144- } else {
145- $ table ->setPrimaryKey (['id ' ]);
146- }
147-
148- // We need to create a sequence for Oracle and set the id column to get the correct nextval
149- if ($ this ->dbalConnection ->getDatabasePlatform () instanceof OraclePlatform) {
150- $ serverVersion = $ this ->dbalConnection ->executeQuery ("SELECT version FROM product_component_version WHERE product LIKE 'Oracle Database%' " )->fetchOne ();
151- if (version_compare ($ serverVersion , '12.1.0 ' , '>= ' )) {
152- $ idColumn ->setAutoincrement (false ); // disable the creation of SEQUENCE and TRIGGER
153- $ idColumn ->setDefault ($ this ->tableName .'_seq.nextval ' );
154-
155- $ schema ->createSequence ($ this ->tableName .'_seq ' );
156- }
157- }
147+ $ table ->addColumn ('messages ' , Types::TEXT )->setNotnull (true );
148+ $ table ->addColumn ('updated_at ' , Types::INTEGER )->setNotnull (true );
158149
159150 return $ schema ;
160151 }
0 commit comments