Skip to content

Commit f85bfb9

Browse files
authored
Feature/pose copy and remove components (#148)
* CDL: minor doc typo fix * Undoing some changes that got mixed in * Add Pose .copy() and .remove_components() * fix type annotation in pose remove_components * Adding copy functions to posebodies, and tests for this * Some pylint changes * Fix return type annotations, use copy() in zero_filled * uncomment pytests for torchposebody, rename MaskedTensor imports * import numpy.ma instead of 'from numpy import ma'
1 parent 4a2e6d0 commit f85bfb9

File tree

6 files changed

+264
-28
lines changed

6 files changed

+264
-28
lines changed

src/python/pose_format/numpy/pose_body.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ def write(self, version: float, buffer: BinaryIO):
128128
buffer.write(np.array(self.data.data, dtype=np.float32).tobytes())
129129
buffer.write(np.array(self.confidence, dtype=np.float32).tobytes())
130130

131+
def copy(self) -> 'NumPyPoseBody':
132+
return type(self)(fps=self.fps,
133+
data=self.data.copy(),
134+
confidence=self.confidence.copy())
135+
131136
@property
132137
def mask(self):
133138
""" Returns mask associated with data. """
@@ -181,8 +186,9 @@ def zero_filled(self):
181186
NumPyPoseBody
182187
changed pose body data.
183188
"""
184-
self.data = ma.array(self.data.filled(0), mask=self.data.mask)
185-
return self
189+
copy = self.copy()
190+
copy.data = ma.array(copy.data.filled(0), mask=copy.data.mask)
191+
return copy
186192

187193
def matmul(self, matrix: np.ndarray):
188194
"""

src/python/pose_format/pose.py

+26
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,28 @@ def frame_dropout_normal(self, dropout_mean: float = 0.5, dropout_std: float = 0
202202
"""
203203
body, selected_indexes = self.body.frame_dropout_normal(dropout_mean=dropout_mean, dropout_std=dropout_std)
204204
return Pose(header=self.header, body=body), selected_indexes
205+
206+
207+
def remove_components(self, components_to_remove: Union[str, List[str]], points_to_remove: Union[Dict[str, List[str]],None] = None):
208+
209+
if isinstance(components_to_remove, str):
210+
components_to_remove = [components_to_remove]
211+
212+
components_to_keep = []
213+
points_dict = {}
214+
215+
for component in self.header.components:
216+
if component.name not in components_to_remove:
217+
components_to_keep.append(component.name)
218+
points_dict[component.name] = []
219+
if points_to_remove is not None:
220+
for point in component.points:
221+
if point not in points_to_remove[component.name]:
222+
points_dict[component.name].append(point)
223+
224+
return self.get_components(components_to_keep, points_dict)
225+
226+
205227

206228
def get_components(self, components: List[str], points: Union[Dict[str, List[str]],None] = None):
207229
"""
@@ -253,6 +275,10 @@ def get_components(self, components: List[str], points: Union[Dict[str, List[str
253275
new_body = self.body.get_points(flat_indexes)
254276

255277
return Pose(header=new_header, body=new_body)
278+
279+
280+
def copy(self):
281+
return self.__class__(self.header, self.body.copy())
256282

257283
def bbox(self):
258284
"""

src/python/pose_format/pose_body.py

+21-15
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import math
12
from random import sample
2-
from typing import BinaryIO, List, Tuple
3+
from typing import BinaryIO, List, Tuple, Optional
34

45
import numpy as np
5-
import math
6+
67

78
from pose_format.pose_header import PoseHeader
89
from pose_format.utils.reader import BufferReader, ConstStructs
@@ -60,9 +61,9 @@ def read(cls, header: PoseHeader, reader: BufferReader, **kwargs) -> "PoseBody":
6061

6162
if header.version == 0:
6263
return cls.read_v0_0(header, reader, **kwargs)
63-
elif round(header.version, 3) == 0.1:
64+
if round(header.version, 3) == 0.1:
6465
return cls.read_v0_1(header, reader, **kwargs)
65-
elif round(header.version, 3) == 0.2:
66+
if round(header.version, 3) == 0.2:
6667
return cls.read_v0_2(header, reader, **kwargs)
6768

6869
raise NotImplementedError("Unknown version - %f" % header.version)
@@ -93,8 +94,8 @@ def read_v0_1_frames(cls,
9394
frames: int,
9495
shape: List[int],
9596
reader: BufferReader,
96-
start_frame: int = None,
97-
end_frame: int = None):
97+
start_frame: Optional[int] = None,
98+
end_frame: Optional[int] = None):
9899
"""
99100
Reads frame data for version 0.1 from a buffer.
100101
@@ -149,8 +150,8 @@ def read_v0_1_frames(cls,
149150
def read_v0_1(cls,
150151
header: PoseHeader,
151152
reader: BufferReader,
152-
start_frame: int = None,
153-
end_frame: int = None,
153+
start_frame: Optional[int] = None,
154+
end_frame: Optional[int] = None,
154155
**unused_kwargs) -> "PoseBody":
155156
"""
156157
Reads pose data for version 0.1 from a buffer.
@@ -176,7 +177,7 @@ def read_v0_1(cls,
176177
fps, _frames = reader.unpack(ConstStructs.double_ushort)
177178

178179
_people = reader.unpack(ConstStructs.ushort)
179-
_points = sum([len(c.points) for c in header.components])
180+
_points = sum(len(c.points) for c in header.components)
180181
_dims = header.num_dims()
181182

182183
# _frames is defined as short, which sometimes is not enough! TODO change to int
@@ -191,10 +192,10 @@ def read_v0_1(cls,
191192
def read_v0_2(cls,
192193
header: PoseHeader,
193194
reader: BufferReader,
194-
start_frame: int = None,
195-
end_frame: int = None,
196-
start_time: int = None,
197-
end_time: int = None,
195+
start_frame: Optional[int] = None,
196+
end_frame: Optional[int] = None,
197+
start_time: Optional[int] = None,
198+
end_time: Optional[int] = None,
198199
**unused_kwargs) -> "PoseBody":
199200
"""
200201
Reads pose data for version 0.2 from a buffer.
@@ -256,6 +257,11 @@ def write(self, version: float, buffer: BinaryIO):
256257
Buffer to write the pose data to.
257258
"""
258259
raise NotImplementedError("'write' not implemented on '%s'" % self.__class__)
260+
261+
def copy(self)->"PoseBody":
262+
return self.__class__(fps=self.fps,
263+
data=self.data,
264+
confidence=self.confidence)
259265

260266
def __getitem__(self, index):
261267
"""
@@ -306,7 +312,7 @@ def torch(self):
306312
Raises
307313
------
308314
NotImplementedError
309-
If toch is not implemented.
315+
If torch is not implemented.
310316
"""
311317
raise NotImplementedError("'torch' not implemented on '%s'" % self.__class__)
312318

@@ -474,7 +480,7 @@ def get_points(self, indexes: List[int]) -> __qualname__:
474480
Returns
475481
-------
476482
PoseBody
477-
PoseBody instance containing only choosen points.
483+
PoseBody instance containing only chosen points.
478484
479485
Raises
480486
------

src/python/pose_format/tensorflow/pose_body.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class TensorflowPoseBody(PoseBody):
1717
"""
1818
Representation of pose body data, optimized for TensorFlow operations.
1919
20-
* Inherites from PoseBody
20+
* Inherits from PoseBody
2121
2222
Parameters
2323
----------
@@ -43,10 +43,11 @@ def __init__(self, fps: float, data: Union[MaskedTensor, tf.Tensor], confidence:
4343

4444
super().__init__(fps, data, confidence)
4545

46-
def zero_filled(self):
46+
def zero_filled(self) -> 'TensorflowPoseBody':
4747
"""Return an instance with zero-filled data."""
48-
self.data = self.data.zero_filled()
49-
return self
48+
copy = self.copy()
49+
copy.data = self.data.zero_filled()
50+
return copy
5051

5152
def select_frames(self, frame_indexes: List[int]):
5253
"""
@@ -152,6 +153,17 @@ def points_perspective(self) -> MaskedTensor:
152153
"""
153154
return self.data.transpose(perm=POINTS_DIMS)
154155

156+
def copy(self) -> 'TensorflowPoseBody':
157+
# Ensure copies are fully detached from the TF computation graph by round-trip through numpy
158+
detached_data = tf.convert_to_tensor(self.data.tensor.numpy())
159+
detached_mask = tf.convert_to_tensor(self.data.mask.numpy())
160+
data_copy = MaskedTensor(detached_data, detached_mask)
161+
confidence_copy = tf.convert_to_tensor(self.confidence.numpy())
162+
return self.__class__(
163+
fps=self.fps,
164+
data=data_copy,
165+
confidence=confidence_copy)
166+
155167
def get_points(self, indexes: List[int]):
156168
"""
157169
Gets and returns points from pose data based on indexes

src/python/pose_format/torch/pose_body.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
import torch
55

66
from ..pose_body import POINTS_DIMS, PoseBody
7-
from ..pose_header import PoseHeader
8-
from ..utils.reader import BufferReader
97
from .masked.tensor import MaskedTensor
108

119

@@ -28,10 +26,21 @@ def __init__(self, fps: float, data: Union[MaskedTensor, torch.Tensor], confiden
2826
super().__init__(fps, data, confidence)
2927

3028
def cuda(self):
31-
"""Move data and cofidence of tensors to GPU"""
29+
"""Move data and confidence of tensors to GPU"""
3230
self.data = self.data.cuda()
3331
self.confidence = self.confidence.cuda()
3432

33+
def copy(self) -> 'TorchPoseBody':
34+
data_copy = MaskedTensor(tensor=self.data.tensor.detach().clone().to(self.data.tensor.device),
35+
mask=self.data.mask.detach().clone().to(self.data.mask.device),
36+
)
37+
confidence_copy = self.confidence.detach().clone().to(self.confidence.device)
38+
39+
return self.__class__(fps=self.fps,
40+
data=data_copy,
41+
confidence=confidence_copy)
42+
43+
3544
def zero_filled(self) -> 'TorchPoseBody':
3645
"""
3746
Fill invalid values with zeros.
@@ -42,8 +51,9 @@ def zero_filled(self) -> 'TorchPoseBody':
4251
TorchPoseBody instance with masked data filled with zeros.
4352
4453
"""
45-
self.data.zero_filled()
46-
return self
54+
copy = self.copy()
55+
copy.data = copy.data.zero_filled()
56+
return copy
4757

4858
def matmul(self, matrix: np.ndarray) -> 'TorchPoseBody':
4959
"""
@@ -120,3 +130,6 @@ def flatten(self):
120130
scalar = torch.ones(len(shape) + shape[-1], device=data.device)
121131
scalar[0] = 1 / self.fps
122132
return flat * scalar
133+
134+
135+

0 commit comments

Comments
 (0)