Skip to content

Commit 4d80574

Browse files
committed
Quiet streamlit warnings
1 parent 621d1df commit 4d80574

1 file changed

Lines changed: 7 additions & 5 deletions

File tree

scripts/viz.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ def plot_trajectories(obs: pl.DataFrame, pred: pl.DataFrame, config: Dict[str, A
111111
# merge observed data with prediction by the combination of models and forecast starts
112112
model_forecast_starts = pred.select(["model", "forecast_start"]).unique()
113113
plot_obs = obs.join(model_forecast_starts, how="cross").filter(
114-
pl.col(factor).is_in(pred[factor].unique()) for factor in config["groups"]
114+
pl.col(factor).is_in(pred[factor].unique().implode())
115+
for factor in config["groups"]
115116
)
116117

117118
groupings = ["model", "forecast_start", "time_end"] + config["groups"]
@@ -150,7 +151,7 @@ def plot_trajectories(obs: pl.DataFrame, pred: pl.DataFrame, config: Dict[str, A
150151

151152
chart = layer_with_facets([obs_chart, pred_chart], encodings)
152153

153-
st.altair_chart(chart, use_container_width=True)
154+
st.altair_chart(chart)
154155

155156

156157
def plot_summary(obs: pl.DataFrame, pred: pl.DataFrame, config: Dict[str, Any]):
@@ -172,7 +173,8 @@ def plot_summary(obs: pl.DataFrame, pred: pl.DataFrame, config: Dict[str, Any]):
172173
# data process: merge observed data with prediction by combinations of model and forecast start #
173174
forecast_starts = pred.select(["model", "forecast_start"]).unique()
174175
plot_obs = obs.join(forecast_starts, how="cross").filter(
175-
pl.col(factor).is_in(pred[factor].unique()) for factor in config["groups"]
176+
pl.col(factor).is_in(pred[factor].unique().implode())
177+
for factor in config["groups"]
176178
)
177179

178180
# summarize sample predictions by grouping factors #
@@ -290,7 +292,7 @@ def plot_summary(obs: pl.DataFrame, pred: pl.DataFrame, config: Dict[str, Any]):
290292
chart_list = [interval_chart, obs_chart, pred_chart]
291293
chart = layer_with_facets(chart_list, encodings)
292294

293-
st.altair_chart(chart, use_container_width=True)
295+
st.altair_chart(chart)
294296

295297

296298
def plot_evaluation(scores: pl.DataFrame, config: Dict[str, Any]):
@@ -379,7 +381,7 @@ def plot_evaluation(scores: pl.DataFrame, config: Dict[str, Any]):
379381
.resolve_scale(y="independent")
380382
)
381383

382-
st.altair_chart(chart, use_container_width=True)
384+
st.altair_chart(chart)
383385

384386

385387
## helper: feed correct argument to altair ##

0 commit comments

Comments
 (0)