1313
1414# Skip cloud tests if TEST_SPICE_CLOUD is not set to true
1515def skip_cloud ():
16- # skip = os.environ.get("TEST_SPICE_CLOUD") != "true"
17- # Skipping all cloud tests for now
18- skip = True
19- return pytest .mark .skipif (skip , reason = "Cloud tests disabled" )
16+ skip = os .environ .get ("TEST_SPICE_CLOUD" ) != "true"
17+ return pytest .mark .skipif (skip , reason = "Cloud tests disabled (set TEST_SPICE_CLOUD=true)" )
2018
2119
2220def get_cloud_client ():
23- api_key = os .environ [ " API_KEY"]
21+ api_key = os .environ . get ( "SPICE_API_KEY" , os . environ . get ( " API_KEY", "" ))
2422 return Client (api_key = api_key , flight_url = "grpc+tls://flight.spiceai.io" )
2523
2624
@@ -35,23 +33,25 @@ def test_user_agent_is_populated():
3533 assert re .match (matching_regex , SPICE_USER_AGENT )
3634
3735
36+ @pytest .mark .cloud
3837@skip_cloud ()
3938def test_flight_recent_blocks ():
4039 client = get_cloud_client ()
41- data = client .query ("SELECT * FROM eth.recent_blocks LIMIT 10" )
40+ data = client .query ("SELECT * FROM tpch.lineitem LIMIT 10" )
4241 pandas_data = data .read_pandas ()
4342 assert len (pandas_data ) == 10
4443
4544
45+ @pytest .mark .cloud
4646@skip_cloud ()
4747def test_flight_streaming ():
4848 client = get_cloud_client ()
4949 query = """
50- SELECT number ,
51- "timestamp" ,
52- base_fee_per_gas ,
53- base_fee_per_gas / 1e9 AS base_fee_per_gas_gwei
54- FROM eth.blocks limit 2000
50+ SELECT o_orderkey ,
51+ o_custkey ,
52+ o_orderstatus ,
53+ o_totalprice
54+ FROM tpch.orders LIMIT 2000
5555 """
5656 reader = client .query (query )
5757
@@ -72,20 +72,17 @@ def test_flight_streaming():
7272 assert num_batches > 1
7373
7474
75+ @pytest .mark .cloud
7576@skip_cloud ()
7677def test_flight_timeout ():
7778 client = get_cloud_client ()
78- query = """SELECT block_number,
79- TO_TIMESTAMP(block_timestamp) as block_timestamp,
80- avg(gas) as avg_gas_used,
81- avg(max_priority_fee_per_gas) as avg_max_priority_fee_per_gas,
82- avg(gas_price) as avg_gas_price,
83- avg(gas_price / 1e9) AS avg_gas_price_in_gwei,
84- avg(gas * (gas_price / 1e18)) AS avg_fee_in_eth
85- FROM eth.transactions
86- WHERE block_timestamp > UNIX_TIMESTAMP()-60*60*24*30 -- last 30 days
87- GROUP BY block_number, block_timestamp
88- ORDER BY block_number DESC"""
79+ query = """SELECT o_orderstatus,
80+ COUNT(*) as order_count,
81+ AVG(o_totalprice) as avg_price,
82+ SUM(o_totalprice) as total_price
83+ FROM tpch.orders
84+ GROUP BY o_orderstatus
85+ ORDER BY total_price DESC"""
8986 try :
9087 prev_time = time .time ()
9188 _ = client .query (query , timeout = 1 )
@@ -414,3 +411,95 @@ def test_parameterized_query_no_params():
414411 total_rows += batch .num_rows
415412
416413 assert total_rows == 5
414+
415+ # ============== Cloud Parameterized Query Tests ==============
416+
417+
418+ @pytest .mark .cloud
419+ @skip_cloud ()
420+ @pytest .mark .skipif (skip_if_no_adbc (), reason = "ADBC driver not installed" )
421+ def test_cloud_parameterized_query_basic ():
422+ """Test parameterized query with Spice Cloud TPCH dataset."""
423+ client = get_cloud_client ()
424+
425+ # Test with integer parameter
426+ reader = client .query_with_params (
427+ "SELECT l_orderkey, l_quantity, l_extendedprice FROM tpch.lineitem WHERE l_quantity > $1 LIMIT 10" ,
428+ [40 ],
429+ )
430+
431+ total_rows = 0
432+ for batch in reader :
433+ total_rows += batch .num_rows
434+ # Validate l_quantity > 40
435+ quantity = batch .column ("l_quantity" )
436+ for i in range (batch .num_rows ):
437+ assert quantity [i ].as_py () > 40
438+
439+ assert total_rows > 0
440+ assert total_rows <= 10
441+
442+
443+ @pytest .mark .cloud
444+ @skip_cloud ()
445+ @pytest .mark .skipif (skip_if_no_adbc (), reason = "ADBC driver not installed" )
446+ def test_cloud_parameterized_query_multiple_params ():
447+ """Test parameterized query with multiple parameters on Spice Cloud."""
448+ client = get_cloud_client ()
449+
450+ reader = client .query_with_params (
451+ "SELECT o_orderkey, o_totalprice, o_orderstatus FROM tpch.orders WHERE o_totalprice > $1 AND o_orderstatus = $2 LIMIT 10" ,
452+ [100000.0 , "O" ],
453+ )
454+
455+ total_rows = 0
456+ for batch in reader :
457+ total_rows += batch .num_rows
458+ totalprice = batch .column ("o_totalprice" )
459+ status = batch .column ("o_orderstatus" )
460+ for i in range (batch .num_rows ):
461+ assert totalprice [i ].as_py () > 100000.0
462+ assert status [i ].as_py () == "O"
463+
464+ assert total_rows <= 10
465+
466+
467+ @pytest .mark .cloud
468+ @skip_cloud ()
469+ @pytest .mark .skipif (skip_if_no_adbc (), reason = "ADBC driver not installed" )
470+ def test_cloud_parameterized_query_with_explicit_types ():
471+ """Test parameterized query with explicit Param types on Spice Cloud."""
472+ client = get_cloud_client ()
473+
474+ reader = client .query_with_params (
475+ "SELECT c_custkey, c_name, c_acctbal FROM tpch.customer WHERE c_acctbal > $1 LIMIT 10" ,
476+ [Param .float64 (5000.0 )],
477+ )
478+
479+ total_rows = 0
480+ for batch in reader :
481+ total_rows += batch .num_rows
482+ acctbal = batch .column ("c_acctbal" )
483+ for i in range (batch .num_rows ):
484+ assert acctbal [i ].as_py () > 5000.0
485+
486+ assert total_rows <= 10
487+
488+
489+ @pytest .mark .cloud
490+ @skip_cloud ()
491+ @pytest .mark .skipif (skip_if_no_adbc (), reason = "ADBC driver not installed" )
492+ def test_cloud_parameterized_query_empty_params ():
493+ """Test parameterized query with empty params on Spice Cloud."""
494+ client = get_cloud_client ()
495+
496+ reader = client .query_with_params (
497+ "SELECT n_nationkey, n_name, n_regionkey FROM tpch.nation LIMIT 5" ,
498+ [],
499+ )
500+
501+ total_rows = 0
502+ for batch in reader :
503+ total_rows += batch .num_rows
504+
505+ assert total_rows == 5
0 commit comments