Skip to content

Commit f038977

Browse files
committed
use SingletonMeta for _TensorboardAdapter
1 parent 5b697f6 commit f038977

File tree

5 files changed

+12
-20
lines changed

5 files changed

+12
-20
lines changed

docs/en/get_started/quick_start.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
515515
# 7. Fill and return Sample object
516516
sample.response = full_response
517517
sample.tokens = ...
518-
sample.loss_masks = loss_masks
518+
sample.loss_mask = loss_masks
519519
return sample
520520
```
521521

docs/zh/get_started/quick_start.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
521521
# 7. 填充并返回 Sample 对象
522522
sample.response = full_response
523523
sample.tokens = ...
524-
sample.loss_masks = loss_masks
524+
sample.loss_mask = loss_masks
525525
return sample
526526
```
527527

examples/retool/generate_with_retool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
300300
sample.tokens = prompt_tokens_ids + response_token_ids
301301
sample.response_length = len(response_token_ids)
302302
sample.response = response
303-
sample.loss_masks = loss_masks
303+
sample.loss_mask = loss_masks
304304

305305
# Store payload information for wandb logging
306306
sample.payload_text = prompt + response

slime/utils/tensorboard_utils.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import datetime
22
import os
3+
from slime.utils.misc import SingletonMeta
34

45
try:
56
from torch.utils.tensorboard import SummaryWriter
67
except:
78
SummaryWriter = None
89

910

10-
class _TensorboardAdapter:
11+
class _TensorboardAdapter(metaclass=SingletonMeta):
1112
_instance = None
1213
_writer = None
1314

@@ -22,20 +23,16 @@ class _TensorboardAdapter:
2223
# tb.log({"Accuracy": 0.9}, step=1)
2324
"""
2425

25-
def __new__(cls, args):
26+
def __init__(self, args):
2627
assert args.use_tensorboard, f"{args.use_tensorboard=}"
2728
tb_project_name = args.tb_project_name
2829
tb_experiment_name = args.tb_experiment_name
29-
if cls._instance is None:
30-
cls._instance = super(_TensorboardAdapter, cls).__new__(cls)
31-
# Initialize if parameters are provided during first creation
32-
if tb_project_name is not None or os.environ.get("TENSORBOARD_DIR", None):
33-
if tb_project_name is not None and tb_experiment_name is None:
34-
tb_experiment_name = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
35-
cls._instance._initialize(tb_project_name, tb_experiment_name)
36-
else:
37-
raise ValueError("tb_project_name and tb_experiment_name, or TENSORBOARD_DIR are required")
38-
return cls._instance
30+
if tb_project_name is not None or os.environ.get("TENSORBOARD_DIR", None):
31+
if tb_project_name is not None and tb_experiment_name is None:
32+
tb_experiment_name = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
33+
self._instance._initialize(tb_project_name, tb_experiment_name)
34+
else:
35+
raise ValueError("tb_project_name and tb_experiment_name, or TENSORBOARD_DIR are required")
3936

4037
def _initialize(self, tb_project_name, tb_experiment_name):
4138
"""Actual initialization logic"""

train.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,6 @@ def train(args):
1111
pgs = create_placement_groups(args)
1212
wandb_run_id = init_wandb_primary(args)
1313

14-
if args.use_tensorboard:
15-
from slime.utils.tensorboard_utils import _TensorboardAdapter
16-
17-
_TensorboardAdapter(args)
18-
1914
# create the rollout manager, with sglang engines inside.
2015
# need to initialize rollout manager first to calculate num_rollout
2116
rollout_manager, num_rollout_per_epoch = create_rollout_manager(args, pgs["rollout"], wandb_run_id=wandb_run_id)

0 commit comments

Comments
 (0)