Skip to content

Commit 215d7d9

Browse files
Use math.fsum for duration aggregation in SimpleProfiler (#21525)
* Update test_mlflow.py * test_mlflow.py * Optimize SimpleProfiler duration aggregation * re run ci * Add tests for SimpleProfiler extended report functionality --------- Co-authored-by: bhimrazy <bhimrajyadav977@gmail.com>
1 parent 7d2de87 commit 215d7d9

2 files changed

Lines changed: 26 additions & 5 deletions

File tree

src/lightning/pytorch/profilers/simple.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414
"""Profiler to check if there are any bottlenecks in your code."""
1515

1616
import logging
17+
import math
1718
import os
1819
import time
1920
from collections import defaultdict
2021
from pathlib import Path
2122
from typing import Optional, Union
2223

23-
import torch
2424
from typing_extensions import override
2525

2626
from lightning.pytorch.profilers.profiler import Profiler
@@ -86,9 +86,8 @@ def _make_report_extended(self) -> tuple[_TABLE_DATA_EXTENDED, float, float]:
8686
report = []
8787

8888
for a, d in self.recorded_durations.items():
89-
d_tensor = torch.tensor(d)
9089
len_d = len(d)
91-
sum_d = torch.sum(d_tensor).item()
90+
sum_d = math.fsum(d)
9291
percentage_d = 100.0 * sum_d / total_duration
9392

9493
report.append((a, sum_d / len_d, len_d, sum_d, percentage_d))
@@ -100,8 +99,7 @@ def _make_report_extended(self) -> tuple[_TABLE_DATA_EXTENDED, float, float]:
10099
def _make_report(self) -> _TABLE_DATA:
101100
report = []
102101
for action, d in self.recorded_durations.items():
103-
d_tensor = torch.tensor(d)
104-
sum_d = torch.sum(d_tensor).item()
102+
sum_d = math.fsum(d)
105103

106104
report.append((action, sum_d / len(d), sum_d))
107105

tests/tests_pytorch/profilers/test_profiler.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,29 @@ def test_simple_profiler_logs(tmp_path, caplog, simple_profiler):
194194
assert caplog.text.count("Profiler Report") == 2
195195

196196

197+
def test_simple_profiler_uses_math_fsum(monkeypatch):
198+
profiler = SimpleProfiler()
199+
profiler.recorded_durations["action"] = [1.0, 2.0, 3.0]
200+
profiler.start_time = 0.0
201+
202+
fsum_calls: list[list[float]] = []
203+
204+
def _fake_fsum(values):
205+
fsum_calls.append(list(values))
206+
return sum(values)
207+
208+
monkeypatch.setattr("lightning.pytorch.profilers.simple.math.fsum", _fake_fsum)
209+
210+
# Test non-extended report
211+
profiler._make_report()
212+
assert fsum_calls == [[1.0, 2.0, 3.0]]
213+
214+
# Test extended report
215+
fsum_calls.clear()
216+
profiler._make_report_extended()
217+
assert fsum_calls == [[1.0, 2.0, 3.0]]
218+
219+
197220
@pytest.mark.parametrize("extended", [True, False])
198221
@patch("time.perf_counter", return_value=70)
199222
def test_simple_profiler_summary(tmp_path, extended):

0 commit comments

Comments
 (0)