1
1
use super :: execution_unit:: QueryHash ;
2
2
use super :: module_subscription_manager:: { Plan , SubscriptionGaugeStats , SubscriptionManager } ;
3
- use super :: query:: compile_read_only_query ;
3
+ use super :: query:: compile_query_with_hashes ;
4
4
use super :: tx:: DeltaTx ;
5
5
use super :: { collect_table_update, record_exec_metrics, TableUpdateType } ;
6
6
use crate :: client:: messages:: {
@@ -16,8 +16,8 @@ use crate::estimation::estimate_rows_scanned;
16
16
use crate :: execution_context:: { Workload , WorkloadType } ;
17
17
use crate :: host:: module_host:: { DatabaseUpdate , EventStatus , ModuleEvent } ;
18
18
use crate :: messages:: websocket:: Subscribe ;
19
- use crate :: sql:: ast:: SchemaViewer ;
20
19
use crate :: subscription:: execute_plans;
20
+ use crate :: subscription:: query:: is_subscribe_to_all_tables;
21
21
use crate :: vm:: check_row_limit;
22
22
use crate :: worker_metrics:: WORKER_METRICS ;
23
23
use parking_lot:: RwLock ;
@@ -27,8 +27,6 @@ use spacetimedb_client_api_messages::websocket::{
27
27
UnsubscribeMulti ,
28
28
} ;
29
29
use spacetimedb_execution:: pipelined:: PipelinedProject ;
30
- use spacetimedb_expr:: check:: parse_and_type_sub;
31
- use spacetimedb_expr:: errors:: TypingError ;
32
30
use spacetimedb_lib:: identity:: AuthCtx ;
33
31
use spacetimedb_lib:: metrics:: ExecutionMetrics ;
34
32
use spacetimedb_lib:: Identity ;
@@ -105,6 +103,20 @@ type FullSubscriptionUpdate = FormatSwitch<ws::DatabaseUpdate<BsatnFormat>, ws::
105
103
106
104
/// A utility for sending an error message to a client and returning early
107
105
macro_rules! return_on_err {
106
+ ( $expr: expr, $handler: expr) => {
107
+ match $expr {
108
+ Ok ( val) => val,
109
+ Err ( e) => {
110
+ // TODO: Handle errors sending messages.
111
+ let _ = $handler( e. to_string( ) . into( ) ) ;
112
+ return Ok ( ( ) ) ;
113
+ }
114
+ }
115
+ } ;
116
+ }
117
+
118
+ /// A utility for sending an error message to a client and returning early
119
+ macro_rules! return_on_err_with_sql {
108
120
( $expr: expr, $sql: expr, $handler: expr) => {
109
121
match $expr. map_err( |err| DBError :: WithSql {
110
122
sql: $sql. into( ) ,
@@ -120,12 +132,6 @@ macro_rules! return_on_err {
120
132
} ;
121
133
}
122
134
123
- /// Hash a sql query, using the caller's identity if necessary
124
- fn hash_query ( sql : & str , tx : & TxId , auth : & AuthCtx ) -> Result < QueryHash , TypingError > {
125
- parse_and_type_sub ( sql, & SchemaViewer :: new ( tx, auth) , auth)
126
- . map ( |( _, has_param) | QueryHash :: from_string ( sql, auth. caller , has_param) )
127
- }
128
-
129
135
impl ModuleSubscriptions {
130
136
pub fn new ( relational_db : Arc < RelationalDB > , subscriptions : Subscriptions , owner_identity : Identity ) -> Self {
131
137
let stats = Box :: new ( SubscriptionGauges :: new ( & relational_db. database_identity ( ) ) ) ;
@@ -248,29 +254,34 @@ impl ModuleSubscriptions {
248
254
} )
249
255
} ;
250
256
257
+ let sql = request. query ;
258
+ let auth = AuthCtx :: new ( self . owner_identity , sender. id . identity ) ;
259
+ let hash = QueryHash :: from_string ( & sql, auth. caller , false ) ;
260
+ let hash_with_param = QueryHash :: from_string ( & sql, auth. caller , true ) ;
261
+
251
262
let tx = scopeguard:: guard ( self . relational_db . begin_tx ( Workload :: Subscribe ) , |tx| {
252
263
self . relational_db . release_tx ( tx) ;
253
264
} ) ;
254
- let auth = AuthCtx :: new ( self . owner_identity , sender. id . identity ) ;
255
- let query = super :: query:: WHITESPACE . replace_all ( & request. query , " " ) ;
256
- let sql = query. trim ( ) ;
257
-
258
- let hash = return_on_err ! ( hash_query( sql, & tx, & auth) , sql, send_err_msg) ;
259
265
260
266
let existing_query = {
261
267
let guard = self . subscriptions . read ( ) ;
262
268
guard. query ( & hash)
263
269
} ;
264
270
265
- let query = return_on_err ! (
266
- existing_query
267
- . map( Ok )
268
- . unwrap_or_else( || compile_read_only_query( & auth, & tx, sql) . map( Arc :: new) ) ,
271
+ let query = return_on_err_with_sql ! (
272
+ existing_query. map( Ok ) . unwrap_or_else( || compile_query_with_hashes(
273
+ & auth,
274
+ & tx,
275
+ & sql,
276
+ hash,
277
+ hash_with_param
278
+ )
279
+ . map( Arc :: new) ) ,
269
280
sql,
270
281
send_err_msg
271
282
) ;
272
283
273
- let ( table_rows, metrics) = return_on_err ! (
284
+ let ( table_rows, metrics) = return_on_err_with_sql ! (
274
285
self . evaluate_initial_subscription( sender. clone( ) , query. clone( ) , & tx, & auth, TableUpdateType :: Subscribe ) ,
275
286
query. sql( ) ,
276
287
send_err_msg
@@ -356,7 +367,7 @@ impl ModuleSubscriptions {
356
367
self . relational_db . release_tx ( tx) ;
357
368
} ) ;
358
369
let auth = AuthCtx :: new ( self . owner_identity , sender. id . identity ) ;
359
- let ( table_rows, metrics) = return_on_err ! (
370
+ let ( table_rows, metrics) = return_on_err_with_sql ! (
360
371
self . evaluate_initial_subscription( sender. clone( ) , query. clone( ) , & tx, & auth, TableUpdateType :: Unsubscribe ) ,
361
372
query. sql( ) ,
362
373
send_err_msg
@@ -452,6 +463,74 @@ impl ModuleSubscriptions {
452
463
Ok ( ( ) )
453
464
}
454
465
466
+ /// Compiles the queries in a [Subscribe] or [SubscribeMulti] message.
467
+ ///
468
+ /// Note, we hash queries to avoid recompilation,
469
+ /// but we need to know if a query is parameterized in order to hash it correctly.
470
+ /// This requires that we type check which in turn requires that we start a tx.
471
+ ///
472
+ /// Unfortunately parsing with sqlparser is quite expensive,
473
+ /// so we'd like to avoid that cost while holding the tx lock,
474
+ /// especially since all we're trying to do is generate a hash.
475
+ ///
476
+ /// Instead we generate two hashes and outside of the tx lock.
477
+ /// If either one is currently tracked, we can avoid recompilation.
478
+ fn compile_queries (
479
+ & self ,
480
+ sender : Identity ,
481
+ queries : impl IntoIterator < Item = Box < str > > ,
482
+ num_queries : usize ,
483
+ ) -> Result < ( Vec < Arc < Plan > > , AuthCtx , TxId ) , DBError > {
484
+ let mut subscribe_to_all_tables = false ;
485
+ let mut plans = Vec :: with_capacity ( num_queries) ;
486
+ let mut query_hashes = Vec :: with_capacity ( num_queries) ;
487
+
488
+ for sql in queries {
489
+ if is_subscribe_to_all_tables ( & sql) {
490
+ subscribe_to_all_tables = true ;
491
+ continue ;
492
+ }
493
+ let hash = QueryHash :: from_string ( & sql, sender, false ) ;
494
+ let hash_with_param = QueryHash :: from_string ( & sql, sender, true ) ;
495
+ query_hashes. push ( ( sql, hash, hash_with_param) ) ;
496
+ }
497
+
498
+ let auth = AuthCtx :: new ( self . owner_identity , sender) ;
499
+
500
+ // We always get the db lock before the subscription lock to avoid deadlocks.
501
+ let tx = scopeguard:: guard ( self . relational_db . begin_tx ( Workload :: Subscribe ) , |tx| {
502
+ self . relational_db . release_tx ( tx) ;
503
+ } ) ;
504
+ let guard = self . subscriptions . read ( ) ;
505
+
506
+ if subscribe_to_all_tables {
507
+ plans. extend (
508
+ super :: subscription:: get_all ( & self . relational_db , & tx, & auth) ?
509
+ . into_iter ( )
510
+ . map ( Arc :: new) ,
511
+ ) ;
512
+ }
513
+
514
+ for ( sql, hash, hash_with_param) in query_hashes {
515
+ if let Some ( unit) = guard. query ( & hash) {
516
+ plans. push ( unit) ;
517
+ } else if let Some ( unit) = guard. query ( & hash_with_param) {
518
+ plans. push ( unit) ;
519
+ } else {
520
+ plans. push ( Arc :: new (
521
+ compile_query_with_hashes ( & auth, & tx, & sql, hash, hash_with_param) . map_err ( |err| {
522
+ DBError :: WithSql {
523
+ error : Box :: new ( DBError :: Other ( err. into ( ) ) ) ,
524
+ sql,
525
+ }
526
+ } ) ?,
527
+ ) ) ;
528
+ }
529
+ }
530
+
531
+ Ok ( ( plans, auth, scopeguard:: ScopeGuard :: into_inner ( tx) ) )
532
+ }
533
+
455
534
#[ tracing:: instrument( level = "trace" , skip_all) ]
456
535
pub fn add_multi_subscription (
457
536
& self ,
@@ -473,39 +552,14 @@ impl ModuleSubscriptions {
473
552
} ) ;
474
553
} ;
475
554
476
- // We always get the db lock before the subscription lock to avoid deadlocks.
477
- let tx = scopeguard:: guard ( self . relational_db . begin_tx ( Workload :: Subscribe ) , |tx| {
555
+ let num_queries = request. query_strings . len ( ) ;
556
+ let ( queries, auth, tx) = return_on_err ! (
557
+ self . compile_queries( sender. id. identity, request. query_strings, num_queries) ,
558
+ send_err_msg
559
+ ) ;
560
+ let tx = scopeguard:: guard ( tx, |tx| {
478
561
self . relational_db . release_tx ( tx) ;
479
562
} ) ;
480
- let auth = AuthCtx :: new ( self . owner_identity , sender. id . identity ) ;
481
- let mut queries = vec ! [ ] ;
482
- let guard = self . subscriptions . read ( ) ;
483
- for sql in request
484
- . query_strings
485
- . iter ( )
486
- . map ( |sql| super :: query:: WHITESPACE . replace_all ( sql, " " ) )
487
- {
488
- let sql = sql. trim ( ) ;
489
- if sql == super :: query:: SUBSCRIBE_TO_ALL_QUERY {
490
- queries. extend (
491
- super :: subscription:: get_all ( & self . relational_db , & tx, & auth) ?
492
- . into_iter ( )
493
- . map ( Arc :: new) ,
494
- ) ;
495
- continue ;
496
- }
497
-
498
- let hash = return_on_err ! ( hash_query( sql, & tx, & auth) , sql, send_err_msg) ;
499
-
500
- if let Some ( unit) = guard. query ( & hash) {
501
- queries. push ( unit) ;
502
- } else {
503
- let compiled = return_on_err ! ( compile_read_only_query( & auth, & tx, sql) , sql, send_err_msg) ;
504
- queries. push ( Arc :: new ( compiled) ) ;
505
- }
506
- }
507
-
508
- drop ( guard) ;
509
563
510
564
// We minimize locking so that other clients can add subscriptions concurrently.
511
565
// We are protected from race conditions with broadcasts, because we have the db lock,
@@ -561,40 +615,11 @@ impl ModuleSubscriptions {
561
615
timer : Instant ,
562
616
_assert : Option < AssertTxFn > ,
563
617
) -> Result < ( ) , DBError > {
564
- let tx = scopeguard:: guard ( self . relational_db . begin_tx ( Workload :: Subscribe ) , |tx| {
618
+ let num_queries = subscription. query_strings . len ( ) ;
619
+ let ( queries, auth, tx) = self . compile_queries ( sender. id . identity , subscription. query_strings , num_queries) ?;
620
+ let tx = scopeguard:: guard ( tx, |tx| {
565
621
self . relational_db . release_tx ( tx) ;
566
622
} ) ;
567
- let request_id = subscription. request_id ;
568
- let auth = AuthCtx :: new ( self . owner_identity , sender. id . identity ) ;
569
- let mut queries = vec ! [ ] ;
570
-
571
- let guard = self . subscriptions . read ( ) ;
572
-
573
- for sql in subscription
574
- . query_strings
575
- . iter ( )
576
- . map ( |sql| super :: query:: WHITESPACE . replace_all ( sql, " " ) )
577
- {
578
- let sql = sql. trim ( ) ;
579
- if sql == super :: query:: SUBSCRIBE_TO_ALL_QUERY {
580
- queries. extend (
581
- super :: subscription:: get_all ( & self . relational_db , & tx, & auth) ?
582
- . into_iter ( )
583
- . map ( Arc :: new) ,
584
- ) ;
585
- continue ;
586
- }
587
-
588
- let hash = hash_query ( sql, & tx, & auth) ?;
589
- if let Some ( unit) = guard. query ( & hash) {
590
- queries. push ( unit) ;
591
- } else {
592
- let compiled = compile_read_only_query ( & auth, & tx, sql) ?;
593
- queries. push ( Arc :: new ( compiled) ) ;
594
- }
595
- }
596
-
597
- drop ( guard) ;
598
623
599
624
check_row_limit (
600
625
& queries,
@@ -639,7 +664,7 @@ impl ModuleSubscriptions {
639
664
// on the wire
640
665
let _ = sender. send_message ( SubscriptionUpdateMessage {
641
666
database_update,
642
- request_id : Some ( request_id) ,
667
+ request_id : Some ( subscription . request_id ) ,
643
668
timer : Some ( timer) ,
644
669
} ) ;
645
670
Ok ( ( ) )
0 commit comments