diff --git a/src/openstack_billing_db/billing.py b/src/openstack_billing_db/billing.py index a93f02b..6eb9ca3 100644 --- a/src/openstack_billing_db/billing.py +++ b/src/openstack_billing_db/billing.py @@ -94,8 +94,8 @@ def get_runtime_for_instance( runtime = instance.get_runtime_during(start, end) for interval_start, interval_end in excluded_intervals: excluded_runtime = instance.get_runtime_during( - start_time=interval_start, - end_time=interval_end, + start_time=interval_start.replace(tzinfo=None), + end_time=interval_end.replace(tzinfo=None), ) runtime = runtime - excluded_runtime diff --git a/src/openstack_billing_db/tests/unit/test_billing.py b/src/openstack_billing_db/tests/unit/test_billing.py index 9e975df..9c20f76 100644 --- a/src/openstack_billing_db/tests/unit/test_billing.py +++ b/src/openstack_billing_db/tests/unit/test_billing.py @@ -1,5 +1,5 @@ import uuid -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone import pytest from openstack_billing_db import billing @@ -36,6 +36,49 @@ def test_instance_simple_runtime(): assert r.total_seconds_stopped == 0 +def test_instance_runtime_timezone_mismatch(): + """Test that mismatched timezones between start/end and excluded_intervals raises an error.""" + time = datetime(year=2000, month=1, day=1, hour=0, minute=0, second=0) + events = [ + InstanceEvent(time=time, name="create", message=""), + InstanceEvent(time=time + timedelta(days=15), name="delete", message=""), + ] + i = Instance( + uuid=uuid.uuid4().hex, name=uuid.uuid4().hex, flavor=FLAVORS[1], events=events + ) + + # Start and end are timezone-naive + start = datetime(year=2000, month=1, day=1, hour=0, minute=0, second=0) + end = datetime(year=2000, month=2, day=1, hour=0, minute=0, second=0) + + # excluded_intervals are timezone-aware + excluded_intervals = [ + ( + datetime( + year=2000, + month=1, + day=7, + hour=0, + minute=0, + second=0, + tzinfo=timezone.utc, + ), + datetime( + year=2000, + month=1, + day=8, + hour=0, + minute=0, + second=0, + tzinfo=timezone.utc, + ), + ), + ] + + # Runs without error + billing.get_runtime_for_instance(i, start, end, excluded_intervals) + + def test_billing_add_su_hours(): invoice = billing.ProjectInvoice( project_name="foo",