From ea724a17cd4a6e7ab898a8bd080cab7dcd7f2dfe Mon Sep 17 00:00:00 2001 From: FengWen Date: Tue, 11 Jun 2024 11:07:39 +0800 Subject: [PATCH] Fix comfy Ci error --- .../modules/oneflow/hijack_samplers.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/onediff_comfy_nodes/modules/oneflow/hijack_samplers.py b/onediff_comfy_nodes/modules/oneflow/hijack_samplers.py index 47bfae8b4..f6e8d6875 100644 --- a/onediff_comfy_nodes/modules/oneflow/hijack_samplers.py +++ b/onediff_comfy_nodes/modules/oneflow/hijack_samplers.py @@ -15,7 +15,7 @@ def calc_cond_batch_of(orig_func, model, conds, x_in, timestep, model_options): out_conds = [] out_counts = [] to_run = [] - + for i in range(len(conds)): out_conds.append(torch.zeros_like(x_in)) out_counts.append(torch.ones_like(x_in) * 1e-37) @@ -40,6 +40,7 @@ def calc_cond_batch_of(orig_func, model, conds, x_in, timestep, model_options): to_batch_temp.reverse() # to_batch = to_batch_temp[:1] to_batch = to_batch_temp + # free_memory = model_management.get_free_memory(x_in.device) # for i in range(1, len(to_batch_temp) + 1): # batch_amount = to_batch_temp[:len(to_batch_temp)//i] @@ -93,27 +94,37 @@ def calc_cond_batch_of(orig_func, model, conds, x_in, timestep, model_options): transformer_options["cond_or_uncond"] = cond_or_uncond[:] diff_model = model.diffusion_model - if create_patch_executor(PatchType.CachedCrossAttentionPatch).check_patch(diff_model): transformer_options["sigmas"] = timestep[0].item() patch_executor = create_patch_executor(PatchType.UNetExtraInputOptions) transformer_options["_attn2"] = patch_executor.get_patch(diff_model)["attn2"] else: transformer_options["sigmas"] = timestep + # transformer_options["sigmas"] = timestep - c['transformer_options'] = transformer_options + if 'model_function_wrapper' in model_options: output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) else: output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) - for o in range(batch_chunks): cond_index = cond_or_uncond[o] - out_conds[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] - out_counts[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] - + a = area[o] + if a is None: + out_conds[cond_index] += output[o] * mult[o] + out_counts[cond_index] += mult[o] + else: + out_c = out_conds[cond_index] + out_cts = out_counts[cond_index] + dims = len(a) // 2 + for i in range(dims): + out_c = out_c.narrow(i + 2, a[i + dims], a[i]) + out_cts = out_cts.narrow(i + 2, a[i + dims], a[i]) + out_c += output[o] * mult[o] + out_cts += mult[o] + for i in range(len(out_conds)): out_conds[i] /= out_counts[i]