Skip to content

This method not work with torch.compile #26

@Trgtuan10

Description

@Trgtuan10

I want to combine Distrifuser with torch compile but it doesn't work
pipeline.pipeline.unet = torch.compile(pipeline.pipeline.unet)

Here is error:

[00:00<?, ?it/s][rank0]:W0202 03:40:07.939000 7752 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] Graph break from Tensor.item(), consider setting:
[rank0]:W0202 03:40:07.939000 7752 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] torch._dynamo.config.capture_scalar_outputs = True
[rank0]:W0202 03:40:07.939000 7752 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] or:
[rank0]:W0202 03:40:07.939000 7752 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
[rank0]:W0202 03:40:07.939000 7752 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] to include these operations in the captured graph.
[rank0]:W0202 03:40:07.939000 7752 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0]
[rank0]:W0202 03:40:07.939000 7752 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] Graph break: from user code at:
[rank0]:W0202 03:40:07.939000 7752 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] File "/workspace/SD-parrallel/distrifuser/distrifuser/models/distri_sdxl_unet_pp.py", line 94, in forward
[rank0]:W0202 03:40:07.939000 7752 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] static_inputs["timestep"][b] = timestep.item()
[rank0]:W0202 03:40:07.939000 7752 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0]
[rank0]:W0202 03:40:07.939000 7752 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0]
[rank1]:W0202 03:40:08.147000 7753 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] Graph break from Tensor.item(), consider setting:
[rank1]:W0202 03:40:08.147000 7753 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] torch._dynamo.config.capture_scalar_outputs = True
[rank1]:W0202 03:40:08.147000 7753 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] or:
[rank1]:W0202 03:40:08.147000 7753 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
[rank1]:W0202 03:40:08.147000 7753 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] to include these operations in the captured graph.
[rank1]:W0202 03:40:08.147000 7753 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0]
[rank1]:W0202 03:40:08.147000 7753 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] Graph break: from user code at:
[rank1]:W0202 03:40:08.147000 7753 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] File "/workspace/SD-parrallel/distrifuser/distrifuser/models/distri_sdxl_unet_pp.py", line 94, in forward
[rank1]:W0202 03:40:08.147000 7753 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0] static_inputs["timestep"][b] = timestep.item()
[rank1]:W0202 03:40:08.147000 7753 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0]
[rank1]:W0202 03:40:08.147000 7753 site-packages/torch/_dynamo/variables/tensor.py:776] [0/0]

Any idea to fix this?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions