|
9 | 9 |
|
10 | 10 | **FLUX.1 [dev]** is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions. It is an open-weight, guidance-distilled model for non-commercial applications.
|
11 | 11 |
|
12 |
| -Install the following dependencies before compilation |
| 12 | +To run this demo, you need to have access to Flux model (request for access if you do not have it already on the `FLUX.1-dev <https://huggingface.co/black-forest-labs/FLUX.1-dev>`_ page) and install the following dependencies |
13 | 13 |
|
14 | 14 | .. code-block:: python
|
15 | 15 |
|
16 |
| - pip install sentencepiece=="0.2.0" transformers=="4.48.2" accelerate=="1.3.0" diffusers=="0.32.2" |
| 16 | + pip install sentencepiece=="0.2.0" transformers=="4.48.2" accelerate=="1.3.0" diffusers=="0.32.2" protobuf=="5.29.3" |
17 | 17 |
|
18 | 18 | There are different components of the ``FLUX.1-dev`` pipeline such as ``transformer``, ``vae``, ``text_encoder``, ``tokenizer`` and ``scheduler``. In this example,
|
19 | 19 | we demonstrate optimizing the ``transformer`` component of the model (which typically consumes >95% of the e2e diffusion latency)
|
|
38 | 38 | "black-forest-labs/FLUX.1-dev",
|
39 | 39 | torch_dtype=torch.float16,
|
40 | 40 | )
|
41 |
| -pipe.to(DEVICE).to(torch.float16) |
| 41 | + |
42 | 42 | # Store the config and transformer backbone
|
43 | 43 | config = pipe.transformer.config
|
44 |
| -backbone = pipe.transformer |
45 |
| - |
| 44 | +backbone = pipe.transformer.to(DEVICE) |
46 | 45 |
|
47 | 46 | # %%
|
48 | 47 | # Export the backbone using torch.export
|
|
63 | 62 | "txt_ids": {0: SEQ_LEN},
|
64 | 63 | "img_ids": {0: IMG_ID},
|
65 | 64 | "guidance": {0: BATCH},
|
| 65 | + "joint_attention_kwargs": {}, |
| 66 | + "return_dict": None, |
66 | 67 | }
|
67 | 68 | # The guidance factor is of type torch.float32
|
68 | 69 | dummy_inputs = {
|
|
79 | 80 | "txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE),
|
80 | 81 | "img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE),
|
81 | 82 | "guidance": torch.tensor([1.0, 1.0], dtype=torch.float32).to(DEVICE),
|
| 83 | + "joint_attention_kwargs": {}, |
| 84 | + "return_dict": False, |
82 | 85 | }
|
83 | 86 | # This will create an exported program which is going to be compiled with Torch-TensorRT
|
84 | 87 | ep = _export(
|
|
116 | 119 | # ---------------------------
|
117 | 120 | # Release the GPU memory occupied by the exported program and the pipe.transformer
|
118 | 121 | # Set the transformer in the Flux pipeline to the Torch-TRT compiled model
|
119 |
| -backbone.to("cpu") |
| 122 | + |
120 | 123 | del ep
|
| 124 | +backbone.to("cpu") |
| 125 | +pipe.to(DEVICE) |
| 126 | +torch.cuda.empty_cache() |
121 | 127 | pipe.transformer = trt_gm
|
122 | 128 | pipe.transformer.config = config
|
123 | 129 |
|
|
0 commit comments