Skip to content

Commit d2eeff6

Browse files
committed
fix for rng input
Signed-off-by: Masahiro Tanaka <[email protected]>
1 parent e5bf185 commit d2eeff6

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

deepspeed/compile/backend.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -156,13 +156,14 @@ def make_fw_graph(gm, sample_inputs):
156156

157157
param_manager[graph_id] = DSGraphParamManager(gm.graph, real_inputs, param_indices)
158158

159+
real_inputs_with_rng = real_inputs + sample_inputs[len(real_inputs):]
159160
run_opt_passes(
160161
opt_passes=next_passes,
161162
gm=gm,
162163
graph_id=graph_id,
163164
graph_order=graph_order,
164165
profiling_results=profiling_results,
165-
create_inputs_fn=lambda: real_inputs,
166+
create_inputs_fn=lambda: real_inputs_with_rng,
166167
mem_budget=.0, # unused
167168
param_manager=param_manager,
168169
bwd=False,

0 commit comments

Comments
 (0)