Skip to content

Commit c4b8ece

Browse files
committed
feat(torch): add dataloadaer collator
1 parent 79fdf23 commit c4b8ece

File tree

5 files changed

+73
-4
lines changed

5 files changed

+73
-4
lines changed

src/python/pose_format/numpy/pose_body.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ class NumPyPoseBody(PoseBody):
3737
confidence array of the pose keypoints.
3838
"""
3939

40-
tensor_reader = 'unpack_numpy' """Specifies the method name for unpacking a numpy array (Value: 'unpack_numpy')."""
40+
"""Specifies the method name for unpacking a numpy array (Value: 'unpack_numpy')."""
41+
tensor_reader = 'unpack_numpy'
4142

4243
def __init__(self, fps: float, data: Union[ma.MaskedArray, np.ndarray], confidence: np.ndarray):
4344
"""

src/python/pose_format/tensorflow/pose_body.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ class TensorflowPoseBody(PoseBody):
2828
confidence : tf.Tensor
2929
The confidence scores for the pose data.
3030
"""
31-
tensor_reader = 'unpack_tensorflow' """str: The method used to read the tensor data. (Type: str)"""
31+
32+
"""str: The method used to read the tensor data. (Type: str)"""
33+
tensor_reader = 'unpack_tensorflow'
3234

3335
def __init__(self, fps: float, data: Union[MaskedTensor, tf.Tensor], confidence: tf.Tensor):
3436
"""
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from typing import Dict, List, Tuple, Union
2+
3+
import numpy as np
4+
import torch
5+
from pose_format.torch.masked import MaskedTensor, MaskedTorch
6+
7+
8+
def pad_tensors(batch: List[Union[torch.Tensor, MaskedTensor]], pad_value=0):
9+
datum = batch[0]
10+
torch_cls = MaskedTorch if isinstance(datum, MaskedTensor) else torch
11+
12+
max_len = max(len(t) for t in batch)
13+
if max_len == 1:
14+
return torch_cls.stack(batch, dim=0)
15+
16+
new_batch = []
17+
for tensor in batch:
18+
missing = list(tensor.shape)
19+
missing[0] = max_len - tensor.shape[0]
20+
21+
if missing[0] > 0:
22+
padding_tensor = torch.full(missing, fill_value=pad_value, dtype=tensor.dtype, device=tensor.device)
23+
tensor = torch_cls.cat([tensor, padding_tensor], dim=0)
24+
25+
new_batch.append(tensor)
26+
27+
return torch_cls.stack(new_batch, dim=0)
28+
29+
30+
def collate_tensors(batch: List, pad_value=0) -> Union[torch.Tensor, List]:
31+
datum = batch[0]
32+
33+
if isinstance(datum, dict): # Recurse over dictionaries
34+
return zero_pad_collator(batch)
35+
36+
if isinstance(datum, (int, np.int32)):
37+
return torch.tensor(batch, dtype=torch.long)
38+
39+
if isinstance(datum, (MaskedTensor, torch.Tensor)):
40+
return pad_tensors(batch, pad_value=pad_value)
41+
42+
return batch
43+
44+
45+
def zero_pad_collator(batch) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, ...]]:
46+
datum = batch[0]
47+
48+
# For strings
49+
if isinstance(datum, str):
50+
return batch
51+
52+
# For tuples
53+
if isinstance(datum, tuple):
54+
return tuple(collate_tensors([b[i] for b in batch]) for i in range(len(datum)))
55+
56+
# For tensors
57+
if isinstance(datum, MaskedTensor):
58+
return collate_tensors(batch)
59+
60+
# For dictionaries
61+
keys = datum.keys()
62+
return {k: collate_tensors([b[k] for b in batch]) for k in keys}
63+
64+

src/python/pose_format/torch/pose_body.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ class TorchPoseBody(PoseBody):
1515
1616
This class extends the PoseBody class and provides methods for manipulating pose data using PyTorch tensors.
1717
"""
18-
tensor_reader = 'unpack_torch' """str: Reader format for unpacking Torch tensors."""
18+
19+
"""str: Reader format for unpacking Torch tensors."""
20+
tensor_reader = 'unpack_torch'
1921

2022
def __init__(self, fps: float, data: Union[MaskedTensor, torch.Tensor], confidence: torch.Tensor):
2123
if isinstance(data, torch.Tensor): # If array is not masked

src/python/pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "pose_format"
33
description = "Library for viewing, augmenting, and handling .pose files"
4-
version = "0.2.3"
4+
version = "0.3.0"
55
keywords = ["Pose Files", "Pose Interpolation", "Pose Augmentation"]
66
authors = [
77
{ name = "Amit Moryossef", email = "[email protected]" },

0 commit comments

Comments
 (0)