Skip to content

Commit beb02b7

Browse files
authored
New mask mode combining Ignore and Segment [Volinga] (#1314)
Multimask mode added
1 parent 5c4cf9a commit beb02b7

22 files changed

Lines changed: 119 additions & 35 deletions

File tree

src/core/argument_parser.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,12 +273,13 @@ namespace {
273273
::args::Group mask_sep(parser, " ");
274274
::args::Group mask_group(parser, "MASK / DEPTH OPTIONS:");
275275
::args::MapFlag<std::string, lfs::core::param::MaskMode> mask_mode(mask_group, "mask_mode",
276-
"Mask mode: none, segment, ignore, alpha_consistent (default: none)",
276+
"Mask mode: none, segment, ignore, segment_and_ignore, alpha_consistent (default: none)",
277277
{"mask-mode"},
278278
std::unordered_map<std::string, lfs::core::param::MaskMode>{
279279
{"none", lfs::core::param::MaskMode::None},
280280
{"segment", lfs::core::param::MaskMode::Segment},
281281
{"ignore", lfs::core::param::MaskMode::Ignore},
282+
{"segment_and_ignore", lfs::core::param::MaskMode::SegmentAndIgnore},
282283
{"alpha_consistent", lfs::core::param::MaskMode::AlphaConsistent}});
283284
::args::Flag invert_masks(mask_group, "invert_masks", "Invert mask values (swap object/background)", {"invert-masks"});
284285
::args::Flag no_alpha_as_mask(mask_group, "no_alpha_as_mask", "Disable automatic alpha-as-mask for RGBA images", {"no-alpha-as-mask"});

src/core/camera.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,8 @@ namespace lfs::core {
384384
}
385385

386386
Tensor Camera::load_and_get_mask(const int resize_factor, const int max_width,
387-
const bool invert_mask, const float mask_threshold) {
387+
const bool invert_mask, const float mask_threshold,
388+
const bool binarize) {
388389
if (_mask_loaded && _cached_mask.is_valid()) {
389390
return _cached_mask;
390391
}
@@ -446,7 +447,7 @@ namespace lfs::core {
446447
}
447448

448449
// Threshold before undistort; final binarization happens after geometric resampling.
449-
if (mask_threshold > 0.0f && mask_threshold < 1.0f) {
450+
if (binarize && mask_threshold > 0.0f && mask_threshold < 1.0f) {
450451
mask = mask.ge(mask_threshold).to(DataType::Float32);
451452
}
452453

@@ -458,7 +459,11 @@ namespace lfs::core {
458459
mask = undistort_mask(mask, scaled, _stream);
459460
}
460461

461-
mask = mask.ge(0.5f).to(DataType::UInt8).contiguous();
462+
if (binarize) {
463+
mask = mask.ge(0.5f).to(DataType::UInt8).contiguous();
464+
} else {
465+
mask = (mask * 255.f).to(DataType::UInt8).contiguous();
466+
}
462467
_cached_mask = mask;
463468
_mask_loaded = true;
464469

src/core/include/core/camera.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ namespace lfs::core {
6060

6161
// Load mask from disk, process it, and return it (cached)
6262
Tensor load_and_get_mask(int resize_factor = -1, int max_width = 0,
63-
bool invert_mask = false, float mask_threshold = 0.5f);
63+
bool invert_mask = false, float mask_threshold = 0.5f, bool binarize = true);
6464

6565
// Load depth map from disk, convert to [H,W] float32 [0,1], and return it (cached)
6666
Tensor load_and_get_depth(int resize_factor = -1, int max_width = 0);

src/core/include/core/parameters.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ namespace lfs::core {
2525
None, // No masking applied
2626
Segment, // Soft penalty to enforce alpha→0 in masked areas
2727
Ignore, // Completely ignore masked regions in loss
28+
SegmentAndIgnore, // 3-band mask (0-255): value<128 ignore, 128<=value<=250 segment, value>250 keep
2829
AlphaConsistent // Enforce exact alpha values from mask
2930
};
3031

src/core/parameters.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ namespace lfs::core {
166166
}
167167

168168
// Mask parameters
169-
static constexpr const char* MASK_MODE_NAMES[] = {"none", "segment", "ignore", "alpha_consistent"};
169+
static constexpr const char* MASK_MODE_NAMES[] = {"none", "segment", "ignore", "segment_and_ignore", "alpha_consistent"};
170170
opt_json["mask_mode"] = MASK_MODE_NAMES[static_cast<int>(mask_mode)];
171171
opt_json["invert_masks"] = invert_masks;
172172
opt_json["mask_opacity_penalty_weight"] = mask_opacity_penalty_weight;
@@ -495,6 +495,8 @@ namespace lfs::core {
495495
params.mask_mode = MaskMode::Segment;
496496
} else if (mode == "ignore") {
497497
params.mask_mode = MaskMode::Ignore;
498+
} else if (mode == "segment_and_ignore") {
499+
params.mask_mode = MaskMode::SegmentAndIgnore;
498500
} else if (mode == "alpha_consistent") {
499501
params.mask_mode = MaskMode::AlphaConsistent;
500502
}

src/io/pipelined_image_loader.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -200,12 +200,13 @@ namespace lfs::io {
200200
}
201201
}
202202

203-
lfs::core::Tensor binarize_mask(lfs::core::Tensor mask) {
203+
lfs::core::Tensor process_mask(lfs::core::Tensor mask, bool binarize) {
204204
if (!mask.is_valid())
205205
return {};
206206
if (mask.dtype() == lfs::core::DataType::UInt8 || mask.dtype() == lfs::core::DataType::Bool)
207207
return mask.contiguous();
208-
return mask.ge(0.5f).to(lfs::core::DataType::UInt8).contiguous();
208+
mask = binarize ? mask.ge(0.5f) : mask * 255.f;
209+
return mask.to(lfs::core::DataType::UInt8).contiguous();
209210
}
210211

211212
[[nodiscard]] bool is_jpeg_file_signature(const std::filesystem::path& path) {
@@ -1192,7 +1193,7 @@ namespace lfs::io {
11921193
throw std::runtime_error("Invalid mask tensor");
11931194
}
11941195

1195-
mask_tensor = binarize_mask(std::move(mask_tensor));
1196+
mask_tensor = process_mask(std::move(mask_tensor), batch[i].mask_params.threshold > 0);
11961197
try_complete_pair(batch[i].sequence_id, std::nullopt, std::move(mask_tensor), nullptr);
11971198

11981199
} else {
@@ -1397,7 +1398,7 @@ namespace lfs::io {
13971398
}
13981399
}
13991400

1400-
alpha = binarize_mask(std::move(alpha));
1401+
alpha = process_mask(std::move(alpha), item.alpha_mask_params.threshold > 0);
14011402
if (const cudaError_t err = cudaStreamSynchronize(nullptr); err != cudaSuccess) {
14021403
throw std::runtime_error(std::string("CUDA sync failed: ") + cudaGetErrorString(err));
14031404
}
@@ -1489,7 +1490,7 @@ namespace lfs::io {
14891490
}
14901491

14911492
if (item.is_mask) {
1492-
aux_tensor = binarize_mask(std::move(aux_tensor));
1493+
aux_tensor = process_mask(std::move(aux_tensor), item.mask_params.threshold > 0);
14931494
} else {
14941495
aux_tensor = aux_tensor.contiguous();
14951496
}

src/python/lfs/py_params.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ namespace lfs::python {
115115
{{"None", MaskMode::None},
116116
{"Segment", MaskMode::Segment},
117117
{"Ignore", MaskMode::Ignore},
118+
{"SegmentAndIgnore", MaskMode::SegmentAndIgnore},
118119
{"AlphaConsistent", MaskMode::AlphaConsistent}},
119120
"Attention mask behavior during training")
120121
.bool_prop(&OptimizationParameters::invert_masks,
@@ -999,6 +1000,7 @@ namespace lfs::python {
9991000
.value("NONE", MaskMode::None)
10001001
.value("SEGMENT", MaskMode::Segment)
10011002
.value("IGNORE", MaskMode::Ignore)
1003+
.value("SEGMENT_AND_IGNORE", MaskMode::SegmentAndIgnore)
10021004
.value("ALPHA_CONSISTENT", MaskMode::AlphaConsistent);
10031005

10041006
nb::enum_<BackgroundMode>(m, "BackgroundMode")

src/python/lfs_plugins/training_panel.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ def _depth_loss_mode_or_default(mode):
186186
"mask_none": "training.options.mask.none",
187187
"mask_segment": "training.options.mask.segment",
188188
"mask_ignore": "training.options.mask.ignore",
189+
"mask_segment_and_ignore": "training.options.mask.segment_and_ignore",
189190
"mask_alpha_consistent": "training.options.mask.alpha_consistent",
190191
"depth_loss_pearson": "training.options.depth_loss.pearson",
191192
"depth_loss_adaptive_warped_l1": "training.options.depth_loss.adaptive_warped_l1",
@@ -560,7 +561,11 @@ def _iteration():
560561
)
561562
model.bind_func(
562563
"dep_mask_segment",
563-
lambda: p() is not None and p().has_params() and p().mask_mode.value == 1,
564+
lambda: p() is not None and p().has_params() and (p().mask_mode.value == 1 or p().mask_mode.value == 3),
565+
)
566+
model.bind_func(
567+
"dep_mask_threshold",
568+
lambda: p() is not None and p().has_params() and p().mask_mode.value != 3,
564569
)
565570
model.bind_func(
566571
"dep_depth_loss",
@@ -2413,6 +2418,7 @@ def _on_strategy_conflict(
24132418
tr("training.options.mask.none"),
24142419
tr("training.options.mask.segment"),
24152420
tr("training.options.mask.ignore"),
2421+
tr("training.options.mask.segment_and_ignore"),
24162422
tr("training.options.mask.alpha_consistent"),
24172423
]
24182424
changed, new_idx = layout.combo("##py_mask_mode", mask_idx, mask_mode_items)

src/python/stubs/lichtfeld/__init__.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1771,7 +1771,9 @@ class MaskMode(enum.Enum):
17711771

17721772
IGNORE = 2
17731773

1774-
ALPHA_CONSISTENT = 3
1774+
SEGMENT_AND_IGNORE = 3
1775+
1776+
ALPHA_CONSISTENT = 4
17751777

17761778
class BackgroundMode(enum.Enum):
17771779
SOLID_COLOR = 0

src/training/dataset.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,11 +579,13 @@ namespace lfs::training {
579579
// Direct-scene plugins attach masks as in-memory tensors
580580
// via Camera::set_mask_tensor — load_and_get_mask returns
581581
// the processed-and-cached tensor (skips file I/O).
582+
bool segment_and_ignore = aux_config_.mask_threshold <= 0.f;
582583
auto m = cam->load_and_get_mask(
583584
dataset_->get_resize_factor(),
584585
dataset_->get_max_width(),
585586
aux_config_.invert_masks,
586-
aux_config_.mask_threshold);
587+
aux_config_.mask_threshold,
588+
!segment_and_ignore);
587589
if (m.is_valid()) {
588590
example.mask = std::move(m);
589591
}

0 commit comments

Comments
 (0)