Skip to content

Commit

Permalink
Fix metrics saver when a file location is specified (#8644)
Browse files Browse the repository at this point in the history
  • Loading branch information
yaochengji authored Jan 29, 2025
1 parent 0183997 commit 7243b80
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 1 deletion.
4 changes: 4 additions & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@ function run_xla_op_tests1 {
run_test "$CDIR/test_python_ops.py"
run_test "$CDIR/test_ops.py"
run_test "$CDIR/test_metrics.py"
if [ -f "/tmp/metrics.txt" ] ; then
rm /tmp/metrics.txt
fi
XLA_METRICS_FILE=/tmp/metrics.txt run_test "$CDIR/test_metrics.py"
run_test "$CDIR/test_deprecation.py"
run_test "$CDIR/dynamo/test_dynamo_integrations_util.py"
run_test "$CDIR/dynamo/test_dynamo_aliasing.py"
Expand Down
8 changes: 8 additions & 0 deletions test/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ def XLAExperimentalContains(feat):
return feat in experimental


def check_metrics_file():
metrics_file = os.environ.get("XLA_METRICS_FILE", None)
if metrics_file is None:
return True
return os.path.exists(metrics_file)


class MetricsTest(unittest.TestCase):

def test_clear_counters(self):
Expand Down Expand Up @@ -90,6 +97,7 @@ def test_short_metrics_report_default_list(self):
xm.mark_step()
short_report = met.short_metrics_report()
self.assertIn("CachedCompile", short_report)
assert check_metrics_file()

def test_short_metrics_report_custom_list(self):
xla_device = xm.xla_device()
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/debug/metrics_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def _extract_metrics_file():
import torch_xla.core.xla_model as xm
metrics_file = os.environ.get('XLA_METRICS_FILE', None)
if metrics_file is not None:
ordinal = xm.get_local_ordinal(defval=-1)
ordinal = xm.get_local_ordinal()
if ordinal >= 0 and xr.world_size() > 1:
metrics_file = '{}.{}'.format(metrics_file, ordinal)
return metrics_file
Expand Down

0 comments on commit 7243b80

Please sign in to comment.