Skip to content

[BUG] _reshape_norm() leading to unexpected array shape when using 3D arrays with one channel #1338

@PandaGab

Description

@PandaGab

Describe the bug
In the file train.py, the function _reshape_norm() behave in an unexpected way.

Specifically, when input data has shape (1, H, W), the current preprocessing logic incorrectly expands it to (3, 1, H, W) instead of the expected (3, H, W). This is due to the following code:

if td.ndim == 2 or (td.ndim == 3 and td.shape[0] == 1):
    td = np.stack((td, 0*td, 0*td), axis=0)
elif td.ndim == 3 and td.shape[0] < 3:
    td = np.concatenate((td, 0*td[:1]), axis=0)

This condition treats 3D arrays with a single channel (e.g. shape (1, H, W)) the same way as 2D images, stacking them along a new first axis and unintentionally adding an extra dimension leading to a 4D array.

To Reproduce
Steps to reproduce the behavior:

  1. Prepare input data shaped (1, H, W) (e.g., a single-channel image with explicit channel dimension).
  2. Call _reshape_norm(data, channel_axis=0).
  3. Observe that the output shape becomes (3, 1, H, W) instead of (3, H, W).

Suggested fix

if td.ndim == 2:
    td = np.stack((td, 0 * td, 0 * td), axis=0)
elif td.ndim == 3 and td.shape[0] < 3:
    channel_to_add = 3 - td.shape[0]
    pad_width = [(0, channel_to_add)] + [(0, 0)] * 2
    td = np.pad(td, pad_width=pad_width, mode="constant", constant_values=0)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions