Skip to content

Conversation

@mdabek-nvidia
Copy link
Collaborator

Category:

New feature

Description:

Torchvision objective API.
The implementation is currently limited to few selected operators and composing them into a single pipeline. The operators has been selected to enable TIMM pipeline implementation. Input data support is limited to PIL Images and Pytorch tensors.

Additional information:

Affected modules and functionalities:

Key points relevant for the review:

Tests:

  • Existing tests apply
  • New tests added
    • Python tests
    • GTests
    • Benchmark
    • Other
  • N/A

Checklist

Documentation

  • Existing documentation applies
  • Documentation updated
    • Docstring
    • Doxygen
    • RST
    • Jupyter
    • Other
  • N/A

There will be additional PR with documentation

DALI team only

Requirements

  • Implements new requirements
  • Affects existing requirements
  • N/A

REQ IDs: N/A

JIRA TASK: DALI-4309

- Selected operators, that enable implementing TIMM pipeline
- Unit tests
- Test scripts

Signed-off-by: Marek Dabek <[email protected]>
@greptile-apps
Copy link

greptile-apps bot commented Dec 17, 2025

Greptile Summary

This PR introduces a torchvision-compatible API for DALI, enabling users to use familiar torchvision transforms with DALI's accelerated execution. The implementation covers essential transforms (Resize, CenterCrop, ColorJitter, Normalize, Pad, GaussianBlur, and flip operations) that can be composed into pipelines.

Key Changes:

  • Added nvidia.dali.experimental.torchvision module with v2 API structure matching torchvision's interface
  • Implemented Compose class that orchestrates DALI pipelines for both PIL Images (HWC layout) and torch.Tensors (CHW layout)
  • Each transform wraps corresponding DALI operators while maintaining torchvision-compatible parameter names and behavior
  • Added autograph support for the new module to enable conditional control flow
  • Comprehensive test suite covering all transforms with CPU/GPU device support

Implementation Details:

  • The Compose class dynamically builds DALI pipelines based on input type (PIL/torch.Tensor) on first call
  • Transform operators follow a consistent pattern: parameter validation in __init__, DALI operator application in __call__
  • CPU/GPU execution controlled via device parameter across all transforms
  • Input validation ensures compatibility with torchvision's expected parameter ranges

Limitations:

  • Currently limited to PIL Images and PyTorch tensors (as stated in PR description)
  • ToTensor only works as the last operation in pipeline
  • Some interpolation modes use approximations (e.g., BOX→LINEAR, HAMMING→GAUSSIAN)
  • Batch processing with ToTensor raises NotImplementedError

Confidence Score: 4/5

  • This PR is generally safe to merge with one potential indexing issue that should be addressed
  • The implementation is well-structured with proper parameter validation, comprehensive test coverage, and follows established DALI patterns. The main concern is a potential IndexError in compose.py:93 when checking tensor dimensions. The code successfully integrates with existing DALI infrastructure (autograph, conditionals) and includes extensive tests validating behavior against torchvision. Documentation is noted as coming in a separate PR.
  • Pay close attention to dali/python/nvidia/dali/experimental/torchvision/v2/compose.py line 93 for the indexing issue

Important Files Changed

Filename Overview
dali/python/nvidia/dali/_conditionals.py Added nvidia.dali.experimental.torchvision to autograph conversion modules to enable conditional control flow support
dali/python/nvidia/dali/experimental/torchvision/v2/compose.py Core Compose implementation with pipeline orchestration for HWC (PIL) and CHW (torch.Tensor) layouts
dali/python/nvidia/dali/experimental/torchvision/v2/resize.py Resize transform with torchvision-compatible size/max_size logic and interpolation mode mapping
dali/python/nvidia/dali/experimental/torchvision/v2/pad.py Pad transform with multiple padding modes (constant, edge, reflect, symmetric) using DALI's slice operator

Sequence Diagram

sequenceDiagram
    participant User
    participant Compose
    participant PipelineHWC/CHW
    participant DALI Pipeline
    participant Transform Ops
    participant Output

    User->>Compose: __call__(PIL Image or torch.Tensor)
    Compose->>Compose: _build_pipeline() on first call
    alt PIL Image Input
        Compose->>PipelineHWC: Create pipeline with HWC layout
    else torch.Tensor Input
        Compose->>PipelineCHW: Create pipeline with CHW layout
    end
    
    Compose->>PipelineHWC/CHW: run(data_input)
    
    alt PIL Image
        PipelineHWC/CHW->>PipelineHWC/CHW: Convert PIL to torch tensor
    else torch.Tensor
        PipelineHWC/CHW->>PipelineHWC/CHW: Add batch dimension if needed
    end
    
    PipelineHWC/CHW->>DALI Pipeline: _pipeline_function(op_list)
    DALI Pipeline->>DALI Pipeline: fn.external_source(input_data)
    
    loop For each transform in op_list
        DALI Pipeline->>Transform Ops: Apply transform (Resize, CenterCrop, etc.)
        Transform Ops->>Transform Ops: __call__(data_input)
        alt GPU device specified
            Transform Ops->>Transform Ops: data_input.gpu()
        end
        Transform Ops->>DALI Pipeline: Return transformed DataNode
    end
    
    DALI Pipeline->>PipelineHWC/CHW: Return output
    PipelineHWC/CHW->>PipelineHWC/CHW: to_torch_tensor()
    
    alt ToTensor in pipeline
        PipelineHWC/CHW->>Output: Return torch.Tensor
    else PIL Image mode
        PipelineHWC/CHW->>PipelineHWC/CHW: Convert to PIL Image
        PipelineHWC/CHW->>Output: Return PIL Image
    else Batch mode
        PipelineHWC/CHW->>Output: Return list or torch.Tensor
    end
    
    Output->>User: Transformed output
Loading

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (7)

  1. dali/python/nvidia/dali/experimental/torchvision/v2/compose.py, line 137 (link)

    syntax: typo: "cound" should be "count"

  2. dali/python/nvidia/dali/experimental/torchvision/v2/compose.py, line 138 (link)

    syntax: typo: "unsable" should be "unusable"

  3. dali/python/nvidia/dali/experimental/torchvision/v2/compose.py, line 59 (link)

    syntax: typo: "runst" should be "runs"

  4. dali/python/nvidia/dali/experimental/torchvision/v2/compose.py, line 186 (link)

    logic: parentheses mismatch: append() call has wrong parentheses placement, causing layout to be passed as second argument to range() instead of _convert_tensor_to_image()

  5. dali/python/nvidia/dali/experimental/torchvision/v2/resize.py, line 40 (link)

    syntax: typo: "resizedimage" should be "resized image" (missing space)

  6. dali/python/nvidia/dali/experimental/torchvision/v2/resize.py, line 108-113 (link)

    logic: logic error: condition will never be false - checking len(self.size) == 2 after already confirming not isinstance(self.size, (tuple, list)) is contradictory

  7. dali/python/nvidia/dali/experimental/torchvision/v2/pad.py, line 29 (link)

    syntax: typo: "boarder_type" should be "border_type"

21 files reviewed, 7 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Marek Dabek <[email protected]>
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (6)

  1. dali/python/nvidia/dali/experimental/torchvision/v2/compose.py, line 48 (link)

    logic: This assumes tensor_or_tl has a __len__ method in the else branch, but not all inputs may support this (e.g., if it's not a tuple but a single TensorList). Consider checking type before calling len().

  2. dali/python/nvidia/dali/experimental/torchvision/v2/pad.py, line 56-57 (link)

    logic: Missing validation: when padding is a sequence, its length should be either 2 or 4 (per torchvision API). Length 2 means left/right and top/bottom, while length 4 is left, top, right, bottom. This unpacking will fail for sequences of length 2.

  3. dali/python/nvidia/dali/experimental/torchvision/v2/pad.py, line 65-66 (link)

    style: Layout check is fragile - comparing against encoded byte of 'C' to detect CHW layout. Consider using string comparison or DALI's layout constants for better readability and maintainability.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

  4. dali/python/nvidia/dali/experimental/torchvision/v2/resize.py, line 76-77 (link)

    style: This validation is incomplete - when size > max_size, should raise a more descriptive error message explaining why the configuration is invalid.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

  5. dali/python/nvidia/dali/experimental/torchvision/v2/compose.py, line 117-118 (link)

    style: This restriction prevents ToTensor from being used mid-pipeline. Consider whether this should raise a more specific exception class or provide guidance on workarounds.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

  6. dali/python/nvidia/dali/experimental/torchvision/v2/color.py, line 138 (link)

    style: Comment asks "what if it is HSV?" but no validation is performed. If non-RGB color spaces are unsupported, should validate and raise an error.

21 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile

elif in_tensor.shape[channels] == 3:
mode = "RGB"
else:
raise ValueError(f"Unsupported channels count: {channels}")
Copy link
Contributor

@mzient mzient Dec 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise ValueError(f"Unsupported channels count: {channels}")
raise ValueError(f"Unsupported channel count: {channels}")

or

Suggested change
raise ValueError(f"Unsupported channels count: {channels}")
raise ValueError(f"Unsupported number of channels: {channels}")

also, I'd add some info like "Expected 1 or 3." to make users' life easier.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

return _to_torch_tensor(tensor_or_tl)


class Compose(Pipeline):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's not going to be used as a regular pipeline (?), why inheritance? Composition would work much better here, since you'd be able to create the inner pipeline with @pipeline_def and the conditional execution would be handled for you.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


def test_compose_tensor():
test_tensor = make_test_tensor(shape=(5, 5, 5, 3))
dali_out = Compose([RandomHorizontalFlip(p=1.0)], batch_size=test_tensor.shape[0])(test_tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should perpetuate this pattern even in the tests - creating an ad-hoc pipeline like this is extremely expensive.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done



"""
TODO: DALI ColorJitter does not work on tensors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by that? That it doesn't work with HWC layout?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HWC layout works, but CHW does not.
The API transforms PIL image to HWC, but expects tensor input to be CHW, so passing tensor will not work. This is something that needs to be solved at some point if the API would need to work efficiently with batches.



def build_centercrop_transform(
size: Union[int, Sequence[int]], batch_size: int = 1, device: Literal["cpu", "gpu"] = "cpu"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: I think we can use the | syntax already.

size: int | Sequence[int]

def __init__(
self,
op_list: List[Callable[..., Union[Sequence[_DataNode], _DataNode]]],
batch_size: int = 1,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A food for thought:
Generally, batch_size argument in DALI Pipeline ctor denotes max_batch_size rather than actual batch size. Handling batches with smaller size than max does not introduce observable memory overhead. I wonder, would it be a good idea to set the batch_size here to some value (it's a good question, what kind of value), so that the user in most cases won't need to set it? Like 256 from top of my head?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And the same might apply to num threads, like nproc-1 or something?

batch_size: int = 1,
num_threads: int = 1,
*args,
**kwargs,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those args and kwargs are the arguments I'd pass to DALI pipeline? If so it might be good to provide this info in documentation. Maybe even call it such? **dali_pipeline_args.

Also, I'm not sure if having *args after defined keyword arguments is a good pattern. Do we have any specific usages for this? Maybe we can just omit *args and have keyword args only?

Comment on lines +153 to +154
if not isinstance(data_input, (Image.Image, torch.Tensor)):
raise TypeError(f"input should be PIL Image or torch.Tensor. Got {type(data_input)}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to also allow numpy array and nvimgcodec image?

Copy link
Collaborator Author

@mdabek-nvidia mdabek-nvidia Dec 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The vanilla Torchvision currently supports only the following formats:

Images as pure tensors, Image or PIL image

Videos as Video

Axis-aligned and rotated bounding boxes as BoundingBoxes

Segmentation and detection masks as Mask

KeyPoints as KeyPoints.

I chose to support images and tensor for now, since these seemed to be supported out of the box or with minimal effort by DALI.

if self.device == "gpu":
data = data.gpu()

return fn.crop(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it work on bounding boxes? Shouldn't there be a detection of the contents and rescaling the bbox to a cropped coordinate?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, bounding boxes need separate handling and are not supported yet.

Signed-off-by: Marek Dabek <[email protected]>

return output

def get_layout(self) -> str: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

def get_layout(self) -> str: ...

def get_channel_reverse_idx(self) -> int: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (4)

  1. dali/python/nvidia/dali/experimental/torchvision/v2/compose.py, line 148-151 (link)

    logic: ValueError is instantiated but never raised, silently failing validation.

  2. dali/python/nvidia/dali/experimental/torchvision/v2/compose.py, line 208-211 (link)

    logic: ValueError is instantiated but never raised, silently failing validation.

  3. dali/python/nvidia/dali/experimental/torchvision/v2/compose.py, line 260 (link)

    syntax: Typo in error message: "torch.Tesors" should be "torch.Tensors"

  4. dali/python/nvidia/dali/experimental/torchvision/v2/pad.py, line 65-66 (link)

    syntax: Comment says "CWH" but should be "CHW" (channels-height-width)

21 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Marek Dabek <[email protected]>
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. dali/python/nvidia/dali/experimental/torchvision/v2/compose.py, line 93-94 (link)

    logic: Indexing with -4 may fail if output has fewer than 4 dimensions. Add bounds check or handle IndexError.

21 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

or validated_param[1] < 0
or validated_param[0] > validated_param[1]
):
raise ValueError("Parameters must be > 0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick, but maybe propagate a name of the parameter here, and report that it also should form a correct range if sequence is provided.

So we would write:

Suggested change
raise ValueError("Parameters must be > 0")
raise ValueError(f"Parameter {name} must be > 0, got {param}.")

or

Suggested change
raise ValueError("Parameters must be > 0")
raise ValueError("Parameters {name} must form a range, got {param}")

self.saturation = self._create_validated_param(saturation)

if isinstance(hue, float):
self.hue = (-hue, hue)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hue hue :D

self.hue = hue

if self.hue is not None and (len(self.hue) != 2 or self.hue[0] < -0.5 or self.hue[1] > 0.5):
raise ValueError(f"hue values should be between (-0.5, 0.5) but got {self.hue}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this looks like a list, but the 0.5 is supposed to be inclusive:

Suggested change
raise ValueError(f"hue values should be between (-0.5, 0.5) but got {self.hue}")
raise ValueError(f"hue values should be between [-0.5, 0.5] but got {self.hue}")

Comment on lines +46 to +47
if isinstance(kernel_size, int) and kernel_size <= 0:
raise ValueError("Kernel size value should be an odd and positive number")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no check for odd.


mean (sequence) – Sequence of means for each channel.
std (sequence) – Sequence of standard deviations for each channel.
inplace (bool,optional) – Bool to make this operation in-place.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just hard error on this?

Comment on lines +143 to +147
if self.size is None:
if orig_h > orig_w:
target_w = (self.max_size * orig_w) / orig_h
else:
target_h = (self.max_size * orig_h) / orig_w
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are missing some calculation for when size is an int, size * max(height, width) / min(height, width) > max_size (the longer edge crosses the max_size threshold) -> this means that the image is max_size on the longer edge, and max_size * min(height, width) / max(height, width)?

Also, you are setting target_w or target_h here, and only using target_h in the if self.mode == "resize_shorter":

Looks like we are missing some cases here unless I'm mistaken.

loop_images_test(t, td)


@params((480, 512), (100, 124), (None, 512), (1024, 512), ([256, 256], 512))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have images that after scaling the shorter edge to resize, the max_size threshold would be crossed?

Comment on lines +33 to +52
def _to_torch_tensor(tensor_or_tl: Union[TensorListGPU, TensorListCPU]) -> torch.Tensor:
if isinstance(tensor_or_tl, (TensorListGPU, TensorListCPU)):
dali_tensor = tensor_or_tl.as_tensor()
else:
dali_tensor = tensor_or_tl

return torch.from_dlpack(dali_tensor)


def to_torch_tensor(tensor_or_tl: Union[tuple, TensorListGPU, TensorListCPU]) -> torch.Tensor:

if isinstance(tensor_or_tl, tuple) and len(tensor_or_tl) > 1:
tl = []
for elem in tensor_or_tl:
tl.append(_to_torch_tensor(elem))
return tuple(tl)
else:
if len(tensor_or_tl) == 1:
tensor_or_tl = tensor_or_tl[0]
return _to_torch_tensor(tensor_or_tl)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a bit of a doc here? Is it just to convert dali pipeline output to a Torch Tensor?
The batch of 1 case is not reflected in the signature of _to_torch_tensor.

return output

if isinstance(output, tuple):
output = self._convert_tensor_to_image(output[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because it's always one output, right?

# We need to convert tensor to CPU, otherwise it will be unsable
return Image.fromarray(in_tensor.cpu().numpy(), mode=mode)

def run(self, data_input):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be cleaner if in both of the Pipelines we had implemented

  • convert_input_to_dali_batch
  • convert_dali_batch_to_output

or something similar, and just call them in the base run()?

Btw, I'm a bit lost what the input and outputs are expected to be. Is it just a torch tensor or PIL image that can have leading batch dimension?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants