Skip to content

Commit 7c64a87

Browse files
authored
Merge pull request #66 from neuroneural/mindgrab
Mindgrab: powerful omnimodal skull stripping model
2 parents 536cc24 + f40e344 commit 7c64a87

File tree

7 files changed

+2132
-48
lines changed

7 files changed

+2132
-48
lines changed

brainchop-mainthread.js

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ import {
44
addZeroPaddingTo3dTensor,
55
applyMriThreshold,
66
binarizeVolumeDataTensor,
7-
convByOutputChannelAndInputSlicing,
7+
convByOutputChannelAndInputSlicing,
8+
gn_convByOutputChannelAndInputSlicing,
9+
LayerNormInPlace,
810
draw3dObjBoundingVolume,
911
firstLastNonZero3D,
1012
generateBrainMask,
@@ -191,15 +193,31 @@ async function inferenceFullVolumeSeqCovLayerPhase2(
191193
if (res.layers[i].activation.getClassName() !== 'linear') {
192194
curTensor[i] = await res.layers[i].apply(curTensor[i - 1])
193195
} else {
194-
curTensor[i] = await convByOutputChannelAndInputSlicing(
195-
curTensor[i - 1],
196-
res.layers[i].getWeights()[0],
197-
res.layers[i].getWeights()[1],
198-
res.layers[i].strides,
199-
res.layers[i].padding,
200-
res.layers[i].dilationRate,
201-
3
202-
) // important for memory use
196+
// Check if the layer's name ends with our special suffix
197+
if (res.layers[i].name.endsWith('_gn')) {
198+
// Use the new GroupNorm-aware convolution function
199+
curTensor[i] = await gn_convByOutputChannelAndInputSlicing(
200+
curTensor[i - 1],
201+
res.layers[i].getWeights()[0],
202+
res.layers[i].getWeights()[1], // Can be undefined, the function handles it
203+
res.layers[i].strides,
204+
res.layers[i].padding,
205+
res.layers[i].dilationRate,
206+
3
207+
);
208+
} else {
209+
// Use the original convolution function for non-GN layers
210+
curTensor[i] = await convByOutputChannelAndInputSlicing(
211+
curTensor[i - 1],
212+
res.layers[i].getWeights()[0],
213+
res.layers[i].getWeights()[1],
214+
res.layers[i].strides,
215+
res.layers[i].padding,
216+
res.layers[i].dilationRate,
217+
3
218+
);
219+
}
220+
203221
}
204222
tf.dispose(curTensor[i - 1])
205223
} catch (err) {
@@ -555,8 +573,13 @@ async function inferenceFullVolumePhase2(
555573
const curTensor = []
556574
curTensor[0] = cropped_slices_3d_w_pad.reshape(adjusted_input_shape)
557575
const timer = window.setInterval(async function () {
558-
try {
559-
curTensor[i] = res.layers[i].apply(curTensor[i - 1])
576+
try {
577+
let resultTensor = await res.layers[i].apply(curTensor[i - 1]);
578+
if (res.layers[i].name.endsWith('_gn')) {
579+
// LayerNormInPlace will dispose of the old resultTensor internally.
580+
resultTensor = LayerNormInPlace(resultTensor);
581+
}
582+
curTensor[i] = resultTensor;
560583
} catch (err) {
561584
callbackUI(err.message, -1, err.message)
562585
window.clearInterval(timer)
@@ -573,6 +596,7 @@ async function inferenceFullVolumePhase2(
573596

574597
return 0
575598
}
599+
576600
callbackUI('Layer ' + i.toString(), (i + 1) / layersLength)
577601
console.log('layer output Tensor shape : ', curTensor[i].shape)
578602
console.log('layer count params ', res.layers[i].countParams())

brainchop-parameters.js

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,5 +376,51 @@ const inferenceModelsList = [
376376
inferenceDelay: 100, // Delay in ms time while looping layers applying.
377377
description:
378378
'FreeSurfer aparc+aseg atlas 104 parcellate brain areas into 104 regions. It contains a combination of the Desikan-Killiany atlas for cortical area and also segmentation of subcortical regions. The model use sequential convolution for inference to overcome browser memory limitations but leads to longer computation time. '
379-
}
379+
},
380+
{
381+
id: 16,
382+
type: 'Brain_Extraction',
383+
path: '/models/mindgrab/model.json',
384+
modelName: '\u{1F9E0}\u{1FA93} omnimodal Skull Strip (High Mem, Fast)',
385+
preModelId: null, // Model run first e.g. crop the brain { null, 1, 2, .. }
386+
preModelPostProcess: false, // If true, perform postprocessing to remove noisy regions after preModel inference generate output.
387+
isBatchOverlapEnable: false, // create extra overlap batches for inference
388+
numOverlapBatches: 0, // Number of extra overlap batches for inference
389+
enableTranspose: true, // Keras and tfjs input orientation may need a tranposing step to be matched
390+
enableCrop: true, // For speed-up inference, crop brain from background before feeding to inference model to lower memory use.
391+
cropPadding: 0, // Padding size add to cropped brain
392+
autoThreshold: 0, // Threshold between 0 and 1, given no preModel and tensor is normalized either min-max or by quantiles. Will remove noisy voxels around brain
393+
enableQuantileNorm: false, // Some models needs Quantile Normaliztion.
394+
filterOutWithPreMask: false, // Can be used to multiply final output with premodel output mask to crean noisy areas
395+
enableSeqConv: false, // For low memory system and low configuration, enable sequential convolution instead of last layer
396+
textureSize: 0, // Requested Texture size for the model, if unknown can be 0.
397+
warning:
398+
"This model may need dedicated graphics card. For more info please check with Browser Resources <i class='fa fa-cogs'></i>.",
399+
inferenceDelay: 100, // Delay in ms time while looping layers applying.
400+
description:
401+
'Extract the brain high accuracy model operates on full T1 image in a single pass, but uses only 11 filters per layer. Can work on dedicated graphics cards. Still more accurate than the fast version.'
402+
},
403+
{
404+
id: 17,
405+
type: 'Brain_Extraction',
406+
path: '/models/mindgrab/model.json',
407+
modelName: '\u{1F9E0}\u{1FA93} omnimodal Skull Strip (Low Mem, Slow)',
408+
preModelId: null, // Model run first e.g. crop the brain { null, 1, 2, .. }
409+
preModelPostProcess: false, // If true, perform postprocessing to remove noisy regions after preModel inference generate output.
410+
isBatchOverlapEnable: false, // create extra overlap batches for inference
411+
numOverlapBatches: 0, // Number of extra overlap batches for inference
412+
enableTranspose: true, // Keras and tfjs input orientation may need a tranposing step to be matched
413+
enableCrop: true, // For speed-up inference, crop brain from background before feeding to inference model to lower memory use.
414+
cropPadding: 0, // Padding size add to cropped brain
415+
autoThreshold: 0, // Threshold between 0 and 1, given no preModel and tensor is normalized either min-max or by quantiles. Will remove noisy voxels around brain
416+
enableQuantileNorm: false, // Some models needs Quantile Normaliztion.
417+
filterOutWithPreMask: false, // Can be used to multiply final output with premodel output mask to crean noisy areas
418+
enableSeqConv: true, // For low memory system and low configuration, enable sequential convolution instead of last layer
419+
textureSize: 0, // Requested Texture size for the model, if unknown can be 0.
420+
warning:
421+
"This model may need dedicated graphics card. For more info please check with Browser Resources <i class='fa fa-cogs'></i>.",
422+
inferenceDelay: 100, // Delay in ms time while looping layers applying.
423+
description:
424+
'Extract the brain high accuracy model operates on image in a single pass, but uses only 11 filters per layer. Can work on dedicated graphics cards. Still more accurate than the fast version.'
425+
},
380426
] // inferenceModelsList

brainchop-webworker.js

Lines changed: 62 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,26 @@
11
import * as tf from '@tensorflow/tfjs'
22
import { inferenceModelsList } from './brainchop-parameters.js'
33
import {
4-
addZeroPaddingTo3dTensor,
5-
applyMriThreshold,
6-
binarizeVolumeDataTensor,
7-
convByOutputChannelAndInputSlicing,
8-
draw3dObjBoundingVolume,
9-
firstLastNonZero3D,
10-
generateBrainMask,
11-
generateOutputSlicesV2,
12-
getAllSlicesDataAsTF3D,
13-
getModelNumLayers,
14-
getModelNumParameters,
15-
isModelChnlLast,
16-
load_model,
17-
minMaxNormalizeVolumeData,
18-
quantileNormalizeVolumeData,
19-
removeZeroPaddingFrom3dTensor,
20-
resizeWithZeroPadding,
21-
SequentialConvLayer
4+
addZeroPaddingTo3dTensor,
5+
applyMriThreshold,
6+
binarizeVolumeDataTensor,
7+
convByOutputChannelAndInputSlicing,
8+
gn_convByOutputChannelAndInputSlicing,
9+
LayerNormInPlace,
10+
draw3dObjBoundingVolume,
11+
firstLastNonZero3D,
12+
generateBrainMask,
13+
generateOutputSlicesV2,
14+
getAllSlicesDataAsTF3D,
15+
getModelNumLayers,
16+
getModelNumParameters,
17+
isModelChnlLast,
18+
load_model,
19+
minMaxNormalizeVolumeData,
20+
quantileNormalizeVolumeData,
21+
removeZeroPaddingFrom3dTensor,
22+
resizeWithZeroPadding,
23+
SequentialConvLayer
2224
} from './tensor-utils.js'
2325

2426
function callbackUI(message = '', progressFrac = -1, modalMessage = '', statData = []) {
@@ -209,15 +211,31 @@ async function inferenceFullVolumeSeqCovLayerPhase2(
209211
if (res.layers[i].activation.getClassName() !== 'linear') {
210212
curTensor[i] = await res.layers[i].apply(curTensor[i - 1])
211213
} else {
212-
curTensor[i] = await convByOutputChannelAndInputSlicing(
213-
curTensor[i - 1],
214-
res.layers[i].getWeights()[0],
215-
res.layers[i].getWeights()[1],
216-
res.layers[i].strides,
217-
res.layers[i].padding,
218-
res.layers[i].dilationRate,
219-
3
220-
) // important for memory use
214+
// Check if the layer's name ends with our special suffix
215+
if (res.layers[i].name.endsWith('_gn')) {
216+
// Use the new GroupNorm-aware convolution function
217+
curTensor[i] = await gn_convByOutputChannelAndInputSlicing(
218+
curTensor[i - 1],
219+
res.layers[i].getWeights()[0],
220+
res.layers[i].getWeights()[1], // Can be undefined, the function handles it
221+
res.layers[i].strides,
222+
res.layers[i].padding,
223+
res.layers[i].dilationRate,
224+
3
225+
);
226+
} else {
227+
// Use the original convolution function for non-GN layers
228+
curTensor[i] = await convByOutputChannelAndInputSlicing(
229+
curTensor[i - 1],
230+
res.layers[i].getWeights()[0],
231+
res.layers[i].getWeights()[1],
232+
res.layers[i].strides,
233+
res.layers[i].padding,
234+
res.layers[i].dilationRate,
235+
3
236+
);
237+
}
238+
221239
}
222240

223241
tf.dispose(curTensor[i - 1])
@@ -575,8 +593,12 @@ async function inferenceFullVolumePhase2(
575593

576594
while (true) {
577595
try {
578-
// -- curTensor[i] = res.layers[i].apply( curTensor[i-1])
579-
curTensor[i] = res.layers[i].apply(curTensor[i - 1])
596+
let resultTensor = await res.layers[i].apply(curTensor[i - 1]);
597+
if (res.layers[i].name.endsWith('_gn')) {
598+
// LayerNormInPlace will dispose of the old resultTensor internally.
599+
resultTensor = LayerNormInPlace(resultTensor);
600+
}
601+
curTensor[i] = resultTensor;
580602
} catch (err) {
581603
callbackUI(err.message, -1, err.message)
582604
tf.engine().endScope()
@@ -592,10 +614,20 @@ async function inferenceFullVolumePhase2(
592614

593615
return 0
594616
}
617+
if (res.layers[i].activation.getClassName() == 'linear') {
618+
// --- FORCE EXECUTION BY READING THE FIRST ELEMENT (Corrected) ---
619+
// 1. Create a tiny slice tensor.
620+
const firstElement = curTensor[i].slice([0, 0, 0, 0, 0], [1, 1, 1, 1, 1]);
621+
// 2. Awaiting its data forces GPU synchronization.
622+
await firstElement.data();
623+
// 3. Manually dispose of the temporary slice tensor now that it has served its purpose.
624+
firstElement.dispose();
625+
// --- SYNCHRONIZATION IS NOW COMPLETE ---
626+
}
595627
callbackUI('Layer ' + i.toString(), (i + 1) / layersLength)
596628
console.log('layer output Tensor shape : ', curTensor[i].shape)
597629
console.log('layer count params ', res.layers[i].countParams())
598-
res.layers[i].dispose()
630+
//res.layers[i].dispose()
599631
curTensor[i - 1].dispose()
600632
if (tf.memory().unreliable) {
601633
const unreliableReasons = 'unreliable reasons :' + tf.memory().reasons
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"R": [ 0, 255],
3+
"G": [ 0, 255],
4+
"B": [ 0, 255],
5+
"labels": [ "background", "brain mask"]
6+
}

public/models/mindgrab/model.bin

571 KB
Binary file not shown.

0 commit comments

Comments
 (0)