@@ -200,12 +200,13 @@ namespace lfs::io {
200200 }
201201 }
202202
203- lfs::core::Tensor binarize_mask (lfs::core::Tensor mask) {
203+ lfs::core::Tensor process_mask (lfs::core::Tensor mask, bool binarize ) {
204204 if (!mask.is_valid ())
205205 return {};
206206 if (mask.dtype () == lfs::core::DataType::UInt8 || mask.dtype () == lfs::core::DataType::Bool)
207207 return mask.contiguous ();
208- return mask.ge (0 .5f ).to (lfs::core::DataType::UInt8).contiguous ();
208+ mask = binarize ? mask.ge (0 .5f ) : mask * 255 .f ;
209+ return mask.to (lfs::core::DataType::UInt8).contiguous ();
209210 }
210211
211212 [[nodiscard]] bool is_jpeg_file_signature (const std::filesystem::path& path) {
@@ -1192,7 +1193,7 @@ namespace lfs::io {
11921193 throw std::runtime_error (" Invalid mask tensor" );
11931194 }
11941195
1195- mask_tensor = binarize_mask (std::move (mask_tensor));
1196+ mask_tensor = process_mask (std::move (mask_tensor), batch[i]. mask_params . threshold > 0 );
11961197 try_complete_pair (batch[i].sequence_id , std::nullopt , std::move (mask_tensor), nullptr );
11971198
11981199 } else {
@@ -1397,7 +1398,7 @@ namespace lfs::io {
13971398 }
13981399 }
13991400
1400- alpha = binarize_mask (std::move (alpha));
1401+ alpha = process_mask (std::move (alpha), item. alpha_mask_params . threshold > 0 );
14011402 if (const cudaError_t err = cudaStreamSynchronize (nullptr ); err != cudaSuccess) {
14021403 throw std::runtime_error (std::string (" CUDA sync failed: " ) + cudaGetErrorString (err));
14031404 }
@@ -1489,7 +1490,7 @@ namespace lfs::io {
14891490 }
14901491
14911492 if (item.is_mask ) {
1492- aux_tensor = binarize_mask (std::move (aux_tensor));
1493+ aux_tensor = process_mask (std::move (aux_tensor), item. mask_params . threshold > 0 );
14931494 } else {
14941495 aux_tensor = aux_tensor.contiguous ();
14951496 }
0 commit comments