Skip to content

Commit 22f9729

Browse files
lc5211The tunix Authors
authored andcommitted
[Tunix] improve trajectory logging to suppor numpy array and scalar types.
PiperOrigin-RevId: 875760462
1 parent 791d90c commit 22f9729

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

tests/utils/trajectory_logger_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import TypedDict
44

55
from absl.testing import absltest
6+
import numpy as np
67
import pandas as pd
78
from tunix.utils import trajectory_logger
89

@@ -52,6 +53,7 @@ def test_log_item_creates_and_writes_to_file(self):
5253
trajectory_id='t0',
5354
completion='c0',
5455
prompt='p0',
56+
value=np.int32(0),
5557
)
5658
trajectory_logger.log_item(temp_dir, item1)
5759

@@ -63,6 +65,7 @@ def test_log_item_creates_and_writes_to_file(self):
6365
trajectory_id='t1',
6466
completion='c1|pipe',
6567
prompt='p1</reasoning>',
68+
value=np.int32(1),
6669
)
6770
trajectory_logger.log_item(temp_dir, item2)
6871

@@ -71,6 +74,7 @@ def test_log_item_creates_and_writes_to_file(self):
7174
trajectory_id='t2',
7275
completion='a, "b", c',
7376
prompt='a prompt with\na newline',
77+
value=np.int32(2),
7478
)
7579
trajectory_logger.log_item(temp_dir, item3)
7680

@@ -80,6 +84,7 @@ def test_log_item_creates_and_writes_to_file(self):
8084
self.assertEqual(df['trajectory_id'].tolist(), ['t0', 't1', 't2'])
8185
self.assertEqual(df['completion'][2], 'a, "b", c')
8286
self.assertEqual(df['prompt'][2], 'a prompt with\na newline')
87+
self.assertEqual(df['value'].tolist(), [0, 1, 2])
8388

8489

8590
if __name__ == '__main__':

tunix/utils/trajectory_logger.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from etils import epath
2626
from google.protobuf import json_format
2727
from google.protobuf import message
28+
import numpy as np
2829
import pandas as pd
2930

3031
def _make_serializable(item: Any) -> Any:
@@ -39,6 +40,16 @@ def _make_serializable(item: Any) -> Any:
3940
return _make_serializable(dataclasses.asdict(item))
4041
elif isinstance(item, message.Message):
4142
return json_format.MessageToDict(item)
43+
elif isinstance(item, np.ndarray):
44+
return _make_serializable(item.tolist())
45+
elif isinstance(item, np.integer):
46+
return int(item)
47+
elif isinstance(item, np.floating):
48+
return float(item)
49+
elif isinstance(item, np.bool_):
50+
return bool(item)
51+
elif isinstance(item, np.str_):
52+
return str(item)
4253
elif isinstance(item, (float, int, bool, str)):
4354
return item
4455
else:

0 commit comments

Comments
 (0)