Skip to content

Commit 74cc5eb

Browse files
committed
Update 2d cohorts visualization
1 parent b0dcf33 commit 74cc5eb

File tree

1 file changed

+30
-14
lines changed

1 file changed

+30
-14
lines changed

flox/visualize.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -123,34 +123,50 @@ def get_colormap(N):
123123
return cmap
124124

125125

126-
def visualize_cohorts_2d(by, array, merge=True, method="cohorts"):
127-
assert by.ndim == 2
128-
print("finding cohorts...")
129-
cohorts = find_group_cohorts(
130-
by, [array.chunks[ax] for ax in range(-by.ndim, 0)], merge=merge, method=method
131-
)
126+
def factorize_cohorts(by, cohorts):
132127

133128
factorized = np.full(by.shape, -1)
134129
for idx, cohort in enumerate(cohorts):
135130
factorized[np.isin(by, cohort)] = idx
136-
ncohorts = idx
131+
return factorized
132+
133+
134+
def visualize_cohorts_2d(by, array, method="cohorts"):
135+
assert by.ndim == 2
136+
print("finding cohorts...")
137+
before_merged = find_group_cohorts(
138+
by, [array.chunks[ax] for ax in range(-by.ndim, 0)], merge=False, method=method
139+
)
140+
merged = find_group_cohorts(
141+
by, [array.chunks[ax] for ax in range(-by.ndim, 0)], merge=True, method=method
142+
)
143+
print("finished cohorts...")
137144

138145
xticks = np.cumsum(array.chunks[-1])
139146
yticks = np.cumsum(array.chunks[-2])
140147

141-
f, ax = plt.subplots(2, 1, constrained_layout=True, sharex=True, sharey=True)
142-
148+
f, ax = plt.subplots(2, 2, constrained_layout=True, sharex=True, sharey=True)
149+
ax = ax.ravel()
150+
ax[1].set_visible(False)
151+
ax = ax[[0, 2, 3]]
143152
flat = by.ravel()
144153
ngroups = len(np.unique(flat[~np.isnan(flat)]))
145154

146155
h0 = ax[0].imshow(by, cmap=get_colormap(ngroups))
147-
h1 = ax[1].imshow(factorized, aspect="equal", vmin=0, cmap=get_colormap(ncohorts))
156+
h1 = ax[1].imshow(
157+
factorize_cohorts(by, before_merged),
158+
vmin=0,
159+
cmap=get_colormap(len(before_merged)),
160+
)
161+
h2 = ax[2].imshow(factorize_cohorts(by, merged), vmin=0, cmap=get_colormap(len(merged)))
148162
for axx in ax:
149163
axx.grid(True, which="both")
150164
axx.set_xticks(xticks)
151165
axx.set_yticks(yticks)
152-
f.colorbar(h0, ax=ax[0])
153-
f.colorbar(h1, ax=ax[1])
154-
ax[0].set_title("by")
155-
ax[1].set_title("cohorts")
166+
for h, axx in zip([h0, h1, h2], ax):
167+
f.colorbar(h, ax=axx, orientation="horizontal")
168+
169+
ax[0].set_title(f"by: {ngroups} groups")
170+
ax[1].set_title(f"{len(before_merged)} cohorts")
171+
ax[2].set_title(f"{len(merged)} merged cohorts")
156172
f.set_size_inches((6, 6))

0 commit comments

Comments
 (0)