Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion cosmos_transfer1/diffusion/inference/transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,12 @@ def parse_arguments() -> argparse.Namespace:
action="store_true",
help="Run the generation in benchmark mode. It means that generation will be rerun a few times and the average generation time will be shown.",
)
parser.add_argument(
"--benchmark_iterations",
type=int,
default=4,
help="Number of iterations to run in benchmark mode, default is 4, only used if benchmark is True",
)
cmd_args = parser.parse_args()

# Load and parse JSON input
Expand Down Expand Up @@ -340,7 +346,7 @@ def demo(cfg, control_inputs):
pipeline.region_definitions = region_definitions

# Generate videos in batch
num_repeats = 4 if cfg.benchmark else 1
num_repeats = cfg.benchmark_iterations if cfg.benchmark else 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would need to be (cfg.benchmark_iterations + 1) given the first run is ignored and used as a warm-up.

time_sum = 0
for i in range(num_repeats):
if cfg.benchmark and i > 0:
Expand Down
2 changes: 1 addition & 1 deletion cosmos_transfer1/diffusion/training/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@

import amp_C
import torch
from apex.multi_tensor_apply import multi_tensor_applier
from einops import rearrange
from megatron.core import parallel_state
from torch import Tensor
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import broadcast_object_list, get_process_group_ranks
from torch.distributed.utils import _verify_param_shape_across_processes

from apex.multi_tensor_apply import multi_tensor_applier
from cosmos_transfer1.diffusion.conditioner import BaseVideoCondition, DataType
from cosmos_transfer1.diffusion.diffusion.modules.res_sampler import COMMON_SOLVER_OPTIONS
from cosmos_transfer1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp
Expand Down
2 changes: 1 addition & 1 deletion cosmos_transfer1/utils/fused_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
# limitations under the License.

import torch

from apex.multi_tensor_apply import multi_tensor_applier

from cosmos_transfer1.utils import distributed, log


Expand Down