diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index efef2fb0..d0c28c6f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,8 +29,8 @@ repos: args: [--show-source, --statistics] - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v0.971' + rev: 'v1.2.0' hooks: - id: mypy - args: [--no-strict-optional, --ignore-missing-imports, --scripts-are-modules, --pretty] - additional_dependencies: [numpy==1.21.5] + args: [--no-strict-optional, --ignore-missing-imports, --explicit-package-bases, --scripts-are-modules, --pretty] + additional_dependencies: [numpy==1.23.5, types-requests] diff --git a/CHANGELOG.md b/CHANGELOG.md index 196f8a26..841dc223 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ Versioning](https://semver.org/spec/v2.0.0.html). ## Unreleased #### Added +- Update pyproject.toml to workaround missing stub packages for yaml. - Add trace format validator - Added multiple trace filter classes and demos. - Added enhanced trace call stack graph implementation. diff --git a/hta/configs/event_args_yaml_parser.py b/hta/configs/event_args_yaml_parser.py index 58a0f94c..cb06984e 100644 --- a/hta/configs/event_args_yaml_parser.py +++ b/hta/configs/event_args_yaml_parser.py @@ -3,9 +3,9 @@ # pyre-strict -import importlib.resources import re from functools import lru_cache +from pathlib import Path from typing import Callable, Dict, List, NamedTuple import yaml @@ -34,6 +34,216 @@ def from_string(version_str: str) -> "YamlVersion": # Yaml version will be mapped to the yaml files defined under the "event_args_formats" folder v1_0_0: YamlVersion = YamlVersion(1, 0, 0) +DEFAULT_YAML_ARGS_FORMAT: str = """version: 1.0.0 + +AVAILABLE_ARGS: + index::ev_idx: + name: ev_idx + raw_name: Ev Idx + value_type: Int + default_value: -1 + index::external_id: + name: external_id + raw_name: External id + value_type: Int + default_value: -1 + cpu_op::concrete_inputs: + name: concrete_inputs + raw_name: Concrete Inputs + value_type: Object + default_value: "[]" + cpu_op::fwd_thread: + name: fwd_thread_id + raw_name: Fwd thread id + value_type: Int + default_value: -1 + cpu_op::input_dims: + name: input_dims + raw_name: Input Dims + value_type: Object + default_value: "-1" + cpu_op::input_type: + name: input_type + raw_name: Input type + value_type: Object + default_value: "-1" + cpu_op::input_strides: + name: input_strides + raw_name: Input Strides + value_type: Object + default_value: "-1" + cpu_op::sequence_number: + name: sequence + raw_name: Sequence number + value_type: Int + default_value: -1 + cpu_op::kernel_backend: + name: kernel_backend + raw_name: kernel_backend + value_type: String + default_value: "" + correlation::cbid: + name: cbid + raw_name: cbid + value_type: Int + default_value: -1 + correlation::cpu_gpu: + name: correlation + raw_name: correlation + value_type: Int + default_value: -1 + sm::blocks: + name: blocks_per_sm + raw_name: blocks per SM + value_type: Object + default_value: "[]" + sm::occupancy: + name: est_occupancy + raw_name: est. achieved occupancy % + value_type: Int + default_value: -1 + sm::warps: + name: warps_per_sm + raw_name: warps per SM + value_type: Float + default_value: 0.0 + data::bytes: + name: bytes + raw_name: bytes + value_type: Int + default_value: -1 + data::bandwidth: + name: memory_bw_gbps + raw_name: memory bandwidth (GB/s) + value_type: Float + default_value: 0.0 + cuda::context: + name: context + raw_name: context + value_type: Int + default_value: -1 + cuda::device: + name: device + raw_name: device + value_type: Int + default_value: -1 + cuda::stream: + name: stream + raw_name: stream + value_type: Int + default_value: -1 + kernel::queued: + name: queued + raw_name: queued + value_type: Int + default_value: -1 + kernel::shared_memory: + name: shared_memory + raw_name: shared memory + value_type: Int + default_value: -1 + threads::block: + name: block + raw_name: block + value_type: Object + default_value: "[]" + threads::grid: + name: grid + raw_name: grid + value_type: Object + default_value: "[]" + threads::registers: + name: registers_per_thread + raw_name: registers per thread + value_type: Int + default_value: -1 + cuda_sync::stream: + name: wait_on_stream + raw_name: wait_on_stream + value_type: Int + default_value: -1 + cuda_sync::event: + name: wait_on_cuda_event_record_corr_id + raw_name: wait_on_cuda_event_record_corr_id + value_type: Int + default_value: -1 + info::labels: + name: labels + raw_name: labels + value_type: String + default_value: "" + info::name: + name: name + raw_name: name + value_type: Int + default_value: -1 + info::op_count: + name: op_count + raw_name: Op count + value_type: Int + default_value: -1 + info::sort_index: + name: sort_index + raw_name: sort_index + value_type: Int + default_value: -1 + nccl::collective_name: + name: collective_name + raw_name: Collective name + value_type: String + default_value: "" + nccl::in_msg_nelems: + name: in_msg_nelems + raw_name: In msg nelems + value_type: Int + default_value: 0 + nccl::out_msg_nelems: + name: out_msg_nelems + raw_name: Out msg nelems + value_type: Int + default_value: 0 + nccl::group_size: + name: group_size + raw_name: Group size + value_type: Int + default_value: 0 + nccl::dtype: + name: msg_dtype + raw_name: dtype + value_type: String + default_value: "" + nccl::in_split_size: + name: in_split_size + raw_name: In split size + value_type: Object + default_value: "[]" + nccl::out_split_size: + name: out_split_size + raw_name: Out split size + value_type: Object + default_value: "[]" + nccl::process_group_name: + name: process_group_name + raw_name: Process Group Name + value_type: String + default_value: "" + nccl::process_group_desc: + name: process_group_desc + raw_name: Process Group Description + value_type: String + default_value: "" + nccl::process_group_ranks: + name: process_group_ranks + raw_name: Process Group Ranks + value_type: Object + default_value: "[]" + nccl::rank: + name: process_rank + raw_name: Rank + value_type: Int + default_value: -1 +""" + ARGS_INPUT_SHAPE_FUNC: Callable[[Dict[str, AttributeSpec]], List[AttributeSpec]] = ( lambda available_args: [ @@ -94,13 +304,15 @@ def from_string(version_str: str) -> "YamlVersion": @lru_cache() def parse_event_args_yaml(version: YamlVersion) -> EventArgs: - local_yaml_data_filepath = str( - importlib.resources.files(__package__).joinpath( - f"event_args_{version.get_version_str()}.yaml", - ) - ) - with open(local_yaml_data_filepath, "r") as f: - yaml_content = yaml.safe_load(f) + pkg_path: Path = Path(__file__).parent + yaml_file = f"event_args_{version.get_version_str()}.yaml" + local_yaml_data_filepath = str(pkg_path.joinpath("event_args_formats", yaml_file)) + + if Path(local_yaml_data_filepath).exists(): + with open(local_yaml_data_filepath, "r") as f: + yaml_content = yaml.safe_load(f) + else: + yaml_content = yaml.safe_load(DEFAULT_YAML_ARGS_FORMAT) def parse_value_type(value: str) -> ValueType: return ValueType[value] diff --git a/pyproject.toml b/pyproject.toml index c998386e..51418765 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,3 +19,7 @@ include_trailing_comma = true use_parentheses = true src_paths = ["hta", "tests"] skip_glob = ["examples/*"] + +[[tool.mypy.overrides]] +module = "yaml" +ignore_missing_imports = true