Skip to content

Commit 639e2e6

Browse files
committed
New (hidden) feature for scalar models
1 parent 61de209 commit 639e2e6

File tree

3 files changed

+68
-21
lines changed

3 files changed

+68
-21
lines changed

brainchop-webworker.js

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,10 @@ async function inferenceFullVolumeSeqCovLayerPhase2(
357357
opts,
358358
niftiImage
359359
)
360+
if (modelEntry.isScalar) {
361+
console.log(':::: This model ignores isScalar:::::::::::::::::::::::::')
362+
modelEntry.isScalar = false
363+
}
360364
console.log(' Phase-2 num of tensors after generateOutputSlicesV2: ', tf.memory().numTensors)
361365

362366
tf.dispose(outLabelVolume)
@@ -568,7 +572,7 @@ async function inferenceFullVolumePhase2(
568572
statData.Model_Layers = await getModelNumLayers(res)
569573
statData.Model = modelEntry.modelName
570574
// statData.Extra_Info = null
571-
575+
const isScalar = modelEntry.isScalar === true
572576
const curTensor = []
573577
curTensor[0] = cropped_slices_3d_w_pad.reshape(adjusted_input_shape)
574578
// console.log("curTensor[0] :", curTensor[0].dataSync())
@@ -619,7 +623,21 @@ async function inferenceFullVolumePhase2(
619623
try {
620624
const argMaxTime = performance.now()
621625
console.log(' Try tf.argMax for fullVolume ..')
622-
prediction_argmax = tf.argMax(curTensor[i], axis)
626+
if (isScalar) {
627+
const input = tf.softmax(curTensor[i], -1) // shape: [..., C], with C >= 2
628+
const shape = input.shape
629+
const lastDim = shape.length - 1
630+
// Slice the last dimension to keep only channels 1 and onward (ignore channel 0)
631+
const start = Array(lastDim).fill(0).concat(1)
632+
const size = shape.slice(0, lastDim).concat(shape[lastDim] - 1)
633+
const sliced = input.slice(start, size)
634+
// Sum across the last dimension (collapsing the remaining classes)
635+
const summed = tf.sum(sliced, lastDim)
636+
// Remove any leading singleton dimensions (optional)
637+
prediction_argmax = summed.squeeze() // only if you want to remove shape [1, ...]
638+
} else {
639+
prediction_argmax = tf.argMax(curTensor[i], axis)
640+
}
623641
console.log('tf.argMax for fullVolume takes : ', ((performance.now() - argMaxTime) / 1000).toFixed(4))
624642
} catch (err1) {
625643
// if channel last
@@ -691,10 +709,12 @@ async function inferenceFullVolumePhase2(
691709
statData.Expect_Labels = expected_Num_labels
692710
statData.NumLabels_Match = numSegClasses === expected_Num_labels
693711

694-
if (numSegClasses !== expected_Num_labels) {
695-
// errTxt = "expected " + expected_Num_labels + " labels, but the predicted are " + numSegClasses + ". For possible solutions please refer to <a href='https://github.com/neuroneural/brainchop/wiki/FAQ#Q3' target='_blank'><b> FAQ </b></a>.", "alert-error"
696-
const errTxt = 'expected ' + expected_Num_labels + ' labels, but the predicted are ' + numSegClasses
697-
callbackUI(errTxt, -1, errTxt)
712+
if (!isScalar) {
713+
if (numSegClasses !== expected_Num_labels) {
714+
// errTxt = "expected " + expected_Num_labels + " labels, but the predicted are " + numSegClasses + ". For possible solutions please refer to <a href='https://github.com/neuroneural/brainchop/wiki/FAQ#Q3' target='_blank'><b> FAQ </b></a>.", "alert-error"
715+
const errTxt = 'expected ' + expected_Num_labels + ' labels, but the predicted are ' + numSegClasses
716+
callbackUI(errTxt, -1, errTxt)
717+
}
698718
}
699719

700720
// -- Transpose back to original unpadded size
@@ -722,11 +742,20 @@ async function inferenceFullVolumePhase2(
722742
)
723743
console.log(' outLabelVolume final shape after resizing : ', outLabelVolume.shape)
724744

745+
if (isScalar) {
746+
const thresh = tf.scalar(0.04); // threshold
747+
const scale255 = tf.scalar(255.0); // if you still need the scaling step
748+
const mask = outLabelVolume.greaterEqual(thresh).toFloat();
749+
outLabelVolume = outLabelVolume.mul(mask);
750+
outLabelVolume = outLabelVolume.mul(scale255);
751+
}
725752
const filterOutWithPreMask = modelEntry.filterOutWithPreMask
726753
// To clean the skull area wrongly segmented in phase-2.
727-
if (pipeline1_out != null && opts.isBrainCropMaskBased && filterOutWithPreMask) {
728-
const bin = binarizeVolumeDataTensor(pipeline1_out)
729-
outLabelVolume = outLabelVolume.mul(bin)
754+
if (!isScalar) {
755+
if (pipeline1_out != null && opts.isBrainCropMaskBased && filterOutWithPreMask) {
756+
const bin = binarizeVolumeDataTensor(pipeline1_out)
757+
outLabelVolume = outLabelVolume.mul(bin)
758+
}
730759
}
731760

732761
startTime = performance.now()

index.html

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
<select id="modelSelect">
2929
</select>
3030
&nbsp;
31+
<label for="scalarCheck" hidden>Scalar</label>
32+
<input type="checkbox" id="scalarCheck" hidden unchecked />
33+
&nbsp;
3134
<button disabled id="createMeshBtn">Create Mesh</button>
3235
&nbsp;
3336
<br>

main.js

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ async function main() {
8383
await ensureConformed()
8484
let model = inferenceModelsList[this.selectedIndex]
8585
model.isNvidia = false
86+
model.isScalar = scalarCheck.checked
8687
const rendererInfo = nv1.gl.getExtension("WEBGL_debug_renderer_info")
8788
if (rendererInfo) {
8889
model.isNvidia = nv1.gl
@@ -151,6 +152,9 @@ async function main() {
151152
nv1.setClipPlane([2, 0, 90])
152153
}
153154
}
155+
scalarCheck.onchange = function () {
156+
modelSelect.selectedIndex = -1
157+
}
154158
function doLoadImage() {
155159
saveBtn.disabled = true
156160
opacitySlider0.oninput()
@@ -167,18 +171,26 @@ async function main() {
167171
overlayVolume.hdr.scl_inter = 0
168172
overlayVolume.hdr.scl_slope = 1
169173
overlayVolume.img = new Uint8Array(img)
170-
if (modelEntry.colormapPath) {
171-
let cmap = await fetchJSON(modelEntry.colormapPath)
172-
overlayVolume.setColormapLabel(cmap)
173-
// n.b. most models create indexed labels, but those without colormap mask scalar input
174-
overlayVolume.hdr.intent_code = 1002 // NIFTI_INTENT_LABEL
174+
const isScalar = modelEntry.isScalar === true
175+
if (isScalar) {
176+
overlayVolume.hdr.scl_slope = 1 / 255
177+
overlayVolume.colormap = "viridis"
175178
} else {
176-
let colormap = opts.atlasSelectedColorTable.toLowerCase()
177-
const cmaps = nv1.colormaps()
178-
if (!cmaps.includes(colormap)) {
179-
colormap = "actc"
179+
180+
if (modelEntry.colormapPath) {
181+
let cmap = await fetchJSON(modelEntry.colormapPath)
182+
overlayVolume.setColormapLabel(cmap)
183+
// n.b. most models create indexed labels, but those without colormap mask scalar input
184+
overlayVolume.hdr.intent_code = 1002 // NIFTI_INTENT_LABEL
185+
} else {
186+
let colormap = opts.atlasSelectedColorTable.toLowerCase()
187+
const cmaps = nv1.colormaps()
188+
if (!cmaps.includes(colormap)) {
189+
colormap = "actc"
190+
}
191+
overlayVolume.colormap = colormap
180192
}
181-
overlayVolume.colormap = colormap
193+
182194
}
183195
overlayVolume.opacity = opacitySlider1.value / 255
184196
await nv1.addVolume(overlayVolume)
@@ -243,12 +255,15 @@ async function main() {
243255
console.log(`Execution time: ${Math.round(performance.now() - startTime)} ms`)
244256
}
245257
async function applyFaster() {
246-
const niiBuffer = await nv1.saveImage({volumeByIndex: nv1.volumes.length - 1}).buffer
258+
const niiBuffer = await nv1.saveImage({volumeByIndex: nv1.volumes.length - 1})
247259
const niiFile = new File([niiBuffer], 'image.nii')
248260
let processor = niimath.image(niiFile)
249261
loadingCircle.classList.remove('hidden')
250262
//mesh with specified isosurface
251-
const isoValue = 0.5
263+
let isoValue = 0.5
264+
if (nv1.volumes[nv1.volumes.length - 1].hdr.intent_code === 0) {
265+
isoValue = 222 //isScalar
266+
}
252267
//const largestCheckValue = largestCheck.checked
253268
let reduce = Math.min(Math.max(Number(shrinkPct.value) / 100, 0.01), 1)
254269
let hollowSz = Number(hollowSelect.value )

0 commit comments

Comments
 (0)