diff --git a/pytorch360convert/pytorch360convert.py b/pytorch360convert/pytorch360convert.py index e65a3d2..a7b25d1 100644 --- a/pytorch360convert/pytorch360convert.py +++ b/pytorch360convert/pytorch360convert.py @@ -514,9 +514,10 @@ def sample_cubefaces( # coor_x, coor_y and tp. cube_faces_mod = cube_faces.clone() - face_w = cube_faces_mod.shape[1] + d = 1 if cube_faces.dim() == 4 else 2 + face_w = cube_faces_mod.shape[d] cube_h = torch.cat( - [cube_faces_mod[i] for i in range(6)], dim=1 + [cube_faces_mod[i] for i in range(6)], dim=d ) # [face_w, face_w*6, C] # We need to map (tp, coor_y, coor_x) -> coordinates in cube_h @@ -752,38 +753,40 @@ def c2e( # Ensure input is in HWC format for processing if channels_first: if cube_format == "list" and isinstance(cubemap, (list, tuple)): - cubemap = [r.permute(1, 2, 0) for r in cubemap] + cubemap = [_nchw2nhwc(r) for r in cubemap] elif cube_format == "dict" and torch.jit.isinstance( cubemap, Dict[str, torch.Tensor] ): - cubemap = {k: v.permute(1, 2, 0) for k, v in cubemap.items()} # type: ignore + cubemap = {k: _nchw2nhwc(v) for k, v in cubemap.items()} # type: ignore elif cube_format in ["horizon", "dice"] and isinstance(cubemap, torch.Tensor): - cubemap = cubemap.permute(1, 2, 0) + cubemap = _nchw2nhwc(cubemap) else: raise NotImplementedError("unknown cube_format and cubemap type") if cube_format == "horizon" and isinstance(cubemap, torch.Tensor): - assert cubemap.dim() == 3 + assert cubemap.dim() in [3, 4] cube_h = cubemap elif cube_format == "list" and isinstance(cubemap, (list, tuple)): - assert all([r.dim() == 3 for r in cubemap]) + assert all([r.dim() in [3, 4] for r in cubemap]) cube_h = cube_list2h(cubemap) elif cube_format == "dict" and torch.jit.isinstance( cubemap, Dict[str, torch.Tensor] ): - assert all(v.dim() == 3 for k, v in cubemap.items()) # type: ignore[union-attr] + assert all(v.dim() in [3, 4] for k, v in cubemap.items()) # type: ignore[union-attr] cube_h = cube_dict2h(cubemap) # type: ignore[arg-type] elif cube_format == "dice" and isinstance(cubemap, torch.Tensor): - assert len(cubemap.shape) == 3 + assert len(cubemap.shape) in [3, 4] cube_h = cube_dice2h(cubemap) else: raise NotImplementedError("unknown cube_format and cubemap type") assert isinstance(cube_h, torch.Tensor) # Mypy wants this + # cube_h -> B, H, W, C + d = 1 if cube_h.dim() == 3 else 2 device = cube_h.device dtype = cube_h.dtype - face_w = cube_h.shape[0] - assert cube_h.shape[1] == face_w * 6 + face_w = cube_h.shape[d - 1] + assert cube_h.shape[d] == face_w * 6 h = face_w * 2 if h is None else h w = face_w * 4 if w is None else w @@ -792,7 +795,7 @@ def c2e( u, v = uv[..., 0], uv[..., 1] cube_faces = torch.stack( - torch.split(cube_h, face_w, dim=1), dim=0 + torch.split(cube_h, face_w, dim=d), dim=0 ) # [6, face_w, face_w, C] tp = equirect_facetype(h, w, device=device, dtype=dtype)