Skip to content

Commit 1161a86

Browse files
committed
added sample slides for demo presentation and added trainscript for demo model
Signed-off-by: Christoph Huy <christoph.huy@campus.tu-berlin.de>
1 parent 4b217dc commit 1161a86

File tree

2 files changed

+57
-0
lines changed

2 files changed

+57
-0
lines changed
654 KB
Binary file not shown.
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import sys
2+
3+
from pathlib import Path
4+
from pyspark.sql import SparkSession
5+
6+
DATA = Path(__file__).parent.parent / "data"
7+
8+
print(f"Data directory is located at: {DATA}")
9+
10+
11+
def setup_python_env():
12+
project_root = Path(__file__).parent.parent.parent.parent
13+
print(f"Project root directory is located at: {project_root}")
14+
15+
sdk_path = project_root / "src" / "sdk" / "python"
16+
sdk_path = sdk_path.resolve()
17+
18+
sys.path.insert(0, str(sdk_path))
19+
20+
21+
def main():
22+
23+
print("Setting up Python environment...")
24+
setup_python_env()
25+
from rtdip_sdk.pipelines.forecasting.spark.autogluon_timeseries import (
26+
AutoGluonTimeSeries,
27+
)
28+
29+
print("Starting Spark session...")
30+
spark = (
31+
SparkSession.builder.master("local[*]")
32+
.appName("SCADA-Forecasting")
33+
.config("spark.driver.memory", "8g")
34+
.config("spark.executor.memory", "8g")
35+
.config("spark.driver.maxResultSize", "2g")
36+
.config("spark.sql.shuffle.partitions", "50")
37+
.config("spark.sql.execution.arrow.pyspark.enabled", "true")
38+
.getOrCreate()
39+
)
40+
41+
print("Reading preprocessed SCADA data...")
42+
43+
data_path = DATA / "scada_prepro.parquet"
44+
assert data_path.exists(), f"Data file not found at {data_path}"
45+
df = spark.read.parquet(str(data_path))
46+
47+
print("Starting AutoGluon Training...")
48+
ag_model = AutoGluonTimeSeries()
49+
train_df, test_df = ag_model.split_data(df)
50+
res_dict = ag_model.train(train_df)
51+
52+
print("Saving test dataset...")
53+
test_df.write.mode("overwrite").parquet(str(DATA / "scada_test.parquet"))
54+
55+
56+
if __name__ == "__main__":
57+
exit(main())

0 commit comments

Comments
 (0)