@@ -111,7 +111,8 @@ def plot_trajectories(obs: pl.DataFrame, pred: pl.DataFrame, config: Dict[str, A
111111 # merge observed data with prediction by the combination of models and forecast starts
112112 model_forecast_starts = pred .select (["model" , "forecast_start" ]).unique ()
113113 plot_obs = obs .join (model_forecast_starts , how = "cross" ).filter (
114- pl .col (factor ).is_in (pred [factor ].unique ()) for factor in config ["groups" ]
114+ pl .col (factor ).is_in (pred [factor ].unique ().implode ())
115+ for factor in config ["groups" ]
115116 )
116117
117118 groupings = ["model" , "forecast_start" , "time_end" ] + config ["groups" ]
@@ -150,7 +151,7 @@ def plot_trajectories(obs: pl.DataFrame, pred: pl.DataFrame, config: Dict[str, A
150151
151152 chart = layer_with_facets ([obs_chart , pred_chart ], encodings )
152153
153- st .altair_chart (chart , use_container_width = True )
154+ st .altair_chart (chart )
154155
155156
156157def plot_summary (obs : pl .DataFrame , pred : pl .DataFrame , config : Dict [str , Any ]):
@@ -172,7 +173,8 @@ def plot_summary(obs: pl.DataFrame, pred: pl.DataFrame, config: Dict[str, Any]):
172173 # data process: merge observed data with prediction by combinations of model and forecast start #
173174 forecast_starts = pred .select (["model" , "forecast_start" ]).unique ()
174175 plot_obs = obs .join (forecast_starts , how = "cross" ).filter (
175- pl .col (factor ).is_in (pred [factor ].unique ()) for factor in config ["groups" ]
176+ pl .col (factor ).is_in (pred [factor ].unique ().implode ())
177+ for factor in config ["groups" ]
176178 )
177179
178180 # summarize sample predictions by grouping factors #
@@ -290,7 +292,7 @@ def plot_summary(obs: pl.DataFrame, pred: pl.DataFrame, config: Dict[str, Any]):
290292 chart_list = [interval_chart , obs_chart , pred_chart ]
291293 chart = layer_with_facets (chart_list , encodings )
292294
293- st .altair_chart (chart , use_container_width = True )
295+ st .altair_chart (chart )
294296
295297
296298def plot_evaluation (scores : pl .DataFrame , config : Dict [str , Any ]):
@@ -379,7 +381,7 @@ def plot_evaluation(scores: pl.DataFrame, config: Dict[str, Any]):
379381 .resolve_scale (y = "independent" )
380382 )
381383
382- st .altair_chart (chart , use_container_width = True )
384+ st .altair_chart (chart )
383385
384386
385387## helper: feed correct argument to altair ##
0 commit comments