Skip to content

Commit c1f0d65

Browse files
authored
Merge pull request #1142 from anibalsolon/fix/invalid_tck_handling
ENH: Assume TCK is open in binary mode
2 parents 69307d3 + 203ed0c commit c1f0d65

File tree

2 files changed

+46
-24
lines changed

2 files changed

+46
-24
lines changed

nibabel/streamlines/tck.py

+18-24
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import warnings
99

1010
import numpy as np
11-
from numpy.compat.py3k import asbytes, asstr
1211

1312
from nibabel.openers import Opener
1413

@@ -44,7 +43,7 @@ class TckFile(TractogramFile):
4443
.. [#] http://nipy.org/nibabel/coordinate_systems.html#voxel-coordinates-are-in-voxel-space
4544
"""
4645
# Constants
47-
MAGIC_NUMBER = "mrtrix tracks"
46+
MAGIC_NUMBER = b"mrtrix tracks"
4847
SUPPORTS_DATA_PER_POINT = False # Not yet
4948
SUPPORTS_DATA_PER_STREAMLINE = False # Not yet
5049

@@ -94,7 +93,7 @@ def is_correct_format(cls, fileobj):
9493
magic_number = f.read(len(cls.MAGIC_NUMBER))
9594
f.seek(-len(cls.MAGIC_NUMBER), os.SEEK_CUR)
9695

97-
return asstr(magic_number) == cls.MAGIC_NUMBER
96+
return magic_number == cls.MAGIC_NUMBER
9897

9998
@classmethod
10099
def create_empty_header(cls):
@@ -230,7 +229,7 @@ def save(self, fileobj):
230229
header[Field.NB_STREAMLINES] = nb_streamlines
231230

232231
# Add the EOF_DELIMITER.
233-
f.write(asbytes(self.EOF_DELIMITER.tobytes()))
232+
f.write(self.EOF_DELIMITER.tobytes())
234233
self._finalize_header(f, header, offset=beginning)
235234

236235
@staticmethod
@@ -251,41 +250,36 @@ def _write_header(fileobj, header):
251250
"count", "datatype", "file"] # Fields being replaced.
252251

253252
lines = []
254-
lines.append(asstr(header[Field.MAGIC_NUMBER]))
255253
lines.append(f"count: {header[Field.NB_STREAMLINES]:010}")
256254
lines.append("datatype: Float32LE") # Always Float32LE.
257255
lines.extend([f"{k}: {v}"
258256
for k, v in header.items()
259257
if k not in exclude and not k.startswith("_")])
260-
lines.append("file: . ") # Manually add this last field.
261258
out = "\n".join(lines)
262259

263260
# Check the header is well formatted.
264261
if out.count("\n") > len(lines) - 1: # \n only allowed between lines.
265262
msg = f"Key-value pairs cannot contain '\\n':\n{out}"
266263
raise HeaderError(msg)
267264

268-
if out.count(":") > len(lines) - 1:
265+
if out.count(":") > len(lines):
269266
# : only one per line (except the last one which contains END).
270267
msg = f"Key-value pairs cannot contain ':':\n{out}"
271268
raise HeaderError(msg)
272269

270+
out = header[Field.MAGIC_NUMBER] + b"\n" + out.encode('utf-8')
271+
272+
# Compute data offset considering the offset string representation
273+
# headers + "file" header + END + \n's
274+
hdr_offset = len(out) + 8 + 3 + 3
275+
offset_repr = f'{hdr_offset}'
276+
277+
# Adding the offset may increase one char to the offset repr
278+
hdr_offset += len(f'{hdr_offset + len(offset_repr)}')
279+
273280
# Write header to file.
274-
fileobj.write(asbytes(out))
275-
276-
hdr_len_no_offset = len(out) + 5
277-
# Need to add number of bytes to store offset as decimal string. We
278-
# start with estimate without string, then update if the
279-
# offset-as-decimal-string got longer after adding length of the
280-
# offset string.
281-
new_offset = -1
282-
old_offset = hdr_len_no_offset
283-
while new_offset != old_offset:
284-
old_offset = new_offset
285-
new_offset = hdr_len_no_offset + len(str(old_offset))
286-
287-
fileobj.write(asbytes(str(new_offset) + "\n"))
288-
fileobj.write(asbytes("END\n"))
281+
fileobj.write(out)
282+
fileobj.write(f'\nfile: . {hdr_offset}\nEND\n'.encode('utf-8'))
289283

290284
@classmethod
291285
def _read_header(cls, fileobj):
@@ -320,7 +314,7 @@ def _read_header(cls, fileobj):
320314
# Read magic number
321315
magic_number = f.read(len(cls.MAGIC_NUMBER))
322316

323-
if asstr(magic_number) != cls.MAGIC_NUMBER:
317+
if magic_number != cls.MAGIC_NUMBER:
324318
raise HeaderError(f"Invalid magic number: {magic_number}")
325319

326320
hdr[Field.MAGIC_NUMBER] = magic_number
@@ -331,7 +325,7 @@ def _read_header(cls, fileobj):
331325

332326
# Read all key-value pairs contained in the header, stop at EOF
333327
for n_line, line in enumerate(f, 1):
334-
line = asstr(line).strip()
328+
line = line.decode('utf-8').strip()
335329

336330
if not line: # Skip empty lines
337331
continue

nibabel/streamlines/tests/test_tck.py

+28
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,34 @@ def test_write_simple_file(self):
216216
with pytest.raises(HeaderError):
217217
tck.save(tck_file)
218218

219+
def test_write_bigheader_file(self):
220+
tractogram = Tractogram(DATA['streamlines'],
221+
affine_to_rasmm=np.eye(4))
222+
223+
# Offset is represented by 2 characters.
224+
tck_file = BytesIO()
225+
tck = TckFile(tractogram)
226+
tck.header['new_entry'] = ' ' * 20
227+
tck.save(tck_file)
228+
tck_file.seek(0, os.SEEK_SET)
229+
230+
new_tck = TckFile.load(tck_file)
231+
assert_tractogram_equal(new_tck.tractogram, tractogram)
232+
assert new_tck.header['_offset_data'] == 99
233+
234+
# We made the jump, now offset is represented by 3 characters
235+
# and we need to adjust the offset!
236+
tck_file = BytesIO()
237+
tck = TckFile(tractogram)
238+
tck.header['new_entry'] = ' ' * 21
239+
tck.save(tck_file)
240+
tck_file.seek(0, os.SEEK_SET)
241+
242+
new_tck = TckFile.load(tck_file)
243+
assert_tractogram_equal(new_tck.tractogram, tractogram)
244+
assert new_tck.header['_offset_data'] == 101
245+
246+
219247
def test_load_write_file(self):
220248
for fname in [DATA['empty_tck_fname'],
221249
DATA['simple_tck_fname']]:

0 commit comments

Comments
 (0)