Skip to content

Commit 7ad9d63

Browse files
authored
[bug fix]vit training get wrong index
1 parent 18a3a1d commit 7ad9d63

File tree

1 file changed

+2
-12
lines changed

1 file changed

+2
-12
lines changed

ernie/callbacks/vit_trainable_callback.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -131,23 +131,13 @@ def extract_feature_wrapper(inner_self, images, grid_thw, second_fwd=False):
131131
not model.balanced_image_preprocess
132132
), "不支持balanced_image_preprocess"
133133

134-
def limao_huan_taizi_hook(
135-
p,
136-
):
137-
def hook(g):
138-
logger.info(f"limao hook called -- {p.name}")
139-
p._clear_dataptr()
140-
141-
return hook
142-
143134
def _prepare_pipeline_inputs_func_wrapper(inner_self, data):
144135
def wrap(micro_data):
145136
inputs, labels = micro_data
146-
if args.pipeline_parallel_rank == 0 and inputs[2] is not None:
147-
fea = inputs[2]
137+
if args.pipeline_parallel_rank == 0 and inputs[3] is not None:
138+
fea = inputs[3]
148139
self.images_features.append(fea)
149140
fea.stop_gradient = False
150-
# fea.register_hook(limao_huan_taizi_hook(fea))
151141
return (inputs, labels)
152142

153143
return (wrap(i) for i in ori_prepare_pipeline_inputs_func(data))

0 commit comments

Comments
 (0)