@@ -123,34 +123,50 @@ def get_colormap(N):
123
123
return cmap
124
124
125
125
126
- def visualize_cohorts_2d (by , array , merge = True , method = "cohorts" ):
127
- assert by .ndim == 2
128
- print ("finding cohorts..." )
129
- cohorts = find_group_cohorts (
130
- by , [array .chunks [ax ] for ax in range (- by .ndim , 0 )], merge = merge , method = method
131
- )
126
+ def factorize_cohorts (by , cohorts ):
132
127
133
128
factorized = np .full (by .shape , - 1 )
134
129
for idx , cohort in enumerate (cohorts ):
135
130
factorized [np .isin (by , cohort )] = idx
136
- ncohorts = idx
131
+ return factorized
132
+
133
+
134
+ def visualize_cohorts_2d (by , array , method = "cohorts" ):
135
+ assert by .ndim == 2
136
+ print ("finding cohorts..." )
137
+ before_merged = find_group_cohorts (
138
+ by , [array .chunks [ax ] for ax in range (- by .ndim , 0 )], merge = False , method = method
139
+ )
140
+ merged = find_group_cohorts (
141
+ by , [array .chunks [ax ] for ax in range (- by .ndim , 0 )], merge = True , method = method
142
+ )
143
+ print ("finished cohorts..." )
137
144
138
145
xticks = np .cumsum (array .chunks [- 1 ])
139
146
yticks = np .cumsum (array .chunks [- 2 ])
140
147
141
- f , ax = plt .subplots (2 , 1 , constrained_layout = True , sharex = True , sharey = True )
142
-
148
+ f , ax = plt .subplots (2 , 2 , constrained_layout = True , sharex = True , sharey = True )
149
+ ax = ax .ravel ()
150
+ ax [1 ].set_visible (False )
151
+ ax = ax [[0 , 2 , 3 ]]
143
152
flat = by .ravel ()
144
153
ngroups = len (np .unique (flat [~ np .isnan (flat )]))
145
154
146
155
h0 = ax [0 ].imshow (by , cmap = get_colormap (ngroups ))
147
- h1 = ax [1 ].imshow (factorized , aspect = "equal" , vmin = 0 , cmap = get_colormap (ncohorts ))
156
+ h1 = ax [1 ].imshow (
157
+ factorize_cohorts (by , before_merged ),
158
+ vmin = 0 ,
159
+ cmap = get_colormap (len (before_merged )),
160
+ )
161
+ h2 = ax [2 ].imshow (factorize_cohorts (by , merged ), vmin = 0 , cmap = get_colormap (len (merged )))
148
162
for axx in ax :
149
163
axx .grid (True , which = "both" )
150
164
axx .set_xticks (xticks )
151
165
axx .set_yticks (yticks )
152
- f .colorbar (h0 , ax = ax [0 ])
153
- f .colorbar (h1 , ax = ax [1 ])
154
- ax [0 ].set_title ("by" )
155
- ax [1 ].set_title ("cohorts" )
166
+ for h , axx in zip ([h0 , h1 , h2 ], ax ):
167
+ f .colorbar (h , ax = axx , orientation = "horizontal" )
168
+
169
+ ax [0 ].set_title (f"by: { ngroups } groups" )
170
+ ax [1 ].set_title (f"{ len (before_merged )} cohorts" )
171
+ ax [2 ].set_title (f"{ len (merged )} merged cohorts" )
156
172
f .set_size_inches ((6 , 6 ))
0 commit comments