1+ import pytest
2+
3+ import dbldatagen as dg
4+
5+
6+ class TestSimulatedServerless :
7+ """Serverless operation and other forms of shared spark cloud operation often have restrictions on what
8+ features may be used.
9+
10+ In this set of tests, we'll simulate some of the common restrictions found in Databricks serverless and shared
11+ environments to ensure that common operations still work.
12+
13+ Serverless operations have some of the following restrictions:
14+
15+ - Spark config settings cannot be written
16+
17+ """
18+
19+ @pytest .fixture (scope = "class" )
20+ def serverlessSpark (self ):
21+ from unittest .mock import MagicMock
22+
23+ sparkSession = dg .SparkSingleton .getLocalInstance ("unit tests" )
24+
25+ oldSetMethod = sparkSession .conf .set
26+ oldGetMethod = sparkSession .conf .get
27+ sparkSession .conf .set = MagicMock (
28+ side_effect = ValueError ("Setting value prohibited in simulated serverless env." ))
29+ sparkSession .conf .get = MagicMock (
30+ side_effect = ValueError ("Getting value prohibited in simulated serverless env." ))
31+
32+ yield sparkSession
33+
34+ sparkSession .conf .set = oldSetMethod
35+ sparkSession .conf .get = oldGetMethod
36+
37+ def test_basic_data (self , serverlessSpark ):
38+ from pyspark .sql .types import FloatType , IntegerType , StringType
39+
40+ row_count = 1000 * 100
41+ column_count = 10
42+ testDataSpec = (
43+ dg .DataGenerator (serverlessSpark , name = "test_data_set1" , rows = row_count , partitions = 4 )
44+ .withIdOutput ()
45+ .withColumn (
46+ "r" ,
47+ FloatType (),
48+ expr = "floor(rand() * 350) * (86400 + 3600)" ,
49+ numColumns = column_count ,
50+ )
51+ .withColumn ("code1" , IntegerType (), minValue = 100 , maxValue = 200 )
52+ .withColumn ("code2" , "integer" , minValue = 0 , maxValue = 10 , random = True )
53+ .withColumn ("code3" , StringType (), values = ["online" , "offline" , "unknown" ])
54+ .withColumn (
55+ "code4" , StringType (), values = ["a" , "b" , "c" ], random = True , percentNulls = 0.05
56+ )
57+ .withColumn (
58+ "code5" , "string" , values = ["a" , "b" , "c" ], random = True , weights = [9 , 1 , 1 ]
59+ )
60+ )
61+
62+ dfTestData = testDataSpec .build ()
63+
64+ @pytest .mark .parametrize ("providerName, providerOptions" , [
65+ ("basic/user" , {"rows" : 50 , "partitions" : 4 , "random" : False , "dummyValues" : 0 }),
66+ ("basic/user" , {"rows" : 100 , "partitions" : - 1 , "random" : True , "dummyValues" : 0 })
67+ ])
68+ def test_basic_user_table_retrieval (self , providerName , providerOptions , serverlessSpark ):
69+ ds = dg .Datasets (serverlessSpark , providerName ).get (** providerOptions )
70+ assert ds is not None , f"""expected to get dataset specification for provider `{ providerName } `
71+ with options: { providerOptions }
72+ """
73+ df = ds .build ()
74+
75+ assert df .count () >= 0
0 commit comments