diff --git a/hta/common/call_stack.py b/hta/common/call_stack.py index 81e6a4d..2b637c5 100644 --- a/hta/common/call_stack.py +++ b/hta/common/call_stack.py @@ -461,6 +461,7 @@ def __init__( trace: Trace, ranks: Optional[List[int]] = None, filter_func: Optional[Filter] = None, + thread_merge_func: Optional[Callable[[int, int], int]] = None, ) -> None: """Construct a CallGraph from a Trace object @@ -468,6 +469,8 @@ def __init__( trace (Trace): the trace data used to construct this CallGraph object. ranks (List[int]) : filter the traces using the given set of ranks. Using all ranks if None. filter_func (Callable) : used to preprocess the trace events and filter events out. Please see filters in hta/common/trace_filter.py for details. + thread_merge_func (Callable) : used to merge threads in the traces. Takes in a tuple of (rank, thread_id) and returns the target thread id for use in the graph + Raises: ValueError: the trace data is invalid. """ @@ -476,10 +479,13 @@ def __init__( self.call_stacks: List[CallStackGraph] = [] _ranks = [k for k in trace.get_all_traces()] if ranks is None else ranks - self._construct_call_graph(_ranks, filter_func) + self._construct_call_graph(_ranks, filter_func, thread_merge_func) def _construct_call_graph( - self, ranks: List[int], filter_func: Optional[Filter] + self, + ranks: List[int], + filter_func: Optional[Filter], + thread_remap_func: Optional[Callable[[int, int], int]] = None, ) -> None: """ Construct the call graph from the traces of a distributed training job. @@ -490,10 +496,19 @@ def _construct_call_graph( """ call_stack_ids: List[CallStackIdentity] = [] t0 = perf_counter() + + groupby_key = ["pid", "tid"] + # construct a call stack graph for each thread/stream for rank in ranks: df = self.trace_data.get_trace(rank) - for (pid, tid), df_thread in df.groupby(["pid", "tid"]): + if thread_remap_func: + df.loc[:, "tid"] = df["tid"].map( + lambda x, rank=rank: thread_remap_func(rank, x) + ) + for row_group, df_thread in df.groupby(groupby_key): + pid, tid = row_group + if df_thread.stream.gt(0).any(): # Filter out gpu annotations and sync events df_thread = df_thread[df_thread["stream"].gt(0)] diff --git a/hta/configs/env_options.py b/hta/configs/env_options.py index a3c026f..a9d8a20 100644 --- a/hta/configs/env_options.py +++ b/hta/configs/env_options.py @@ -3,8 +3,7 @@ # LICENSE file in the root directory of this source tree. import os - -from typing import Optional +from typing import Dict, Optional """ HTA provides a set of options to modify behavior of the analyzers using environmenent variables. @@ -28,46 +27,129 @@ CP_STRICT_NEG_WEIGHT_CHECK_ENV = "CRITICAL_PATH_STRICT_NEGATIVE_WEIGHT_CHECKS" -def _get_env(name: str) -> Optional[str]: - """Checks for env or returns None""" - return os.environ.get(name) - - -def _check_env_flag(name: str, default: str = "0") -> bool: - """Checks if env flag is "1" """ - if (value := _get_env(name)) is None: - value = default - return value == "1" +class HTAEnvOptions: + """Singleton class that manages HTA environment options. + + This class reads environment variables when initialized and provides + methods to access and modify the options. Use the instance() method + to get the singleton instance. + """ + + _instance = None + + def __init__(self): + """Initialize options from environment variables.""" + # Read environment variables + self._options: Dict[str, bool] = {} + self._initialize_options() + + def _initialize_options(self) -> None: + """Initialize options from environment variables.""" + self._options = { + HTA_DISABLE_NS_ROUNDING_ENV: self._check_env_flag( + HTA_DISABLE_NS_ROUNDING_ENV, "0" + ), + HTA_DISABLE_CG_DEPTH_ENV: self._check_env_flag( + HTA_DISABLE_CG_DEPTH_ENV, "0" + ), + CP_LAUNCH_EDGE_ENV: self._check_env_flag(CP_LAUNCH_EDGE_ENV, "0"), + CP_LAUNCH_EDGE_SHOW_ENV: self._check_env_flag(CP_LAUNCH_EDGE_SHOW_ENV, "0"), + CP_STRICT_NEG_WEIGHT_CHECK_ENV: self._check_env_flag( + CP_STRICT_NEG_WEIGHT_CHECK_ENV, "0" + ), + } + + @classmethod + def instance(cls) -> "HTAEnvOptions": + """Get the singleton instance of HTAEnvOptions.""" + if cls._instance is None: + cls._instance = HTAEnvOptions() + return cls._instance + + def _get_env(self, name: str) -> Optional[str]: + """Checks for env or returns None""" + return os.environ.get(name) + + def _check_env_flag(self, name: str, default: str = "0") -> bool: + """Checks if env flag is "1" """ + if (value := self._get_env(name)) is None: + value = default + return value == "1" + + def disable_ns_rounding(self) -> bool: + """Check if nanosecond rounding is disabled.""" + return self._options[HTA_DISABLE_NS_ROUNDING_ENV] + + def set_disable_ns_rounding(self, value: bool) -> None: + """Set whether nanosecond rounding is disabled.""" + self._options[HTA_DISABLE_NS_ROUNDING_ENV] = value + + def disable_call_graph_depth(self) -> bool: + """Check if call graph depth is disabled.""" + return self._options[HTA_DISABLE_CG_DEPTH_ENV] + + def set_disable_call_graph_depth(self, value: bool) -> None: + """Set whether call graph depth is disabled.""" + self._options[HTA_DISABLE_CG_DEPTH_ENV] = value + + def critical_path_add_zero_weight_launch_edges(self) -> bool: + """Check if zero weight launch edges should be added for critical path analysis.""" + return self._options[CP_LAUNCH_EDGE_ENV] + + def set_critical_path_add_zero_weight_launch_edges(self, value: bool) -> None: + """Set whether zero weight launch edges should be added for critical path analysis.""" + self._options[CP_LAUNCH_EDGE_ENV] = value + + def critical_path_show_zero_weight_launch_edges(self) -> bool: + """Check if zero weight launch edges should be shown in overlaid trace.""" + return self._options[CP_LAUNCH_EDGE_SHOW_ENV] + + def set_critical_path_show_zero_weight_launch_edges(self, value: bool) -> None: + """Set whether zero weight launch edges should be shown in overlaid trace.""" + self._options[CP_LAUNCH_EDGE_SHOW_ENV] = value + + def critical_path_strict_negative_weight_check(self) -> bool: + """Check if strict negative weight checking is enabled for critical path analysis.""" + return self._options[CP_STRICT_NEG_WEIGHT_CHECK_ENV] + + def set_critical_path_strict_negative_weight_check(self, value: bool) -> None: + """Set whether strict negative weight checking is enabled for critical path analysis.""" + self._options[CP_STRICT_NEG_WEIGHT_CHECK_ENV] = value + + def get_options_str(self) -> str: + """Get a string representation of all options.""" + + def get_env(name: str) -> str: + return self._get_env(name) or "unset" + + return f""" +disable_ns_rounding={self.disable_ns_rounding()}, HTA_DISABLE_NS_ROUNDING_ENV={get_env(HTA_DISABLE_NS_ROUNDING_ENV)} +disable_call_graph_depth={self.disable_call_graph_depth()}, HTA_DISABLE_CG_DEPTH_ENV={get_env(HTA_DISABLE_CG_DEPTH_ENV)} +critical_path_add_zero_weight_launch_edges={self.critical_path_add_zero_weight_launch_edges()}, CP_LAUNCH_EDGE_ENV={get_env(CP_LAUNCH_EDGE_ENV)} +critical_path_show_zero_weight_launch_edges={self.critical_path_show_zero_weight_launch_edges()}, CP_LAUNCH_EDGE_SHOW_ENV={get_env(CP_LAUNCH_EDGE_SHOW_ENV)} +critical_path_strict_negative_weight_check={self.critical_path_strict_negative_weight_check()}, CP_STRICT_NEG_WEIGHT_CHECK_ENV={get_env(CP_STRICT_NEG_WEIGHT_CHECK_ENV)} +""" def disable_ns_rounding() -> bool: - return _check_env_flag(HTA_DISABLE_NS_ROUNDING_ENV, "0") + return HTAEnvOptions.instance().disable_ns_rounding() def disable_call_graph_depth() -> bool: - return _check_env_flag(HTA_DISABLE_CG_DEPTH_ENV, "0") + return HTAEnvOptions.instance().disable_call_graph_depth() def critical_path_add_zero_weight_launch_edges() -> bool: - return _check_env_flag(CP_LAUNCH_EDGE_ENV, "0") + return HTAEnvOptions.instance().critical_path_add_zero_weight_launch_edges() def critical_path_show_zero_weight_launch_edges() -> bool: - return _check_env_flag(CP_LAUNCH_EDGE_SHOW_ENV, "0") + return HTAEnvOptions.instance().critical_path_show_zero_weight_launch_edges() def critical_path_strict_negative_weight_check() -> bool: - return _check_env_flag(CP_STRICT_NEG_WEIGHT_CHECK_ENV, "0") + return HTAEnvOptions.instance().critical_path_strict_negative_weight_check() def get_options() -> str: - def get_env(name: str) -> str: - return _get_env(name) or "unset" - - return f""" -disable_ns_rounding={disable_ns_rounding()}, HTA_DISABLE_NS_ROUNDING_ENV={get_env(HTA_DISABLE_NS_ROUNDING_ENV)} -disable_call_graph_depth={disable_call_graph_depth()}, HTA_DISABLE_CG_DEPTH_ENV={get_env(HTA_DISABLE_CG_DEPTH_ENV)} -critical_path_add_zero_weight_launch_edges={critical_path_add_zero_weight_launch_edges()}, CP_LAUNCH_EDGE_ENV={get_env(CP_LAUNCH_EDGE_ENV)} -critical_path_show_zero_weight_launch_edges={critical_path_show_zero_weight_launch_edges()}, CP_LAUNCH_EDGE_SHOW_ENV={get_env(CP_LAUNCH_EDGE_SHOW_ENV)} -critical_path_strict_negative_weight_check={critical_path_strict_negative_weight_check()}, CP_STRICT_NEG_WEIGHT_CHECK_ENV={get_env(CP_STRICT_NEG_WEIGHT_CHECK_ENV)} -""" + return HTAEnvOptions.instance().get_options_str() diff --git a/tests/test_config.py b/tests/test_config.py index fa35423..7919199 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -3,12 +3,19 @@ # LICENSE file in the root directory of this source tree. import json +import os import unittest from pathlib import Path +from unittest.mock import patch from hta.configs.config import HtaConfig from hta.configs.default_values import DEFAULT_CONFIG_FILENAME -from hta.configs.env_options import get_options +from hta.configs.env_options import ( + CP_LAUNCH_EDGE_ENV, + get_options, + HTA_DISABLE_NS_ROUNDING_ENV, + HTAEnvOptions, +) class HtaConfigTestCase(unittest.TestCase): @@ -28,7 +35,7 @@ def test_get_default_paths(self): len(paths), 3, f"expect the default file paths to be 3 but got {len(paths)}" ) self.assertTrue( - all([str(path).endswith(DEFAULT_CONFIG_FILENAME) for path in paths]) + all(str(path).endswith(DEFAULT_CONFIG_FILENAME) for path in paths) ) def test_constructor_no_config_file(self): @@ -79,5 +86,91 @@ def test_get_test_data_path(self): self.assertTrue(Path(data_path).exists()) +class HTAEnvOptionsTestCase(unittest.TestCase): + def setUp(self) -> None: + # Reset the singleton instance before each test + HTAEnvOptions._instance = None + # Save original environment variables + self.original_env = os.environ.copy() + + def tearDown(self) -> None: + # Reset the singleton instance after each test + HTAEnvOptions._instance = None + # Restore original environment variables + os.environ.clear() + os.environ.update(self.original_env) + + def test_singleton_behavior(self): + """Test that instance() always returns the same instance.""" + instance1 = HTAEnvOptions.instance() + instance2 = HTAEnvOptions.instance() + self.assertIs(instance1, instance2, "instance() should return the same object") + + def test_get_set_options(self): + """Test getting and setting options.""" + options = HTAEnvOptions.instance() + + # Test default values + self.assertFalse(options.disable_ns_rounding()) + self.assertFalse(options.disable_call_graph_depth()) + self.assertFalse(options.critical_path_add_zero_weight_launch_edges()) + self.assertFalse(options.critical_path_show_zero_weight_launch_edges()) + self.assertFalse(options.critical_path_strict_negative_weight_check()) + + # Test setting values + options.set_disable_ns_rounding(True) + self.assertTrue(options.disable_ns_rounding()) + + options.set_critical_path_add_zero_weight_launch_edges(True) + self.assertTrue(options.critical_path_add_zero_weight_launch_edges()) + + # Test that other values remain unchanged + self.assertFalse(options.disable_call_graph_depth()) + self.assertFalse(options.critical_path_show_zero_weight_launch_edges()) + self.assertFalse(options.critical_path_strict_negative_weight_check()) + + def test_environment_variable_reading(self): + """Test that environment variables are correctly read.""" + # Set environment variables + os.environ[HTA_DISABLE_NS_ROUNDING_ENV] = "1" + os.environ[CP_LAUNCH_EDGE_ENV] = "1" + + # Create a new instance that should read these environment variables + HTAEnvOptions._instance = None + options = HTAEnvOptions.instance() + + # Check that the environment variables were correctly read + self.assertTrue(options.disable_ns_rounding()) + self.assertTrue(options.critical_path_add_zero_weight_launch_edges()) + self.assertFalse(options.disable_call_graph_depth()) # Default value + + def test_get_options_str(self): + """Test the get_options_str method.""" + options = HTAEnvOptions.instance() + options_str = options.get_options_str() + + # Check that the string contains all option names + self.assertIn("disable_ns_rounding", options_str) + self.assertIn("disable_call_graph_depth", options_str) + self.assertIn("critical_path_add_zero_weight_launch_edges", options_str) + self.assertIn("critical_path_show_zero_weight_launch_edges", options_str) + self.assertIn("critical_path_strict_negative_weight_check", options_str) + + @patch.dict(os.environ, {HTA_DISABLE_NS_ROUNDING_ENV: "1"}) + def test_legacy_functions(self): + """Test that legacy functions use the singleton instance.""" + from hta.configs.env_options import ( + disable_call_graph_depth, + disable_ns_rounding, + ) + + # Reset the singleton to ensure it reads the patched environment + HTAEnvOptions._instance = None + + # Check that legacy functions return the correct values + self.assertTrue(disable_ns_rounding()) + self.assertFalse(disable_call_graph_depth()) + + if __name__ == "__main__": # pragma: no cover unittest.main()