@@ -357,6 +357,10 @@ async function inferenceFullVolumeSeqCovLayerPhase2(
357357 opts ,
358358 niftiImage
359359 )
360+ if ( modelEntry . isScalar ) {
361+ console . log ( ':::: This model ignores isScalar:::::::::::::::::::::::::' )
362+ modelEntry . isScalar = false
363+ }
360364 console . log ( ' Phase-2 num of tensors after generateOutputSlicesV2: ' , tf . memory ( ) . numTensors )
361365
362366 tf . dispose ( outLabelVolume )
@@ -568,7 +572,7 @@ async function inferenceFullVolumePhase2(
568572 statData . Model_Layers = await getModelNumLayers ( res )
569573 statData . Model = modelEntry . modelName
570574 // statData.Extra_Info = null
571-
575+ const isScalar = modelEntry . isScalar === true
572576 const curTensor = [ ]
573577 curTensor [ 0 ] = cropped_slices_3d_w_pad . reshape ( adjusted_input_shape )
574578 // console.log("curTensor[0] :", curTensor[0].dataSync())
@@ -619,7 +623,21 @@ async function inferenceFullVolumePhase2(
619623 try {
620624 const argMaxTime = performance . now ( )
621625 console . log ( ' Try tf.argMax for fullVolume ..' )
622- prediction_argmax = tf . argMax ( curTensor [ i ] , axis )
626+ if ( isScalar ) {
627+ const input = tf . softmax ( curTensor [ i ] , - 1 ) // shape: [..., C], with C >= 2
628+ const shape = input . shape
629+ const lastDim = shape . length - 1
630+ // Slice the last dimension to keep only channels 1 and onward (ignore channel 0)
631+ const start = Array ( lastDim ) . fill ( 0 ) . concat ( 1 )
632+ const size = shape . slice ( 0 , lastDim ) . concat ( shape [ lastDim ] - 1 )
633+ const sliced = input . slice ( start , size )
634+ // Sum across the last dimension (collapsing the remaining classes)
635+ const summed = tf . sum ( sliced , lastDim )
636+ // Remove any leading singleton dimensions (optional)
637+ prediction_argmax = summed . squeeze ( ) // only if you want to remove shape [1, ...]
638+ } else {
639+ prediction_argmax = tf . argMax ( curTensor [ i ] , axis )
640+ }
623641 console . log ( 'tf.argMax for fullVolume takes : ' , ( ( performance . now ( ) - argMaxTime ) / 1000 ) . toFixed ( 4 ) )
624642 } catch ( err1 ) {
625643 // if channel last
@@ -691,10 +709,12 @@ async function inferenceFullVolumePhase2(
691709 statData . Expect_Labels = expected_Num_labels
692710 statData . NumLabels_Match = numSegClasses === expected_Num_labels
693711
694- if ( numSegClasses !== expected_Num_labels ) {
695- // errTxt = "expected " + expected_Num_labels + " labels, but the predicted are " + numSegClasses + ". For possible solutions please refer to <a href='https://github.com/neuroneural/brainchop/wiki/FAQ#Q3' target='_blank'><b> FAQ </b></a>.", "alert-error"
696- const errTxt = 'expected ' + expected_Num_labels + ' labels, but the predicted are ' + numSegClasses
697- callbackUI ( errTxt , - 1 , errTxt )
712+ if ( ! isScalar ) {
713+ if ( numSegClasses !== expected_Num_labels ) {
714+ // errTxt = "expected " + expected_Num_labels + " labels, but the predicted are " + numSegClasses + ". For possible solutions please refer to <a href='https://github.com/neuroneural/brainchop/wiki/FAQ#Q3' target='_blank'><b> FAQ </b></a>.", "alert-error"
715+ const errTxt = 'expected ' + expected_Num_labels + ' labels, but the predicted are ' + numSegClasses
716+ callbackUI ( errTxt , - 1 , errTxt )
717+ }
698718 }
699719
700720 // -- Transpose back to original unpadded size
@@ -722,11 +742,20 @@ async function inferenceFullVolumePhase2(
722742 )
723743 console . log ( ' outLabelVolume final shape after resizing : ' , outLabelVolume . shape )
724744
745+ if ( isScalar ) {
746+ const thresh = tf . scalar ( 0.04 ) ; // threshold
747+ const scale255 = tf . scalar ( 255.0 ) ; // if you still need the scaling step
748+ const mask = outLabelVolume . greaterEqual ( thresh ) . toFloat ( ) ;
749+ outLabelVolume = outLabelVolume . mul ( mask ) ;
750+ outLabelVolume = outLabelVolume . mul ( scale255 ) ;
751+ }
725752 const filterOutWithPreMask = modelEntry . filterOutWithPreMask
726753 // To clean the skull area wrongly segmented in phase-2.
727- if ( pipeline1_out != null && opts . isBrainCropMaskBased && filterOutWithPreMask ) {
728- const bin = binarizeVolumeDataTensor ( pipeline1_out )
729- outLabelVolume = outLabelVolume . mul ( bin )
754+ if ( ! isScalar ) {
755+ if ( pipeline1_out != null && opts . isBrainCropMaskBased && filterOutWithPreMask ) {
756+ const bin = binarizeVolumeDataTensor ( pipeline1_out )
757+ outLabelVolume = outLabelVolume . mul ( bin )
758+ }
730759 }
731760
732761 startTime = performance . now ( )
0 commit comments