16
16
from transformers import PreTrainedTokenizer
17
17
18
18
from .datasets import CacheDataset
19
- from ..models .cache import make_prompt_cache , KVCache
19
+ from ..models .cache import KVCache , make_prompt_cache
20
20
21
21
22
22
def reset_prompt_cache (cache ):
@@ -74,7 +74,7 @@ class TrainingArgs:
74
74
default = False ,
75
75
metadata = {"help" : "Use gradient checkpointing to reduce memory use." },
76
76
)
77
- seq_step_size : Optional [int ] = field (
77
+ seq_step_size : Optional [int ] = field (
78
78
default = None ,
79
79
metadata = {"help" : "The examples are processsed in seq_step_size chunks." },
80
80
)
@@ -196,7 +196,7 @@ def evaluate(
196
196
):
197
197
seq_length = batch [0 ].shape [1 ]
198
198
for s in range (0 , seq_length , seq_step_size ):
199
- local_batch = (batch [0 ][:, s : s + seq_step_size ], batch [1 ])
199
+ local_batch = (batch [0 ][:, s : s + seq_step_size ], batch [1 ])
200
200
losses , toks = loss (model , * local_batch , cache )
201
201
all_losses += losses * toks
202
202
ntokens += toks
@@ -273,7 +273,7 @@ def seq_split_step(batch):
273
273
seq_length = batch [0 ].shape [1 ]
274
274
grad_accum = None
275
275
for s in range (0 , seq_length , seq_step_size ):
276
- local_batch = (batch [0 ][:, s : s + seq_step_size ], batch [1 ])
276
+ local_batch = (batch [0 ][:, s : s + seq_step_size ], batch [1 ])
277
277
(lvalue , toks ), grad = loss_value_and_grad (model , * local_batch , cache )
278
278
prev_n_tokens = n_tokens
279
279
losses += toks * lvalue
@@ -284,8 +284,9 @@ def seq_split_step(batch):
284
284
else :
285
285
scale_g = toks / n_tokens
286
286
scale_acc = prev_n_tokens / n_tokens
287
- grad_accum = tree_map (lambda g , acc : scale_g * g + scale_acc * acc , grad , grad_accum )
288
-
287
+ grad_accum = tree_map (
288
+ lambda g , acc : scale_g * g + scale_acc * acc , grad , grad_accum
289
+ )
289
290
290
291
# Let go of the prompt cache before the last eval
291
292
if s + seq_step_size >= seq_length :
0 commit comments