Skip to content

Commit 23d10f7

Browse files
committed
refactor dim sorting
1 parent 2c7bdbe commit 23d10f7

File tree

1 file changed

+44
-35
lines changed

1 file changed

+44
-35
lines changed

suite2p/io/nd2.py

Lines changed: 44 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
def nd2_to_binary(ops):
11-
""" finds nd2 files and writes them to binaries
11+
"""finds nd2 files and writes them to binaries
1212
1313
Parameters
1414
----------
@@ -41,72 +41,81 @@ def nd2_to_binary(ops):
4141
nd2_file = nd2.ND2File(file_name)
4242
nd2_dims = {k: i for i, k in enumerate(nd2_file.sizes)}
4343

44-
valid_dimensions = set("TZCYX")
45-
assert valid_dimensions == set(nd2_dims), f"Unknown or missing dimension in {nd2_dims}"
44+
valid_dimensions = "TZCYX"
45+
assert set(nd2_dims) <= set(
46+
valid_dimensions
47+
), f"Unknown dimensions {set(nd2_dims)-set(valid_dimensions)} in file {file_name}."
4648

47-
im = nd2_file.asarray().transpose([nd2_dims[x] for x in valid_dimensions])
49+
# Sort the dimensions in the order of TZCYX, skipping the missing ones.
50+
im = nd2_file.asarray().transpose(
51+
[nd2_dims[x] for x in valid_dimensions if x in nd2_dims]
52+
)
4853

49-
# expand dimensions to have [Time (T), Depth (Z), Channel (C), Height (Y), Width (X)].
54+
# Expand array to include the missing dimensions.
5055
for i, dim in enumerate("TZC"):
5156
if dim not in nd2_dims:
5257
im = np.expand_dims(im, i)
53-
54-
nplanes = nd2_file.sizes['Z'] if 'Z' in nd2_file.sizes else 1
55-
nchannels = nd2_file.sizes['C'] if 'C' in nd2_file.sizes else 1
56-
nframes = nd2_file.sizes['T'] if 'T' in nd2_file.sizes else 1
5758

58-
iblocks = np.arange(0, nframes, ops1[0]['batch_size'])
59+
nplanes = nd2_file.sizes["Z"] if "Z" in nd2_file.sizes else 1
60+
nchannels = nd2_file.sizes["C"] if "C" in nd2_file.sizes else 1
61+
nframes = nd2_file.sizes["T"] if "T" in nd2_file.sizes else 1
62+
63+
iblocks = np.arange(0, nframes, ops1[0]["batch_size"])
5964
if iblocks[-1] < nframes:
6065
iblocks = np.append(iblocks, nframes)
6166

6267
if nchannels > 1:
63-
nfunc = ops1[0]['functional_chan'] - 1
68+
nfunc = ops1[0]["functional_chan"] - 1
6469
else:
6570
nfunc = 0
6671

67-
if im.dtype.type == np.uint16:
68-
im = (im // 2).astype(np.int16)
69-
elif im.dtype.type == np.int32:
70-
im = (im // 2).astype(np.int16)
71-
elif im.dtype.type != np.int16:
72-
im = im.astype(np.int16)
72+
assert im.max() < 32768 and im.min() >= -32768, "image data is out of range"
73+
im = im.astype(np.int16)
7374

7475
# loop over all frames
7576
for ichunk, onset in enumerate(iblocks[:-1]):
7677
offset = iblocks[ichunk + 1]
7778
im_p = np.array(im[onset:offset, :, :, :, :])
78-
im2mean = im_p.mean(axis = 0).astype(np.float32) / len(iblocks)
79+
im2mean = im_p.mean(axis=0).astype(np.float32) / len(iblocks)
7980
for ichan in range(nchannels):
8081
nframes = im_p.shape[0]
8182
im2write = im_p[:, :, ichan, :, :]
8283
for j in range(0, nplanes):
8384
if iall == 0:
84-
ops1[j]['meanImg'] = np.zeros((im_p.shape[3], im_p.shape[4]), np.float32)
85-
if nchannels>1:
86-
ops1[j]['meanImg_chan2'] = np.zeros((im_p.shape[3], im_p.shape[4]), np.float32)
87-
ops1[j]['nframes'] = 0
85+
ops1[j]["meanImg"] = np.zeros(
86+
(im_p.shape[3], im_p.shape[4]), np.float32
87+
)
88+
if nchannels > 1:
89+
ops1[j]["meanImg_chan2"] = np.zeros(
90+
(im_p.shape[3], im_p.shape[4]), np.float32
91+
)
92+
ops1[j]["nframes"] = 0
8893
if ichan == nfunc:
89-
ops1[j]['meanImg'] += np.squeeze(im2mean[j, ichan, :, :])
90-
reg_file[j].write(bytearray(im2write[:, j, :, :].astype('int16')))
94+
ops1[j]["meanImg"] += np.squeeze(im2mean[j, ichan, :, :])
95+
reg_file[j].write(
96+
bytearray(im2write[:, j, :, :].astype("int16"))
97+
)
9198
else:
92-
ops1[j]['meanImg_chan2'] += np.squeeze(im2mean[j, ichan, :, :])
93-
reg_file_chan2[j].write(bytearray(im2write[:, j, :, :].astype('int16')))
94-
95-
ops1[j]['nframes'] += im2write.shape[0]
99+
ops1[j]["meanImg_chan2"] += np.squeeze(im2mean[j, ichan, :, :])
100+
reg_file_chan2[j].write(
101+
bytearray(im2write[:, j, :, :].astype("int16"))
102+
)
103+
104+
ops1[j]["nframes"] += im2write.shape[0]
96105
ik += nframes
97106
iall += nframes
98-
107+
99108
nd2_file.close()
100109

101110
# write ops files
102-
do_registration = ops1[0]['do_registration']
111+
do_registration = ops1[0]["do_registration"]
103112
for ops in ops1:
104-
ops['Ly'] = im.shape[3]
105-
ops['Lx'] = im.shape[4]
113+
ops["Ly"] = im.shape[3]
114+
ops["Lx"] = im.shape[4]
106115
if not do_registration:
107-
ops['yrange'] = np.array([0, ops['Ly']])
108-
ops['xrange'] = np.array([0, ops['Lx']])
109-
np.save(ops['ops_path'], ops)
116+
ops["yrange"] = np.array([0, ops["Ly"]])
117+
ops["xrange"] = np.array([0, ops["Lx"]])
118+
np.save(ops["ops_path"], ops)
110119
# close all binary files and write ops files
111120
for j in range(0, nplanes):
112121
reg_file[j].close()

0 commit comments

Comments
 (0)