From 4f2000f3d6c131a3e8b8df0875e9d218e40b597d Mon Sep 17 00:00:00 2001 From: ahalev Date: Wed, 17 May 2023 17:27:02 -0700 Subject: [PATCH 1/2] exclude forecasts key --- src/pymgrid/envs/base/base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/pymgrid/envs/base/base.py b/src/pymgrid/envs/base/base.py index 6b89a769..a6da5d38 100644 --- a/src/pymgrid/envs/base/base.py +++ b/src/pymgrid/envs/base/base.py @@ -128,6 +128,11 @@ def _validate_observation_keys(self, keys): keys = np.array(keys) possible_keys = self.potential_observation_keys() + + if keys[0] == 'exclude_forecasts': + raise RuntimeError('This behavior is currently not working correctly.') + return possible_keys[~possible_keys.str.contains('forecast')].to_list() + bad_keys = [key for key in keys if key not in possible_keys] if bad_keys: From 7d659960f47b484740e7deb47a1945f11be1b9e1 Mon Sep 17 00:00:00 2001 From: ahalev Date: Wed, 17 May 2023 17:27:17 -0700 Subject: [PATCH 2/2] add test that is currently failing --- tests/envs/test_discrete.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/envs/test_discrete.py b/tests/envs/test_discrete.py index 5a472795..d36d1a1e 100644 --- a/tests/envs/test_discrete.py +++ b/tests/envs/test_discrete.py @@ -164,6 +164,31 @@ def test_set_initial_step(self): self.assertEqual(initial_step, 1) + def test_exclude_forecast(self): + env_with_forecasts = DiscreteMicrogridEnv.from_scenario(self.microgrid_number, + observation_keys='exclude_forecasts') + env_without_forecasts = DiscreteMicrogridEnv.from_scenario(self.microgrid_number) + + env_with_forecasts.modules.set_attrs(forecast_horizon=23) + env_without_forecasts.modules.set_attrs(forecast_horizon=0) + + for j in range(4): + + action = env_with_forecasts.action_space.sample() + + out_with_forecasts = env_with_forecasts.step(action) + out_without_forecasts = env_without_forecasts.step(action) + keys = 'obs', 'reward' 'done' 'info' + + for key, obj_1, obj_2 in zip(keys, out_with_forecasts, out_without_forecasts): + with self.subTest(key=key, step=j): + self.assertEqual(obj_1, obj_2) + + + print('here') + print('x') + + class TestDiscreteEnvScenario1(TestDiscreteEnvScenario):