Skip to content

Commit

Permalink
chore: updates
Browse files Browse the repository at this point in the history
  • Loading branch information
peri044 committed Feb 5, 2025
1 parent 7eb327a commit 763346f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 246 deletions.
234 changes: 0 additions & 234 deletions examples/dynamo/flux.py

This file was deleted.

33 changes: 21 additions & 12 deletions examples/dynamo/torch_export_flux_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@
This example illustrates the state of the art model `FLUX.1-dev <https://huggingface.co/black-forest-labs/FLUX.1-dev>`_ optimized using
Torch-TensorRT.
**FLUX.1 [dev]** is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions. It is an open-weight, guidance-distilled model for non-commercial applications
**FLUX.1 [dev]** is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions. It is an open-weight, guidance-distilled model for non-commercial applications.
Install the following dependencies before compilation
.. code-block:: python
pip install -r requirements.txt
There are different components of the FLUX.1-dev pipeline such as `transformer`, `vae`, `text_encoder`, `tokenizer` and `scheduler`. In this example,
we demonstrate optimizing the `transformer` component of the model (which typically consumes >95% of the e2el diffusion latency)
There are different components of the ``FLUX.1-dev`` pipeline such as ``transformer``, ``vae``, ``text_encoder``, ``tokenizer`` and ``scheduler``. In this example,
we demonstrate optimizing the ``transformer`` component of the model (which typically consumes >95% of the e2e diffusion latency)
"""

# %%
Expand All @@ -30,7 +31,7 @@
# Define the FLUX-1.dev model
# -----------------------------
# Load the ``FLUX-1.dev`` pretrained pipeline using ``FluxPipeline`` class.
# ``FluxPipeline`` includes all the different components such as `transformer`, `vae`, `text_encoder`, `tokenizer` and `scheduler` necessary
# ``FluxPipeline`` includes different components such as ``transformer``, ``vae``, ``text_encoder``, ``tokenizer`` and ``scheduler`` necessary
# to generate an image. We load the weights in ``FP16`` precision using ``torch_dtype`` argument
DEVICE = "cuda:0"
pipe = FluxPipeline.from_pretrained(
Expand All @@ -44,14 +45,14 @@


# %%
# Export the backbone using ``torch.export``
# Export the backbone using torch.export
# --------------------------------------------------
# Define the dummy inputs and their respective dynamic shapes. We export the transformer backbone with dynamic shapes with a batch size of 2
# due to 0/1 specialization <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ&tab=t.0#heading=h.ez923tomjvyk>`_
# Define the dummy inputs and their respective dynamic shapes. We export the transformer backbone with dynamic shapes with a ``batch_size=2``
# due to `0/1 specialization <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ&tab=t.0#heading=h.ez923tomjvyk>`_
batch_size = 2
BATCH = torch.export.Dim("batch", min=1, max=2)
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512)
# This particular min, max values are recommended by torch dynamo during the export of the model.
# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model.
# To see this recommendation, you can try exporting using min=1, max=4096
IMG_ID = torch.export.Dim("img_id", min=3586, max=4096)
dynamic_shapes = {
Expand Down Expand Up @@ -92,18 +93,24 @@
# %%
# Torch-TensorRT compilation
# ---------------------------
# We enable FP32 matmul accumulation using ``use_fp32_acc=True`` to preserve accuracy with the original Pytorch model.
# Since this is a 12 billion parameter model, it takes around 20-30 min on H100 GPU
# .. note::
# The compilation requires a GPU with high memory (> 80GB) since TensorRT is storing the weights in FP32 precision. This is a known issue and will be resolved in the future.
#
#
# We enable ``FP32`` matmul accumulation using ``use_fp32_acc=True`` to ensure accuracy is preserved by introducing cast to ``FP32`` nodes.
# We also enable explicit typing to ensure TensorRT respects the datatypes set by the user which is a requirement for FP32 matmul accumulation.
# Since this is a 12 billion parameter model, it takes around 20-30 min to compile on H100 GPU. The model is completely convertible and results in
# a single TensorRT engine.
trt_gm = torch_tensorrt.dynamo.compile(
ep,
inputs=dummy_inputs,
enabled_precisions={torch.float32},
truncate_double=True,
min_block_size=1,
debug=True,
use_fp32_acc=True,
use_explicit_typing=True,
)

# %%
# Post Processing
# ---------------------------
Expand Down Expand Up @@ -138,4 +145,6 @@ def generate_image(pipe, prompt, image_name):

# %%
# The generated image is as shown below
# .. image:: dog_code.png
#
# .. image:: dog_code.png
#

0 comments on commit 763346f

Please sign in to comment.