Skip to content

Commit 0132eba

Browse files
committed
Add Chrome trace support
1 parent 0ca5a49 commit 0132eba

File tree

7 files changed

+188
-110
lines changed

7 files changed

+188
-110
lines changed

graphsignal/__init__.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,22 @@ def _parse_env_param(name: str, value: Any, expected_type: Type) -> Any:
3838
raise ValueError(f"Invalid type for {name}: expected {expected_type.__name__}, got {type(value).__name__}")
3939

4040

41-
def _read_config_param(name: str, expected_type: Type, provided_value: Optional[Any] = None, required: bool = False) -> Any:
41+
def _read_config_param(name: str, expected_type: Type, provided_value: Optional[Any] = None, default_value: Optional[Any] = None, required: bool = False) -> Any:
4242
# Check if the value was provided as an argument
4343
if provided_value is not None:
4444
return provided_value
4545

4646
# Check if the value was provided as an environment variable
4747
env_value = os.getenv(f'GRAPHSIGNAL_{name.upper()}')
48-
if env_value is None:
49-
if required:
50-
raise ValueError(f"Missing required argument: {name}")
51-
return None
48+
if env_value is not None:
49+
parsed_env_value = _parse_env_param(name, env_value, expected_type)
50+
if parsed_env_value is not None:
51+
return parsed_env_value
52+
53+
if required:
54+
raise ValueError(f"Missing required argument: {name}")
5255

53-
# Parse the environment variable
54-
return _parse_env_param(name, env_value, expected_type)
56+
return default_value
5557

5658

5759
def _read_config_tags(provided_value: Optional[dict] = None, prefix: str = "GRAPHSIGNAL_TAG_") -> Dict[str, str]:
@@ -68,10 +70,10 @@ def configure(
6870
api_url: Optional[str] = None,
6971
deployment: Optional[str] = None,
7072
tags: Optional[Dict[str, str]] = None,
71-
auto_instrument: Optional[bool] = True,
72-
record_payloads: Optional[bool] = True,
73-
profiling_rate: Optional[float] = 0.1,
74-
debug_mode: Optional[bool] = False
73+
auto_instrument: Optional[bool] = None,
74+
record_payloads: Optional[bool] = None,
75+
profiling_rate: Optional[float] = None,
76+
debug_mode: Optional[bool] = None
7577
) -> None:
7678
global _tracer
7779

@@ -82,10 +84,10 @@ def configure(
8284
api_key = _read_config_param("api_key", str, api_key, required=True)
8385
api_url = _read_config_param("api_url", str, api_url)
8486
tags = _read_config_tags(tags)
85-
auto_instrument = _read_config_param("auto_instrument", bool, auto_instrument)
86-
record_payloads = _read_config_param("record_payloads", bool, record_payloads)
87-
profiling_rate = _read_config_param("profiling_rate", float, profiling_rate)
88-
debug_mode = _read_config_param("debug_mode", bool, debug_mode)
87+
auto_instrument = _read_config_param("auto_instrument", bool, auto_instrument, default_value=True)
88+
record_payloads = _read_config_param("record_payloads", bool, record_payloads, default_value=True)
89+
profiling_rate = _read_config_param("profiling_rate", float, profiling_rate, default_value=0.1)
90+
debug_mode = _read_config_param("debug_mode", bool, debug_mode, default_value=False)
8991

9092
# left for compatibility
9193
if deployment and isinstance(deployment, str):
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import logging
2+
import os
3+
import tempfile
4+
import gzip
5+
import shutil
6+
import glob
7+
8+
logger = logging.getLogger('graphsignal')
9+
10+
def create_log_dir():
11+
log_dir = tempfile.mkdtemp(prefix='graphsignal-')
12+
logger.debug('Created temporary log directory %s', log_dir)
13+
return log_dir
14+
15+
def remove_log_dir(log_dir):
16+
shutil.rmtree(log_dir)
17+
logger.debug('Removed temporary log directory %s', log_dir)
18+
19+
def find_and_read(log_dir, file_pattern, decompress=True, max_size=None):
20+
file_paths = glob.glob(os.path.join(log_dir, file_pattern))
21+
if len(file_paths) == 0:
22+
logger.debug('Files are not found at %s', os.path.join(log_dir, file_pattern))
23+
return None
24+
25+
found_path = file_paths[-1]
26+
27+
if max_size:
28+
file_size = os.path.getsize(found_path)
29+
if file_size > max_size:
30+
raise Exception('File is too big: {0}'.format(file_size))
31+
32+
if decompress and found_path.endswith('.gz'):
33+
last_file = gzip.open(found_path, "rb")
34+
else:
35+
last_file = open(found_path, "rb")
36+
data = last_file.read()
37+
last_file.close()
38+
39+
return data

graphsignal/recorders/pytorch_recorder.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
import logging
2-
import torch
2+
import os
33
import random
44
import json
5+
import time
6+
import torch
57

68
import graphsignal
79
from graphsignal.recorders.base_recorder import BaseRecorder
10+
from graphsignal.recorders.profiler_utils import create_log_dir, remove_log_dir
811

912
logger = logging.getLogger('graphsignal')
1013

1114
class PyTorchRecorder(BaseRecorder):
1215
def __init__(self):
1316
self._torch_prof = None
17+
self._log_dir = None
1418

1519
def setup(self):
1620
pass
@@ -87,17 +91,43 @@ def on_span_read(self, span, context):
8791
count = 1,
8892
duration_ns = _ns(kernel.duration)
8993
)
90-
94+
9195
device_profile = kernel_index.values()
9296
if len(device_profile) > 0:
9397
span.set_profile('device-profile', 'event-averages', json.dumps(device_profile))
9498
span.set_tag('profile_type', 'device') # override cpu value
9599

100+
chrome_trace = self._export_chrome_trace()
101+
if chrome_trace:
102+
span.set_profile('event-timeline', 'chrome-trace', chrome_trace)
103+
96104
if len(cpu_profile) > 0 or len(device_profile) > 0:
97105
span.set_tag('profiler', f'pytorch-{torch.__version__}')
98106
finally:
99107
self._torch_prof = None
100108

109+
def _export_chrome_trace(self):
110+
try:
111+
read_start_time = time.time()
112+
113+
self._log_dir = create_log_dir()
114+
115+
trace_path = os.path.join(self._log_dir, 'trace.json')
116+
self._torch_prof.export_chrome_trace(trace_path)
117+
118+
trace_file_size = os.path.getsize(trace_path)
119+
logger.debug('Chrome trace size: %s', trace_file_size)
120+
if trace_file_size > 50 * 1e6:
121+
raise Exception('Trace file too big: {0}'.format(trace_file_size))
122+
123+
with open(trace_path, "r") as f:
124+
return str(f.read())
125+
finally:
126+
remove_log_dir(self._log_dir)
127+
logger.debug('Chrome trace export time: %s', time.time() - read_start_time)
128+
129+
return None
130+
101131
def _ns(val):
102132
return int(max(val, 0) * 1e3)
103133

graphsignal/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.16.2'
1+
__version__ = '0.16.3'

0 commit comments

Comments
 (0)