Skip to content

Commit f1ce889

Browse files
committed
Enhance LOCA2 data processing by introducing asset jobs and updating configurations. Refactor asset definitions to utilize new job structure, and modify sensor implementations to align with the updated asset dependencies. Update tests to reflect changes in asset return types and configurations.
1 parent 7bb9348 commit f1ce889

File tree

6 files changed

+140
-68
lines changed

6 files changed

+140
-68
lines changed

downscaled_climate_data/assets/loca2.py

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,29 @@
44
import requests
55
import s3fs
66
import xarray as xr
7-
from dagster import AssetExecutionContext, AssetIn, Config, EnvVar, asset
7+
import dagster as dg
88
from dagster_aws.s3 import S3Resource
99

1010
import downscaled_climate_data
1111

1212

13-
class Loca2Config(Config):
13+
class Loca2Config(dg.Config):
1414
s3_key: str
1515
url: str = "https://cirrus.ucsd.edu/~pierce/LOCA2/CONUS_regions_split/ACCESS-CM2/cent/0p0625deg/r2i1p1f1/historical/tasmax/tasmax.ACCESS-CM2.historical.r2i1p1f1.1950-2014.LOCA_16thdeg_v20220413.cent.nc" # NOQA E501
1616

1717

18-
@asset(
18+
@dg.asset(
1919
name="loca2_raw_netcdf",
2020
description="Raw LOCA2 data downloaded from the web",
2121
code_version=downscaled_climate_data.__version__,
2222
group_name="loca2"
2323
)
24-
def loca2_raw_netcdf(context: AssetExecutionContext,
24+
def loca2_raw_netcdf(context: dg.AssetExecutionContext,
2525
config: Loca2Config,
26-
s3: S3Resource) -> dict[str, str]:
26+
s3: S3Resource) -> dg.Output:
2727

28-
destination_bucket = EnvVar("LOCA2_BUCKET").get_value()
29-
destination_path_root = EnvVar("LOCA2_RAW_PATH_ROOT").get_value()
28+
destination_bucket = dg.EnvVar("LOCA2_BUCKET").get_value()
29+
destination_path_root = dg.EnvVar("LOCA2_RAW_PATH_ROOT").get_value()
3030

3131
with requests.get(config.url, stream=True) as response:
3232
# Raise an exception for bad HTTP responses
@@ -44,24 +44,34 @@ def loca2_raw_netcdf(context: AssetExecutionContext,
4444
)
4545

4646
context.log.info(f"Downloading data to {config.s3_key}")
47-
return {
48-
"bucket": destination_bucket,
49-
"s3_key": config.s3_key,
50-
}
47+
zarr_config = ZarrConfig(
48+
s3_key=config.s3_key,
49+
bucket=destination_bucket,
50+
)
51+
return dg.MaterializeResult(
52+
metadata={
53+
"zarr_config": dg.MetadataValue.json(zarr_config.__dict__),
54+
}
55+
)
5156

5257

53-
@asset(
58+
class ZarrConfig(dg.Config):
59+
s3_key: str
60+
bucket: str
61+
62+
63+
@dg.asset(
5464
name="loca2_zarr",
55-
ins={
56-
"loca2_raw_netcdf": AssetIn()
57-
},
65+
deps=["loca2_raw_netcdf"],
5866
group_name="loca2",
5967
description="LOCA2 data converted to Zarr format",
6068
code_version=downscaled_climate_data.__version__)
61-
def loca2_zarr(context,
62-
loca2_raw_netcdf,
63-
s3: S3Resource):
64-
context.log.info(f"Converting {loca2_raw_netcdf['s3_key']} to zarr")
69+
def loca2_zarr(context: dg.AssetExecutionContext, s3: S3Resource):
70+
upstream_metadata = context.instance.get_latest_materialization_event(
71+
dg.AssetKey("loca2_raw_netcdf")).asset_materialization.metadata
72+
73+
config = ZarrConfig(**upstream_metadata['zarr_config'].data)
74+
context.log.info(f"Converting {config.s3_key} to zarr")
6575

6676
# Initialize s3fs with the same credentials as the S3Resource
6777
fs = s3fs.S3FileSystem(
@@ -70,14 +80,14 @@ def loca2_zarr(context,
7080
endpoint_url=s3.endpoint_url
7181
)
7282

73-
raw_root = EnvVar("LOCA2_RAW_PATH_ROOT").get_value()
74-
zarr_root = EnvVar("LOCA2_ZARR_PATH_ROOT").get_value()
83+
raw_root = dg.EnvVar("LOCA2_RAW_PATH_ROOT").get_value()
84+
zarr_root = dg.EnvVar("LOCA2_ZARR_PATH_ROOT").get_value()
7585
# Construct S3 paths
76-
input_path = f"s3://{loca2_raw_netcdf['bucket']}/{raw_root}{loca2_raw_netcdf['s3_key']}" # NOQA E501
86+
input_path = f"s3://{config.bucket}/{raw_root}{config.s3_key}" # NOQA E501
7787
context.log.info(f"Reading from {input_path}")
7888

79-
zarr_key = loca2_raw_netcdf['s3_key'].replace('.nc', '.zarr')
80-
output_path = f"s3://{loca2_raw_netcdf['bucket']}/{zarr_root}{zarr_key}"
89+
zarr_key = config.s3_key.replace('.nc', '.zarr')
90+
output_path = f"s3://{config.bucket}/{zarr_root}{zarr_key}"
8191
context.log.info(f"Writing to {output_path}")
8292

8393
# Read NetCDF file from S3
@@ -103,7 +113,7 @@ def loca2_zarr(context,
103113
ds.close()
104114

105115

106-
class ESMCatalogConfig(Config):
116+
class ESMCatalogConfig(dg.Config):
107117
data_format: str = "zarr"
108118
id: str = "loca2_zarr_monthly_esm_catalog"
109119
description: str = "LOCA2 Zarr data catalog"
@@ -143,21 +153,21 @@ def parse_key(relative_path: str, bucket: str, full_key: str) -> dict[str, str]:
143153
}
144154

145155

146-
@asset(
156+
@dg.asset(
147157
name="loca2_esm_catalog",
148158
group_name="loca2",
149159
description="Generate an Intake-ESM Catalog for LOCA2 datasets",
150160
code_version=downscaled_climate_data.__version__)
151-
def loca2_esm_catalog(context: AssetExecutionContext,
161+
def loca2_esm_catalog(context: dg.AssetExecutionContext,
152162
config: ESMCatalogConfig,
153163
s3: S3Resource):
154164

155-
bucket = EnvVar("LOCA2_BUCKET").get_value()
165+
bucket = dg.EnvVar("LOCA2_BUCKET").get_value()
156166

157167
if config.is_zarr():
158-
prefix = EnvVar("LOCA2_ZARR_PATH_ROOT").get_value()
168+
prefix = dg.EnvVar("LOCA2_ZARR_PATH_ROOT").get_value()
159169
else:
160-
prefix = EnvVar("LOCA2_RAW_PATH_ROOT").get_value()
170+
prefix = dg.EnvVar("LOCA2_RAW_PATH_ROOT").get_value()
161171

162172
if config.is_monthly():
163173
prefix += "/monthly"

downscaled_climate_data/definitions.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from dagster import Definitions, EnvVar
1+
from dagster import Definitions, EnvVar, define_asset_job
22
from dagster_aws.s3 import S3Resource
33

44
from downscaled_climate_data.assets.loca2 import loca2_zarr, loca2_raw_netcdf
@@ -12,8 +12,19 @@
1212
loca2_sensor_monthly_tasmax,
1313
loca2_sensor_tasmin)
1414

15+
all_assets = [loca2_raw_netcdf, loca2_zarr, loca2_esm_catalog]
16+
17+
loca2_data_job = define_asset_job(
18+
name="loca2_data_job",
19+
selection=[
20+
"loca2_raw_netcdf",
21+
"loca2_zarr"
22+
],
23+
)
24+
1525
defs = Definitions(
16-
assets=[loca2_raw_netcdf, loca2_zarr, loca2_esm_catalog],
26+
assets=all_assets,
27+
jobs=[loca2_data_job],
1728
sensors=[loca2_sensor_tasmax,
1829
loca2_sensor_tasmin,
1930
loca2_sensor_pr,

downscaled_climate_data/sensors/loca2_sensor.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
RunConfig,
1313
)
1414

15-
from downscaled_climate_data.assets.loca2 import loca2_raw_netcdf, loca2_zarr
1615
from downscaled_climate_data.sensors.loca2_models import Loca2Models
1716

1817
# Give ourselves 2 hours to process a single model/scenario
@@ -21,8 +20,6 @@
2120
# For the smaller, monthly files, we can process them more frequently
2221
LOCA2_MONTHLY_SENSOR_FREQUENCY = 120
2322

24-
LOCA2_ASSETS = [loca2_raw_netcdf, loca2_zarr]
25-
2623

2724
class Loca2Datasets(ConfigurableResource):
2825
"""
@@ -113,15 +110,15 @@ def run_request(file: dict[str, str],
113110
:param monthly:
114111
:return:
115112
"""
113+
s3_key = "/monthly" + file["s3_key"] if monthly else file["s3_key"]
116114
return RunRequest(
117-
run_key=file["s3_key"],
115+
run_key=s3_key,
118116
run_config=RunConfig(
119117
{
120118
"loca2_raw_netcdf": {
121119
"config": {
122120
"url": file["url"],
123-
"s3_key": "/monthly" + file["s3_key"]
124-
if monthly else file["s3_key"],
121+
"s3_key": s3_key,
125122
}
126123
},
127124
}
@@ -155,7 +152,7 @@ def sensor_implementation(context, models,
155152

156153
# Now we can launch jobs for each of the files for this model/scenario combination
157154
for file in dataset_resource.get_downloadable_files(
158-
models, model, scenario, monthly=True
155+
models, model, scenario, monthly=monthly
159156
):
160157
context.log.info(f"Found file: {file['url']}")
161158
yield run_request(file, model, scenario, monthly=monthly)
@@ -165,7 +162,7 @@ def sensor_implementation(context, models,
165162

166163
@sensor(
167164
name="LOCA2_Sensor_tasmax",
168-
target=LOCA2_ASSETS,
165+
job_name="loca2_data_job",
169166
minimum_interval_seconds=LOCA2_SENSOR_FREQUENCY,
170167
tags={
171168
"variable": "tasmax",
@@ -182,7 +179,7 @@ def loca2_sensor_tasmax(
182179

183180
@sensor(
184181
name="LOCA2_Sensor_tasmin",
185-
target=LOCA2_ASSETS,
182+
job_name="loca2_data_job",
186183
minimum_interval_seconds=LOCA2_SENSOR_FREQUENCY,
187184
tags={
188185
"variable": "tasmin",
@@ -199,7 +196,7 @@ def loca2_sensor_tasmin(
199196

200197
@sensor(
201198
name="LOCA2_Sensor_pr",
202-
target=LOCA2_ASSETS,
199+
job_name="loca2_data_job",
203200
minimum_interval_seconds=LOCA2_SENSOR_FREQUENCY,
204201
tags={
205202
"variable": "pr",
@@ -216,7 +213,7 @@ def loca2_sensor_pr(
216213

217214
@sensor(
218215
name="LOCA2_Sensor_Monthly_tasmax",
219-
target=LOCA2_ASSETS,
216+
job_name="loca2_data_job",
220217
minimum_interval_seconds=LOCA2_MONTHLY_SENSOR_FREQUENCY,
221218
tags={
222219
"variable": "tasmax",
@@ -234,7 +231,7 @@ def loca2_sensor_monthly_tasmax(
234231

235232
@sensor(
236233
name="LOCA2_Sensor_Monthly_tasmin",
237-
target=LOCA2_ASSETS,
234+
job_name="loca2_data_job",
238235
minimum_interval_seconds=LOCA2_MONTHLY_SENSOR_FREQUENCY,
239236
tags={
240237
"variable": "tasmin",
@@ -252,7 +249,7 @@ def loca2_sensor_monthly_tasmin(
252249

253250
@sensor(
254251
name="LOCA2_Sensor_Monthly_pr",
255-
target=LOCA2_ASSETS,
252+
job_name="loca2_data_job",
256253
minimum_interval_seconds=LOCA2_MONTHLY_SENSOR_FREQUENCY,
257254
tags={
258255
"variable": "pr",

tests/assets/test_loca2_raw_netcdf.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from unittest.mock import patch
33

4-
from dagster import DagsterInstance, build_asset_context
4+
from dagster import DagsterInstance, MaterializeResult, build_asset_context
55

66
from downscaled_climate_data.assets.loca2 import Loca2Config, loca2_raw_netcdf
77

@@ -28,8 +28,9 @@ def test_loca2_raw(mocker):
2828
mock_get.return_value.__enter__.return_value = mock_response
2929

3030
results = loca2_raw_netcdf(context=ctx, config=config)
31-
32-
assert results == {
31+
assert type(results) == MaterializeResult
32+
assert 'zarr_config' in results.metadata
33+
assert results.metadata['zarr_config'].data == {
3334
"bucket": "test_bucket",
3435
"s3_key": "/loca2/cent.nc"
3536
}

tests/assets/test_loca2_zarr.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
from unittest.mock import patch
3-
from dagster import DagsterInstance, build_asset_context
3+
from dagster import DagsterInstance
44

55
from downscaled_climate_data.assets.loca2 import loca2_zarr
66

@@ -15,10 +15,18 @@ def test_as_zarr_asset(mock_s3fs, mock_xarray, mocker):
1515
s3.aws_secret_access_key = "test_secret"
1616
s3.endpoint_url = "https://test"
1717

18-
ctx = build_asset_context(instance=instance,
19-
resources={
20-
"s3": s3
21-
})
18+
# Mock the upstream asset materialization
19+
mock_materialization = mocker.MagicMock()
20+
mock_materialization.asset_materialization.metadata = {
21+
'zarr_config': mocker.MagicMock(data={
22+
'bucket': 'test_bucket',
23+
's3_key': '/hist/cent.nc'
24+
})
25+
}
26+
27+
# Mock the instance to return our mock materialization
28+
instance.get_latest_materialization_event = mocker.MagicMock()
29+
instance.get_latest_materialization_event.return_value = mock_materialization
2230

2331
mock_s3fs.S3FileSystem = mocker.MagicMock()
2432
mock_fs_open = mocker.MagicMock()
@@ -31,11 +39,19 @@ def test_as_zarr_asset(mock_s3fs, mock_xarray, mocker):
3139

3240
os.environ['LOCA2_RAW_PATH_ROOT'] = 'test'
3341
os.environ['LOCA2_ZARR_PATH_ROOT'] = 'test/zarr'
34-
35-
loca2_zarr(context=ctx, loca2_raw_netcdf={
36-
"bucket": "test_bucket",
37-
"s3_key": "/hist/cent.nc"
38-
})
42+
from dagster import materialize_to_memory
43+
result = materialize_to_memory(
44+
[loca2_zarr],
45+
instance=instance,
46+
resources={
47+
"s3": s3
48+
}
49+
)
50+
print(result)
51+
# loca2_zarr(context=ctx, loca2_raw_netcdf={
52+
# "bucket": "test_bucket",
53+
# "s3_key": "/hist/cent.nc"
54+
# })
3955

4056
mock_s3fs.S3FileSystem.assert_called_with(
4157
key="test_key",

0 commit comments

Comments
 (0)