Skip to content

Commit 982bae7

Browse files
dipannita08copybara-github
authored andcommitted
Add API support for Interval Query
PiperOrigin-RevId: 704847633
1 parent 56a8115 commit 982bae7

File tree

3 files changed

+342
-29
lines changed

3 files changed

+342
-29
lines changed

ml_goodput_measurement/src/goodput.py

Lines changed: 134 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
from cloud_goodput.ml_goodput_measurement.src.checkpoint_badput_calculator import CheckpointBadputCalculator
1313
from cloud_goodput.ml_goodput_measurement.src.checkpoint_badput_calculator import CheckpointLoggerOptions
1414
from cloud_goodput.ml_goodput_measurement.src.goodput_cache import GoodputCache
15-
from cloud_goodput.ml_goodput_measurement.src.goodput_utils import BadputType
16-
from cloud_goodput.ml_goodput_measurement.src.goodput_utils import GoodputInfo
15+
from cloud_goodput.ml_goodput_measurement.src.goodput_utils import BadputType, GoodputInfo, get_timestamp_from_log_entry
1716
import numpy as np
1817
from scipy import stats
1918

@@ -390,6 +389,10 @@ def __init__(
390389
self._cloud_logger = _CloudLogger(job_name, logger_name)
391390
self._current_entries = []
392391
self._goodput_cache = GoodputCache()
392+
self._interval_entries = []
393+
self._interval_start_time = None
394+
self._interval_end_time = None
395+
self._number_of_interruptions = 0
393396

394397
def _get_total_productive_and_unproductive_time(
395398
self,
@@ -442,7 +445,7 @@ def _get_cached_productive_and_unproductive_time(
442445
return (0.0, {}, 0)
443446

444447
def _get_current_productive_and_unproductive_time(
445-
self,
448+
self, interval_query: Optional[bool] = False
446449
) -> tuple[float, dict[BadputType, float], int]:
447450
"""Helper function to compute the current productive training time, unproductive time and the last step recorded till now.
448451
@@ -582,7 +585,11 @@ def _accumulate_segment_unproductive_time(
582585
tpu_initialization_badput = 0.0
583586
training_prep_badput = 0.0
584587
data_loading_badput = 0.0
585-
for payload in self._current_entries:
588+
entries_to_process = (
589+
self._interval_entries if interval_query else self._current_entries
590+
)
591+
self._number_of_interruptions = 0
592+
for payload in entries_to_process:
586593
if _JOB_START_TIME in payload:
587594
# Keep track of the latest start to compute badput due to disruption.
588595
job_start_time = payload[_JOB_START_TIME]
@@ -623,6 +630,7 @@ def _accumulate_segment_unproductive_time(
623630
segment_unproductive_time[
624631
BadputType.WASTED_PROGRESS_FROM_DISRUPTION
625632
] = wasted_progress_from_disruption
633+
self._number_of_interruptions += 1
626634

627635
# The second bucket is individually computed either from recorded
628636
# logs (TPU initialization, training preparation, data loading) or
@@ -676,7 +684,7 @@ def _accumulate_segment_unproductive_time(
676684
checkpoint_badput_calculator = CheckpointBadputCalculator(
677685
checkpoint_logger_options
678686
)
679-
checkpoint_badput_calculator.entries = self._current_entries
687+
checkpoint_badput_calculator.entries = entries_to_process
680688
checkpoint_manager_save_stats = (
681689
checkpoint_badput_calculator.calculate_save_operation_checkpoint_manager_blocking_time()
682690
)
@@ -717,7 +725,15 @@ def _accumulate_segment_unproductive_time(
717725

718726
if job_end_time is not None:
719727
productive_training_time += job_end_time - step_start_data[last_step]
720-
else:
728+
elif (
729+
interval_query
730+
and self._interval_end_time
731+
and self._interval_end_time.timestamp() > step_start_data[last_step]
732+
):
733+
productive_training_time += (
734+
self._interval_end_time.timestamp() - step_start_data[last_step]
735+
)
736+
elif not interval_query:
721737
productive_training_time += (
722738
datetime.datetime.utcnow().timestamp() - step_start_data[last_step]
723739
)
@@ -795,6 +811,63 @@ def _update_log_entries(self):
795811
else:
796812
self._current_entries = self._cloud_logger.read_cloud_logging_entries()
797813

814+
def _get_interval_log_entries(
815+
self, start_time: datetime.datetime, end_time: datetime.datetime
816+
):
817+
"""Helper function to get log entries from an interval window."""
818+
if start_time is None or end_time is None:
819+
raise ValueError(
820+
'Start and end times are required to get log entries from an interval'
821+
' window.'
822+
)
823+
self._interval_entries = self._cloud_logger.read_cloud_logging_entries(
824+
start_time, end_time
825+
)
826+
logging.info(
827+
'Inspecting interval entries between %s and %s', start_time, end_time
828+
)
829+
830+
if not self._interval_entries:
831+
raise ValueError(
832+
'No log entries found within the interval window between %s and %s.'
833+
% (start_time, end_time)
834+
)
835+
836+
def _get_total_job_time_from_interval(
837+
self, start_interval: datetime.datetime, end_interval: datetime.datetime
838+
) -> float:
839+
"""Helper function to compute the total job runtime from interval entries."""
840+
# Get the first and last entry's timestamps in the window
841+
first_entry_timestamp = get_timestamp_from_log_entry(
842+
self._interval_entries[0]
843+
)
844+
last_entry_timestamp = get_timestamp_from_log_entry(
845+
self._interval_entries[-1]
846+
)
847+
848+
# Calculate effective start_time and end_time
849+
self._interval_start_time = (
850+
max(start_interval, first_entry_timestamp)
851+
if first_entry_timestamp
852+
else start_interval
853+
)
854+
self._interval_end_time = (
855+
min(end_interval, last_entry_timestamp)
856+
if last_entry_timestamp
857+
else end_interval
858+
)
859+
860+
# Ensure start_time is not after end_time
861+
if self._interval_start_time >= self._interval_end_time:
862+
raise ValueError(
863+
'Start time is on or after end time, cannot compute total job time.'
864+
)
865+
866+
return (
867+
self._interval_end_time.timestamp()
868+
- self._interval_start_time.timestamp()
869+
)
870+
798871
def get_job_goodput(
799872
self, include_badput_breakdown=False
800873
) -> tuple[float, dict[BadputType, float], int]:
@@ -870,12 +943,18 @@ def get_job_goodput(
870943
)
871944
return job_goodput, job_badput_breakdown, last_step
872945

873-
def get_job_goodput_interval(self, interval_start, interval_end):
874-
"""Method to get the Goodput of the job within an interval window.
946+
def get_job_goodput_interval(
947+
self, interval_start: datetime.datetime, interval_end: datetime.datetime
948+
) -> tuple[float, dict[BadputType, float], int, float, int]:
949+
"""Method to get the Goodput and Badput breakdown of the job within an interval window.
875950
876951
If the application is interested in retrieving the Goodput of the job within
877-
a specific window of time, this method provides the singular Goodput
878-
computation between the start and end of this window.
952+
a specific window of time, this method provides the metrics computed between
953+
the start and end of this window.
954+
955+
Additionaly, this method returns the last step recorded for the job. This is
956+
primarily used for improving monitoring and observability of the job's
957+
overall Goodput as a function of number of executed steps.
879958
880959
Args:
881960
interval_start: The start time of the window for which Goodput is to be
@@ -884,9 +963,52 @@ def get_job_goodput_interval(self, interval_start, interval_end):
884963
computed.
885964
886965
Returns:
887-
Goodput percentage of the job within specified time window.
966+
A tuple containing:
967+
- The job's Goodput percentage with respect to the total job time within
968+
the interval window.
969+
- The Badput Breakdown percentages with respect to the total job time
970+
within the interval window.
971+
- The last step recorded for the job within the interval window.
972+
- The total job time within the interval window.
973+
- The number of disruptions within the interval window.
974+
975+
Raises:
976+
ValueError if computed total job time is zero. In this case, Goodput
977+
cannot be computed.
978+
ValueError if productive training or unproductive time is invalid.
888979
"""
889-
pass
980+
981+
# Get the logs for the interval and validate the interval window.
982+
self._get_interval_log_entries(interval_start, interval_end)
983+
984+
total_job_time = self._get_total_job_time_from_interval(
985+
interval_start, interval_end
986+
)
987+
988+
productive_training_time, total_unproductive_time, last_step = (
989+
self._get_current_productive_and_unproductive_time(interval_query=True)
990+
)
991+
if (
992+
productive_training_time < 0.0
993+
or productive_training_time > total_job_time
994+
):
995+
raise ValueError(
996+
'Productive training time is invalid. Please fix the logging entries.'
997+
)
998+
# Return a tuple of calculated Goodput & Badput of the job till now and the
999+
# last recorded step.
1000+
job_goodput = (float(productive_training_time) / total_job_time) * 100
1001+
job_badput_breakdown = self._get_job_badput_breakdown(
1002+
productive_training_time, total_unproductive_time, total_job_time
1003+
)
1004+
1005+
return (
1006+
job_goodput,
1007+
job_badput_breakdown,
1008+
last_step,
1009+
total_job_time,
1010+
self._number_of_interruptions,
1011+
)
8901012

8911013
def _get_job_badput_breakdown(
8921014
self, total_productive_time, total_unproductive_time, total_job_time

ml_goodput_measurement/src/goodput_utils.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
import datetime
44
import enum
5-
from typing import Optional
5+
from typing import Any, Optional
6+
7+
_TIME_ENTRY = 'time'
68

79

810
class BadputType(enum.Enum):
@@ -32,3 +34,19 @@ def __init__(
3234
self.total_elapsed_time_since_start = total_elapsed_time_since_start
3335
self.total_unproductive_time = total_unproductive_time
3436
self.last_recorded_step = last_recorded_step
37+
38+
39+
def get_timestamp_from_log_entry(
40+
entry: dict[str, Any],
41+
) -> datetime.datetime | None:
42+
"""Helper function to get the timestamp from a log entry."""
43+
timestamp_posix_time = [
44+
entry_value
45+
for entry_label, entry_value in entry.items()
46+
if _TIME_ENTRY in entry_label
47+
]
48+
if timestamp_posix_time:
49+
return datetime.datetime.fromtimestamp(
50+
timestamp_posix_time[0], datetime.timezone.utc
51+
)
52+
return None

0 commit comments

Comments
 (0)