Skip to content

Commit 663e87e

Browse files
committed
feat(pose): read pose as bytesio, without reading the entire file
1 parent 4e799fe commit 663e87e

File tree

4 files changed

+80
-14
lines changed

4 files changed

+80
-14
lines changed

src/python/pose_format/pose.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from io import BytesIO
12
from itertools import chain
23
from typing import BinaryIO, Dict, List, Tuple, Type, Union
34

@@ -7,9 +8,10 @@
78
from pose_format.pose_body import PoseBody
89
from pose_format.pose_header import (PoseHeader, PoseHeaderComponent,
910
PoseHeaderDimensions,
10-
PoseNormalizationInfo)
11+
PoseNormalizationInfo, PoseHeaderCache)
1112
from pose_format.utils.fast_math import distance_batch
12-
from pose_format.utils.reader import BufferReader
13+
from pose_format.utils.reader import BufferReader, BytesIOReader
14+
1315

1416

1517
class Pose:
@@ -29,7 +31,7 @@ def __init__(self, header: PoseHeader, body: PoseBody):
2931
self.body = body
3032

3133
@staticmethod
32-
def read(buffer: bytes, pose_body: Type[PoseBody] = NumPyPoseBody, **kwargs):
34+
def read(buffer: Union[bytes, BytesIO], pose_body: Type[PoseBody] = NumPyPoseBody, **kwargs):
3335
"""
3436
Read Pose object from buffer.
3537
@@ -45,7 +47,8 @@ def read(buffer: bytes, pose_body: Type[PoseBody] = NumPyPoseBody, **kwargs):
4547
Pose
4648
Pose object.
4749
"""
48-
reader = BufferReader(buffer)
50+
reader = BufferReader(buffer) if isinstance(buffer, bytes) else BytesIOReader(buffer)
51+
reader.expect_to_read(PoseHeaderCache.end_offset or 10 * 1024) # Expect to read the header at least (or 10kb)
4952
header = PoseHeader.read(reader)
5053
body = pose_body.read(header, reader, **kwargs)
5154

src/python/pose_format/pose_body.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,9 @@ def read_v0_1_frames(cls,
128128
_frames = frames
129129
if start_frame is not None and start_frame > 0:
130130
if start_frame >= frames:
131-
raise ValueError("Start frame is greater than the number of frames")
131+
raise ValueError(f"Start frame {start_frame} is greater than the number of frames {frames}")
132132
# Advance to the start frame
133-
reader.advance(s, int(np.prod((start_frame, *shape))))
133+
reader.skip(s, int(np.prod((start_frame, *shape))))
134134
_frames -= start_frame
135135

136136
remove_frames = None
@@ -142,7 +142,7 @@ def read_v0_1_frames(cls,
142142
tensor = tensor_reader(ConstStructs.float, shape=(_frames, *shape))
143143

144144
if remove_frames is not None:
145-
reader.advance(s, int(np.prod((remove_frames, *shape))))
145+
reader.skip(s, int(np.prod((remove_frames, *shape))))
146146

147147
return tensor
148148

src/python/pose_format/utils/reader.py

+46-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import struct
22
from dataclasses import dataclass
3-
from typing import Tuple
3+
from io import BytesIO
4+
from typing import Tuple, Union
45

56
import numpy as np
67

@@ -35,6 +36,8 @@ class ConstStructs:
3536
Struct format for unsigned integer"""
3637

3738

39+
40+
3841
class BufferReader:
3942
"""
4043
Class is used to read binary data from buffer
@@ -45,11 +48,18 @@ class BufferReader:
4548
buffer from which to read data
4649
read_offset: int
4750
current read offset in buffer
51+
read_skipped: int
52+
how many bytes were skipped during reading
4853
"""
4954

50-
def __init__(self, buffer: bytes):
51-
self.buffer = buffer
55+
def __init__(self, buffer: Union[bytearray, bytes]):
56+
self.buffer: bytearray = buffer
57+
self.total_bytes_read = len(buffer)
5258
self.read_offset = 0
59+
self.read_skipped = 0
60+
61+
def expect_to_read(self, n: int):
62+
pass
5363

5464
def bytes_left(self):
5565
"""
@@ -60,7 +70,7 @@ def bytes_left(self):
6070
int
6171
The number of bytes left to read.
6272
"""
63-
return len(self.buffer) - self.read_offset
73+
return len(self.buffer) - self.read_offset +self.read_skipped
6474

6575
def unpack_f(self, s_format: str):
6676
"""
@@ -97,7 +107,9 @@ def unpack_numpy(self, s: struct.Struct, shape: Tuple):
97107
np.ndarray
98108
The unpacked NumPy array.
99109
"""
100-
arr = np.ndarray(shape, s.format, self.buffer, self.read_offset).copy()
110+
self.expect_to_read(s.size * int(np.prod(shape)))
111+
112+
arr = np.ndarray(shape, s.format, self.buffer, self.read_offset - self.read_skipped).copy()
101113
self.advance(s, int(np.prod(shape)))
102114
return arr
103115

@@ -155,7 +167,8 @@ def unpack(self, s: struct.Struct):
155167
-------
156168
Unpacked data as specified by the struct format.
157169
"""
158-
unpack: tuple = s.unpack_from(self.buffer, self.read_offset)
170+
self.expect_to_read(s.size)
171+
unpack: tuple = s.unpack_from(self.buffer, self.read_offset - self.read_skipped)
159172
self.advance(s)
160173
if len(unpack) == 1:
161174
return unpack[0]
@@ -174,6 +187,9 @@ def advance(self, s: struct.Struct, times=1):
174187
"""
175188
self.read_offset += s.size * times
176189

190+
def skip(self, s: struct.Struct, times=1):
191+
self.advance(s, times)
192+
177193
def unpack_str(self) -> str:
178194
"""
179195
Unpacks a string from the buffer.
@@ -184,10 +200,34 @@ def unpack_str(self) -> str:
184200
The unpacked string, encoded in UTF-8.
185201
"""
186202
length: int = self.unpack(ConstStructs.ushort)
203+
self.expect_to_read(length)
187204
bytes_: bytes = self.unpack_f("%ds" % length)
188205
return bytes_.decode("utf-8")
189206

190207

208+
class BytesIOReader(BufferReader):
209+
def __init__(self, reader: BytesIO):
210+
super().__init__(bytearray())
211+
self.reader = reader
212+
213+
def skip(self, s: struct.Struct, times=1):
214+
self.buffer = self.buffer[:self.read_offset] # remove the bytes that were not used
215+
self.read_skipped += s.size * times
216+
super().skip(s, times)
217+
218+
def read_chunk(self, chunk_size: int):
219+
self.reader.seek(self.read_offset, 0) # 0 means absolute seek
220+
self.buffer.extend(self.reader.read(chunk_size))
221+
self.total_bytes_read += chunk_size
222+
223+
if not self.buffer:
224+
raise EOFError("End of file reached")
225+
226+
def expect_to_read(self, n: int):
227+
if self.bytes_left() < n:
228+
self.read_chunk(n - self.bytes_left())
229+
230+
191231
if __name__ == "__main__":
192232
from tqdm import tqdm
193233

src/python/pose_format/utils/reader_test.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import os
22
import struct
3+
import tempfile
34
from unittest import TestCase
45

56
import numpy as np
67
import torch
8+
from pose_format import Pose
9+
from pose_format.utils.generic import fake_pose
10+
from pose_format.utils.openpose import OpenPose_Components
711

812
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
9-
import tensorflow as tf
1013

1114
from pose_format.utils.reader import BufferReader, ConstStructs
1215

@@ -83,6 +86,8 @@ def test_unpack_torch(self):
8386

8487
def test_unpack_tensorflow(self):
8588
""" Test that unpack_tensorflow returns the correct value"""
89+
import tensorflow as tf
90+
8691
buffer = struct.pack("<ffff", 1., 2.5, 3.5, 4.5)
8792
reader = BufferReader(buffer)
8893

@@ -91,3 +96,21 @@ def test_unpack_tensorflow(self):
9196
res = tf.constant([[1., 2.5], [3.5, 4.5]])
9297
self.assertTrue(tf.reduce_all(tf.equal(arr, res)),
9398
msg="Tensorflow unpacked array is not equal to expected array")
99+
100+
def test_file_reader_equal_buffer_reader(self):
101+
pose = fake_pose(100, fps=25, components=OpenPose_Components)
102+
103+
file_path = tempfile.NamedTemporaryFile(delete=False)
104+
with open(file_path.name, "wb") as f:
105+
pose.write(f)
106+
107+
with open(file_path.name, "rb") as f:
108+
pose_1 = Pose.read(f, start_frame=10, end_frame=50)
109+
110+
with open(file_path.name, "rb") as f:
111+
pose_2 = Pose.read(f.read(), start_frame=10, end_frame=50)
112+
113+
self.assertEqual(pose_1.header, pose_2.header)
114+
self.assertEqual(pose_1.body.fps, pose_2.body.fps)
115+
self.assertTrue(np.all(pose_1.body.data == pose_2.body.data))
116+
self.assertTrue(np.all(pose_1.body.confidence == pose_2.body.confidence))

0 commit comments

Comments
 (0)