|
26 | 26 | import warnings
|
27 | 27 | from functools import partial
|
28 | 28 | from itertools import product
|
| 29 | +import io |
29 | 30 | import pathlib
|
30 | 31 |
|
31 | 32 | import numpy as np
|
@@ -523,34 +524,41 @@ def validate_affine_deprecated(self, imaker, params):
|
523 | 524 | img.get_affine()
|
524 | 525 |
|
525 | 526 |
|
526 |
| -class SerializeMixin(object): |
527 |
| - def validate_to_bytes(self, imaker, params): |
| 527 | +class SerializeMixin: |
| 528 | + def validate_to_from_stream(self, imaker, params): |
528 | 529 | img = imaker()
|
529 |
| - serialized = img.to_bytes() |
530 |
| - with InTemporaryDirectory(): |
531 |
| - fname = 'img' + self.standard_extension |
532 |
| - img.to_filename(fname) |
533 |
| - with open(fname, 'rb') as fobj: |
534 |
| - file_contents = fobj.read() |
535 |
| - assert serialized == file_contents |
| 530 | + klass = getattr(self, 'klass', img.__class__) |
| 531 | + stream = io.BytesIO() |
| 532 | + img.to_stream(stream) |
| 533 | + |
| 534 | + rt_img = klass.from_stream(stream) |
| 535 | + assert self._header_eq(img.header, rt_img.header) |
| 536 | + assert np.array_equal(img.get_fdata(), rt_img.get_fdata()) |
536 | 537 |
|
537 |
| - def validate_from_bytes(self, imaker, params): |
| 538 | + def validate_file_stream_equivalence(self, imaker, params): |
538 | 539 | img = imaker()
|
539 | 540 | klass = getattr(self, 'klass', img.__class__)
|
540 | 541 | with InTemporaryDirectory():
|
541 | 542 | fname = 'img' + self.standard_extension
|
542 | 543 | img.to_filename(fname)
|
543 | 544 |
|
544 |
| - all_images = list(getattr(self, 'example_images', [])) + [{'fname': fname}] |
545 |
| - for img_params in all_images: |
546 |
| - img_a = klass.from_filename(img_params['fname']) |
547 |
| - with open(img_params['fname'], 'rb') as fobj: |
548 |
| - img_b = klass.from_bytes(fobj.read()) |
| 545 | + with open("stream", "wb") as fobj: |
| 546 | + img.to_stream(fobj) |
549 | 547 |
|
550 |
| - assert self._header_eq(img_a.header, img_b.header) |
| 548 | + # Check that writing gets us the same thing |
| 549 | + contents1 = pathlib.Path(fname).read_bytes() |
| 550 | + contents2 = pathlib.Path("stream").read_bytes() |
| 551 | + assert contents1 == contents2 |
| 552 | + |
| 553 | + # Check that reading gets us the same thing |
| 554 | + img_a = klass.from_filename(fname) |
| 555 | + with open(fname, "rb") as fobj: |
| 556 | + img_b = klass.from_stream(fobj) |
| 557 | + # This needs to happen while the filehandle is open |
551 | 558 | assert np.array_equal(img_a.get_fdata(), img_b.get_fdata())
|
552 |
| - del img_a |
553 |
| - del img_b |
| 559 | + assert self._header_eq(img_a.header, img_b.header) |
| 560 | + del img_a |
| 561 | + del img_b |
554 | 562 |
|
555 | 563 | def validate_to_from_bytes(self, imaker, params):
|
556 | 564 | img = imaker()
|
|
0 commit comments