Skip to content

Commit 5d12e4f

Browse files
authored
Merge pull request #26 from Leengit/feature_vectors_squashed
ENH: Support features of shape (1024,)
2 parents 61627c4 + 31e5d87 commit 5d12e4f

File tree

6 files changed

+413
-108
lines changed

6 files changed

+413
-108
lines changed

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ RUN apt-get update && \
1313

1414
COPY . /opt/scw
1515
WORKDIR /opt/scw
16-
RUN python -m pip install --no-cache-dir -e .[tensorflow,torch] --find-links https://girder.github.io/large_image_wheels --extra-index-url https://download.pytorch.org/whl/cu117 && \
16+
RUN python -m pip install --no-cache-dir -e .[tensorflow,torch] --find-links https://girder.github.io/large_image_wheels --extra-index-url https://download.pytorch.org/whl/cu126 && \
1717
rm -rf /root/.cache/pip/* && \
1818
rdfind -minsize 32768 -makehardlinks true -makeresultsfile false /usr/local
1919

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ tensorflow = [
2828
"keras",
2929
]
3030
torch = [
31-
"torch==1.13.1+cu117",
31+
"torch",
32+
"torchvision",
3233
"batchbald_redux",
34+
"huggingface_hub",
35+
"timm",
3336
]

superpixel_classification/SuperpixelClassification/SuperpixelClassification.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from histomicstk.cli.utils import CLIArgumentParser
2+
from SuperpixelClassificationBase import SuperpixelClassificationBase
23

34
if __name__ == '__main__':
45
args = CLIArgumentParser().parse_args()
5-
if args.certainty == 'batchbald':
6+
# Use tensorflow unless the dependency requires torch
7+
superpixel_classification: SuperpixelClassificationBase
8+
if args.certainty == 'batchbald' or args.feature == 'vector':
69
from SuperpixelClassificationTorch import SuperpixelClassificationTorch
710

811
superpixel_classification = SuperpixelClassificationTorch()

superpixel_classification/SuperpixelClassification/SuperpixelClassification.xml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,15 @@
163163
<element>negative_entropy</element>
164164
<element>batchbald</element>
165165
</string-enumeration>
166+
<string-enumeration>
167+
<name>feature</name>
168+
<label>Feature Shape</label>
169+
<description>Whether a feature is superpixel image data or a foundation model vector</description>
170+
<longflag>feature</longflag>
171+
<default>image</default>
172+
<element>image</element>
173+
<element>vector</element>
174+
</string-enumeration>
166175
</parameters>
167176
<parameters advanced="true">
168177
<label>Girder API URL and Key</label>

superpixel_classification/SuperpixelClassification/SuperpixelClassificationBase.py

Lines changed: 112 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
import girder_client
1515
import h5py
1616
import numpy as np
17+
import tenacity
1718
from numpy.typing import NDArray
1819
from progress_helper import ProgressHelper
19-
from tenacity import Retrying, stop_after_attempt
2020

2121

2222
def summary_repr(contents, collapseSequences=False):
@@ -37,9 +37,7 @@ def summary_repr(contents, collapseSequences=False):
3737
A string representation of a summary of the object
3838
3939
"""
40-
if isinstance(
41-
contents, (bool, int, float, str, np.int32, np.int64, np.float32, np.float64),
42-
):
40+
if isinstance(contents, (bool, int, float, str, np.int32, np.int64, np.float32, np.float64)):
4341
return repr(contents)
4442
if isinstance(contents, (list, tuple, dict, set)) and len(contents) == 0:
4543
return repr(type(contents)())
@@ -90,11 +88,7 @@ def summary_repr(contents, collapseSequences=False):
9088
f", 'and {len(contents) - 1} more'" +
9189
'}'
9290
)
93-
return (
94-
'{' +
95-
', '.join([summary_repr(elem, collapseSequences) for elem in contents]) +
96-
'}'
97-
)
91+
return '{' + ', '.join([summary_repr(elem, collapseSequences) for elem in contents]) + '}'
9892
if isinstance(contents, np.ndarray):
9993
return (
10094
repr(type(contents)) +
@@ -245,7 +239,7 @@ def progCallback(step, count, total):
245239
SuperpixelSegmentation.createSuperPixels(spopts)
246240
del spopts.callback
247241
prog.item_progress(item, 0.9)
248-
for attempt in Retrying(stop=stop_after_attempt(self.uploadRetries)):
242+
for attempt in tenacity.Retrying(stop=tenacity.stop_after_attempt(self.uploadRetries)):
249243
with attempt:
250244
outImageFile = gc.uploadFileToFolder(annotationFolderId, outImagePath)
251245
outImageId = outImageFile['itemId']
@@ -260,7 +254,7 @@ def progCallback(step, count, total):
260254
with open(outAnnotationPath, 'w') as annotation_file:
261255
json.dump(annot, annotation_file, indent=2, sort_keys=False)
262256
count = len(gc.get('annotation', parameters=dict(itemId=item['_id'])))
263-
for attempt in Retrying(stop=stop_after_attempt(self.uploadRetries)):
257+
for attempt in tenacity.Retrying(stop=tenacity.stop_after_attempt(self.uploadRetries)):
264258
with attempt:
265259
gc.uploadFileToItem(
266260
item['_id'], outAnnotationPath,
@@ -300,12 +294,51 @@ def createSuperpixels(self, gc, folderId, annotationName, radius, magnification,
300294
results[item['_id']] = future.result()
301295
return results
302296

297+
def initializeCreateFeatureFromPatchAndMaskSimple(self):
298+
# There is nothing to initialize
299+
pass
300+
301+
def initializeCreateFeatureFromPatchAndMask(self):
302+
# This SuperpixelClassificationBase implementation allows only the "Simple"
303+
# approach.
304+
# assert self.feature_is_image
305+
self.initializeCreateFeatureFromPatchAndMaskSimple()
306+
307+
def createFeatureFromPatchAndMaskSimple(self, patch, mask, maskvals):
308+
feature = np.array(patch.copy()).astype(np.uint8)
309+
feature[(mask != maskvals[0]).any(axis=-1) & (mask != maskvals[1]).any(axis=-1)] = [0, 0, 0]
310+
return feature
311+
312+
def createFeatureListFromPatchAndMaskListSimple(self, patch_list, mask_list, maskvals_list):
313+
feature_list = [
314+
self.createFeatureFromPatchAndMaskSimple(patch, mask, maskvals)
315+
for patch, mask, maskvals in zip(patch_list, mask_list, maskvals_list)
316+
]
317+
return feature_list
318+
319+
def createFeatureFromPatchAndMask(self, patch, mask, maskvals):
320+
# This SuperpixelClassificationBase implementation allows only the "Simple"
321+
# approach.
322+
# assert self.feature_is_image
323+
feature = self.createFeatureFromPatchAndMaskSimple(patch, mask, maskvals)
324+
return feature
325+
326+
def createFeatureListFromPatchAndMaskList(self, patch_list, mask_list, maskvals_list):
327+
# This SuperpixelClassificationBase implementation allows only the "Simple"
328+
# approach.
329+
# assert self.feature_is_image
330+
feature_list = self.createFeatureListFromPatchAndMaskListSimple(
331+
patch_list, mask_list, maskvals_list,
332+
)
333+
return feature_list
334+
303335
def createFeaturesForItem(self, gc, item, elem, featureFolderId, fileName, patchSize, prog):
304336
import large_image
305337

306338
print('Create feature', fileName)
307339
lastlog = starttime = time.time()
308340
ds = None
341+
self.initializeCreateFeatureFromPatchAndMask()
309342
with tempfile.TemporaryDirectory(dir=os.getcwd()) as tempdir:
310343
filePath = os.path.join(tempdir, fileName)
311344
imagePath = os.path.join(tempdir, item['name'])
@@ -317,57 +350,69 @@ def createFeaturesForItem(self, gc, item, elem, featureFolderId, fileName, patch
317350
tsMask = large_image.open(maskPath)
318351

319352
with h5py.File(filePath, 'w') as fptr:
320-
for idx, _ in enumerate(elem['values']):
321-
prog.item_progress(item, 0.9 * idx / len(elem['values']))
322-
bbox = elem['user']['bbox'][idx * 4: idx * 4 + 4]
323-
# use masked superpixel
324-
patch = ts.getRegion(
325-
region=dict(
326-
left=int(bbox[0]), top=int(bbox[1]),
327-
right=int(bbox[2]), bottom=int(bbox[3])),
328-
output=dict(maxWidth=patchSize, maxHeight=patchSize),
329-
fill='#000',
330-
format=large_image.constants.TILE_FORMAT_NUMPY)[0]
331-
if patch.shape[2] in (2, 4):
332-
patch = patch[:, :, :-1]
333-
scale = 1
334-
try:
335-
scale = elem['transform']['matrix'][0][0]
336-
except Exception:
337-
pass
338-
mask = tsMask.getRegion(
339-
region=dict(
340-
left=int(bbox[0] / scale), top=int(bbox[1] / scale),
341-
right=int(bbox[2] / scale), bottom=int(bbox[3] / scale)),
342-
output=dict(maxWidth=patchSize, maxHeight=patchSize),
343-
fill='#000',
344-
format=large_image.constants.TILE_FORMAT_NUMPY)[0]
345-
if mask.shape[2] == 4:
346-
mask = mask[:, :, :-1]
347-
maskvals = [[val % 256, val // 256 % 256, val // 65536 % 256]
348-
for val in [idx * 2, idx * 2 + 1]]
349-
patch = patch.copy()
350-
patch[(mask != maskvals[0]).any(axis=-1) &
351-
(mask != maskvals[1]).any(axis=-1)] = [0, 0, 0]
352-
# TODO: ensure this is uint8
353-
if not ds:
354-
ds = fptr.create_dataset(
355-
'images', (1,) + patch.shape, maxshape=(None,) + patch.shape,
356-
dtype=patch.dtype, chunks=True)
357-
else:
358-
ds.resize((ds.shape[0] + 1,) + patch.shape)
359-
ds[ds.shape[0] - 1] = patch
360-
if time.time() - lastlog > 5:
361-
lastlog = time.time()
362-
print(ds.shape, len(elem['values']),
363-
'%5.3f' % (time.time() - starttime),
364-
'%5.3f' % ((len(elem['values']) - idx - 1) / (idx + 1) *
365-
(time.time() - starttime)),
366-
item['name'])
353+
batch_size = 1024 # TODO: Is this the best value?
354+
for batch_start in range(0, len(elem['values']), batch_size):
355+
batch_list = elem['values'][batch_start: batch_start + batch_size]
356+
patch_list = []
357+
mask_list = []
358+
maskvals_list = []
359+
for idx, _ in enumerate(batch_list, start=batch_start):
360+
prog.item_progress(item, 0.9 * idx / len(elem['values']))
361+
bbox = elem['user']['bbox'][idx * 4: idx * 4 + 4]
362+
# use masked superpixel
363+
patch = ts.getRegion(
364+
region=dict(
365+
left=int(bbox[0]), top=int(bbox[1]),
366+
right=int(bbox[2]), bottom=int(bbox[3])),
367+
output=dict(maxWidth=patchSize, maxHeight=patchSize),
368+
fill='#000',
369+
format=large_image.constants.TILE_FORMAT_NUMPY)[0]
370+
if patch.shape[2] in (2, 4):
371+
patch = patch[:, :, :-1]
372+
scale = 1
373+
try:
374+
scale = elem['transform']['matrix'][0][0]
375+
except Exception:
376+
pass
377+
mask = tsMask.getRegion(
378+
region=dict(
379+
left=int(bbox[0] / scale), top=int(bbox[1] / scale),
380+
right=int(bbox[2] / scale), bottom=int(bbox[3] / scale)),
381+
output=dict(maxWidth=patchSize, maxHeight=patchSize),
382+
fill='#000',
383+
format=large_image.constants.TILE_FORMAT_NUMPY)[0]
384+
if mask.shape[2] == 4:
385+
mask = mask[:, :, :-1]
386+
maskvals = [[val % 256, val // 256 % 256, val // 65536 % 256]
387+
for val in [idx * 2, idx * 2 + 1]]
388+
patch_list.append(patch)
389+
mask_list.append(mask)
390+
maskvals_list.append(maskvals)
391+
# Make sure only the *_list forms are used subsequently
392+
del patch, mask, maskvals
393+
feature_list = self.createFeatureListFromPatchAndMaskList(
394+
patch_list, mask_list, maskvals_list,
395+
)
396+
for idx, feature in enumerate(feature_list, start=batch_start):
397+
if not ds:
398+
ds = fptr.create_dataset(
399+
'images', (1,) + feature.shape, maxshape=(None,) + feature.shape,
400+
dtype=np.float32, chunks=True)
401+
else:
402+
ds.resize((ds.shape[0] + 1,) + feature.shape)
403+
ds[ds.shape[0] - 1] = feature
404+
if time.time() - lastlog > 5:
405+
lastlog = time.time()
406+
print(ds.shape, len(elem['values']),
407+
'%5.3f' % (time.time() - starttime),
408+
'%5.3f' % ((len(elem['values']) - idx - 1) / (idx + 1) *
409+
(time.time() - starttime)),
410+
item['name'])
411+
del batch_list, patch_list, mask_list, maskvals_list, feature_list
367412
print(ds.shape, len(elem['values']), '%5.3f' % (time.time() - starttime),
368413
item['name'])
369414
prog.item_progress(item, 0.9)
370-
for attempt in Retrying(stop=stop_after_attempt(self.uploadRetries)):
415+
for attempt in tenacity.Retrying(stop=tenacity.stop_after_attempt(self.uploadRetries)):
371416
with attempt:
372417
file = gc.uploadFileToFolder(featureFolderId, filePath)
373418
prog.item_progress(item, 1)
@@ -503,11 +548,11 @@ def trainModel(self, gc, folderId, annotationName, features, modelFolderId,
503548
except AttributeError as exc:
504549
print(f'Cannot pickle history; skipping. {exc}')
505550
prog.progress(1)
506-
for attempt in Retrying(stop=stop_after_attempt(self.uploadRetries)):
551+
for attempt in tenacity.Retrying(stop=tenacity.stop_after_attempt(self.uploadRetries)):
507552
with attempt:
508553
modelFile = gc.uploadFileToFolder(modelFolderId, modelPath)
509554
print('Saved model')
510-
for attempt in Retrying(stop=stop_after_attempt(self.uploadRetries)):
555+
for attempt in tenacity.Retrying(stop=tenacity.stop_after_attempt(self.uploadRetries)):
511556
with attempt:
512557
modTrainingFile = gc.uploadFileToFolder(modelFolderId, modTrainingPath)
513558
print('Saved modTraining')
@@ -596,7 +641,7 @@ def predictLabelsForItem(self, gc, annotationName, annotationFolderId, tempdir,
596641
print_fully('annot', annot)
597642
with open(outAnnotationPath, 'w') as annotation_file:
598643
json.dump(annot, annotation_file, indent=2, sort_keys=False)
599-
for attempt in Retrying(stop=stop_after_attempt(self.uploadRetries)):
644+
for attempt in tenacity.Retrying(stop=tenacity.stop_after_attempt(self.uploadRetries)):
600645
with attempt:
601646
gc.uploadFileToItem(
602647
item['_id'], outAnnotationPath, reference=json.dumps({
@@ -614,7 +659,7 @@ def predictLabelsForItem(self, gc, annotationName, annotationFolderId, tempdir,
614659
print_fully('newAnnot', newAnnot)
615660
with open(outAnnotationPath, 'w') as annotation_file:
616661
json.dump(newAnnot, annotation_file, indent=2, sort_keys=False)
617-
for attempt in Retrying(stop=stop_after_attempt(self.uploadRetries)):
662+
for attempt in tenacity.Retrying(stop=tenacity.stop_after_attempt(self.uploadRetries)):
618663
with attempt:
619664
gc.uploadFileToItem(
620665
item['_id'], outAnnotationPath, reference=json.dumps({
@@ -706,7 +751,7 @@ def makeHeatmapsForItem(self, gc, annotationName, userId, tempdir, radius, item,
706751
print_fully('heatmaps', heatmaps)
707752
with open(outAnnotationPath, 'w') as annotation_file:
708753
json.dump(heatmaps, annotation_file, indent=2, sort_keys=False)
709-
for attempt in Retrying(stop=stop_after_attempt(self.uploadRetries)):
754+
for attempt in tenacity.Retrying(stop=tenacity.stop_after_attempt(self.uploadRetries)):
710755
with attempt:
711756
gc.uploadFileToItem(
712757
item['_id'],
@@ -784,6 +829,9 @@ def predictLabels(self, gc, folderId, annotationName, features, modelFolderId,
784829
prog.progress(1)
785830

786831
def main(self, args):
832+
self.feature_is_image = args.feature != 'vector'
833+
self.certainty = args.certainty
834+
787835
print('\n>> CLI Parameters ...\n')
788836
pprint.pprint(vars(args))
789837

0 commit comments

Comments
 (0)