@@ -424,6 +424,7 @@ impl PostgresSinkWriter {
424424 & self . pk_indices ,
425425 parameters,
426426 remaining,
427+ true ,
427428 )
428429 . await ?;
429430 transaction. commit ( ) . await ?;
@@ -463,6 +464,7 @@ impl PostgresSinkWriter {
463464 & self . pk_indices ,
464465 delete_parameters,
465466 delete_remaining_parameter,
467+ false ,
466468 )
467469 . await ?;
468470 let ( insert_parameters, insert_remaining_parameter) = insert_parameter_buffer. into_parts ( ) ;
@@ -474,6 +476,7 @@ impl PostgresSinkWriter {
474476 & self . pk_indices ,
475477 insert_parameters,
476478 insert_remaining_parameter,
479+ false ,
477480 )
478481 . await ?;
479482 transaction. commit ( ) . await ?;
@@ -489,6 +492,7 @@ impl PostgresSinkWriter {
489492 pk_indices : & [ usize ] ,
490493 parameters : Vec < Vec < Option < ScalarAdapter > > > ,
491494 remaining_parameter : Vec < Option < ScalarAdapter > > ,
495+ append_only : bool ,
492496 ) -> Result < ( ) > {
493497 let column_length = match op {
494498 Op :: Insert => schema. len ( ) ,
@@ -506,7 +510,13 @@ impl PostgresSinkWriter {
506510 schema. fields( ) ,
507511 ) ;
508512 let statement = match op {
509- Op :: Insert => create_insert_sql ( schema, table_name, rows_length) ,
513+ Op :: Insert => {
514+ if append_only {
515+ create_insert_sql ( schema, table_name, rows_length)
516+ } else {
517+ create_upsert_sql ( schema, table_name, pk_indices, rows_length)
518+ }
519+ }
510520 Op :: Delete => create_delete_sql ( schema, table_name, pk_indices, rows_length) ,
511521 _ => unreachable ! ( ) ,
512522 } ;
@@ -523,7 +533,13 @@ impl PostgresSinkWriter {
523533 "flattened parameters are unaligned"
524534 ) ;
525535 let statement = match op {
526- Op :: Insert => create_insert_sql ( schema, table_name, rows_length) ,
536+ Op :: Insert => {
537+ if append_only {
538+ create_insert_sql ( schema, table_name, rows_length)
539+ } else {
540+ create_upsert_sql ( schema, table_name, pk_indices, rows_length)
541+ }
542+ }
527543 Op :: Delete => create_delete_sql ( schema, table_name, pk_indices, rows_length) ,
528544 _ => unreachable ! ( ) ,
529545 } ;
@@ -602,6 +618,46 @@ fn create_delete_sql(
602618 format ! ( "DELETE FROM {table_name} WHERE {pk} in ({parameters})" )
603619}
604620
621+ fn create_upsert_sql (
622+ schema : & Schema ,
623+ table_name : & str ,
624+ pk_indices : & [ usize ] ,
625+ number_of_rows : usize ,
626+ ) -> String {
627+ let number_of_columns = schema. len ( ) ;
628+ let columns: String = schema
629+ . fields ( )
630+ . iter ( )
631+ . map ( |field| field. name . clone ( ) )
632+ . collect_vec ( )
633+ . join ( ", " ) ;
634+ let parameters: String = ( 0 ..number_of_rows)
635+ . map ( |i| {
636+ let row_parameters = ( 0 ..number_of_columns)
637+ . map ( |j| format ! ( "${}" , i * number_of_columns + j + 1 ) )
638+ . join ( ", " ) ;
639+ format ! ( "({row_parameters})" )
640+ } )
641+ . collect_vec ( )
642+ . join ( ", " ) ;
643+ let pk_columns = pk_indices
644+ . iter ( )
645+ . map ( |i| schema. fields ( ) [ * i] . name . clone ( ) )
646+ . collect_vec ( )
647+ . join ( ", " ) ;
648+ let update_parameters: String = ( 0 ..number_of_columns)
649+ . filter ( |i| !pk_indices. contains ( i) )
650+ . map ( |i| {
651+ let column = schema. fields ( ) [ i] . name . clone ( ) ;
652+ format ! ( "{column} = EXCLUDED.{column}" )
653+ } )
654+ . collect_vec ( )
655+ . join ( ", " ) ;
656+ format ! (
657+ "INSERT INTO {table_name} ({columns}) VALUES {parameters} on conflict ({pk_columns}) do update set {update_parameters}"
658+ )
659+ }
660+
605661#[ cfg( test) ]
606662mod tests {
607663 use std:: fmt:: Display ;
@@ -662,4 +718,26 @@ mod tests {
662718 expect ! [ "DELETE FROM test_table WHERE (a, b) in (($1, $2), ($3, $4), ($5, $6))" ] ,
663719 ) ;
664720 }
721+
722+ #[ test]
723+ fn test_create_upsert_sql ( ) {
724+ let schema = Schema :: new ( vec ! [
725+ Field {
726+ data_type: DataType :: Int32 ,
727+ name: "a" . to_owned( ) ,
728+ } ,
729+ Field {
730+ data_type: DataType :: Int32 ,
731+ name: "b" . to_owned( ) ,
732+ } ,
733+ ] ) ;
734+ let table_name = "test_table" ;
735+ let sql = create_upsert_sql ( & schema, table_name, & [ 1 ] , 3 ) ;
736+ check (
737+ sql,
738+ expect ! [
739+ "INSERT INTO test_table (a, b) VALUES ($1, $2), ($3, $4), ($5, $6) on conflict (b) do update set a = EXCLUDED.a"
740+ ] ,
741+ ) ;
742+ }
665743}
0 commit comments