diff --git a/routee/transit/thermal_energy.py b/routee/transit/thermal_energy.py index 66d0968..9d252a5 100644 --- a/routee/transit/thermal_energy.py +++ b/routee/transit/thermal_energy.py @@ -213,7 +213,7 @@ def add_HVAC_energy( Trips on selected date and route, including deadhead trips. output_dir : Path or None Directory used to store downloaded TMY weather files (in a ``TMY/`` - subdirectory). Must not be None. + subdirectory). If None, defaults to ``~/cache/routee-transit/TMY``. Returns ------- @@ -225,7 +225,7 @@ def add_HVAC_energy( if output_dir is not None: tmy_dir = output_dir / "TMY" else: - raise Exception("Must specify output_dir if downloading TMY data") + tmy_dir = Path.home() / "cache" / "routee-transit" / "TMY" # Based on gtfs stops data, get counties served df_stops = feed.stops diff --git a/tests/test_thermal_energy.py b/tests/test_thermal_energy.py index 3aa478a..901dbe6 100644 --- a/tests/test_thermal_energy.py +++ b/tests/test_thermal_energy.py @@ -36,21 +36,13 @@ def test_compute_HVAC_energy(self) -> None: # np.trapezoid on [0, ..., 0.99] with constant 10 gives 9.9. self.assertAlmostEqual(energy[0], 9.9, places=1) - @patch("routee.transit.thermal_energy.fetch_counties_gdf") - @patch("routee.transit.thermal_energy.download_tmy_files") - @patch("routee.transit.thermal_energy.get_hourly_temperature") - def test_add_HVAC_energy( + def _make_mock_feed_and_trips( self, - mock_get_hourly: MagicMock, - mock_download: MagicMock, - mock_fetch_counties: MagicMock, - ) -> None: - # Setup mock Feed + ) -> tuple[MagicMock, pd.DataFrame, gpd.GeoDataFrame, pd.DataFrame]: mock_feed = MagicMock() mock_feed.stops = pd.DataFrame( {"stop_id": ["S1"], "stop_lat": [40.0], "stop_lon": [-105.0]} ) - # Mock stop_times for HVAC calculation (integration needs start/end) mock_feed.stop_times = pd.DataFrame( { "trip_id": ["T1", "T1"], @@ -58,26 +50,34 @@ def test_add_HVAC_energy( "stop_id": ["S1", "S1"], } ) - - # Setup mock trips trips_df = pd.DataFrame({"trip_id": ["T1"]}) - - # Mock dependencies mock_county_gdf = gpd.GeoDataFrame( { - "county_id": ["G0800130"], # Example FIPS + "county_id": ["G0800130"], "STATEFP": ["08"], "COUNTYFP": ["013"], "geometry": [Point(-105.0, 40.0).buffer(1.0)], }, crs="EPSG:4269", ) - mock_fetch_counties.return_value = mock_county_gdf - - # Mock hourly temperature and power mapping mock_hourly_temp = pd.DataFrame( {"hour": list(range(24)), "Dry Bulb Temperature [°C]": [20.0] * 24} ) + return mock_feed, trips_df, mock_county_gdf, mock_hourly_temp + + @patch("routee.transit.thermal_energy.fetch_counties_gdf") + @patch("routee.transit.thermal_energy.download_tmy_files") + @patch("routee.transit.thermal_energy.get_hourly_temperature") + def test_add_HVAC_energy( + self, + mock_get_hourly: MagicMock, + mock_download: MagicMock, + mock_fetch_counties: MagicMock, + ) -> None: + mock_feed, trips_df, mock_county_gdf, mock_hourly_temp = ( + self._make_mock_feed_and_trips() + ) + mock_fetch_counties.return_value = mock_county_gdf mock_get_hourly.return_value = mock_hourly_temp # use a temp directory for output @@ -90,6 +90,29 @@ def test_add_HVAC_energy( # Should have results for 3 scenarios: summer, winter, median self.assertEqual(len(result), 3) + @patch("routee.transit.thermal_energy.fetch_counties_gdf") + @patch("routee.transit.thermal_energy.download_tmy_files") + @patch("routee.transit.thermal_energy.get_hourly_temperature") + def test_add_HVAC_energy_no_output_dir( + self, + mock_get_hourly: MagicMock, + mock_download: MagicMock, + mock_fetch_counties: MagicMock, + ) -> None: + """add_HVAC_energy should work without output_dir using a default cache path.""" + mock_feed, trips_df, mock_county_gdf, mock_hourly_temp = ( + self._make_mock_feed_and_trips() + ) + mock_fetch_counties.return_value = mock_county_gdf + mock_get_hourly.return_value = mock_hourly_temp + + # Call without specifying output_dir (previously raised an exception) + result = add_HVAC_energy(mock_feed, trips_df, output_dir=None) + + self.assertIn("hvac_energy_kWh", result.columns) + self.assertIn("scenario", result.columns) + self.assertEqual(len(result), 3) + if __name__ == "__main__": unittest.main()