Skip to content

Handle max_num_warnings properly #1089

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions native/python/src/fairseq2n/bindings/data/data_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ def_data_pipeline(py::module_ &data_module)

.def_property_readonly("is_broken", &data_pipeline::is_broken)

.def_property_readonly("warning_count", &data_pipeline::warning_count)

// state_dict
.def(
"state_dict",
Expand Down
14 changes: 11 additions & 3 deletions native/src/fairseq2n/data/data_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,16 @@ data_pipeline::next()
if (ex.recoverable() && warning_count_ < max_num_warnings_) {
warning_count_++;

// TODO: log exception
// Log the exception with the current warning count
(void) fprintf(stderr, "Data pipeline warning (%zu/%zu): %s\n",
warning_count_, max_num_warnings_, ex.what());

// Continue to the next example
continue;
} else {
if (max_num_warnings_ > 0) {
// TODO: log max number of warnings reached.
if (max_num_warnings_ > 0 && warning_count_ >= max_num_warnings_) {
(void) fprintf(stderr, "Data pipeline error: Maximum number of warnings (%zu) reached.\n",
max_num_warnings_);
}

// If the error is not recoverable, any further attempt to read
Expand All @@ -87,6 +93,8 @@ data_pipeline::reset(bool reset_rng)

try {
source_->reset(reset_rng);
// Reset warning counter when pipeline is reset
warning_count_ = 0;
} catch (const std::exception &) {
is_broken_ = true;

Expand Down
6 changes: 6 additions & 0 deletions native/src/fairseq2n/data/data_pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ class FAIRSEQ2_API data_pipeline {
return is_broken_;
}

std::size_t
warning_count() const noexcept
{
return warning_count_;
}

private:
bool
is_initialized() const noexcept;
Expand Down
10 changes: 9 additions & 1 deletion native/src/fairseq2n/data/filter_data_source.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,15 @@ filter_data_source::finitude_type() const noexcept
bool
filter_data_source::invoke_function(data &example)
{
return predicate_fn_(example);
try {
return predicate_fn_(example);
} catch (const std::exception &ex) {
// Convert any exception into a recoverable data_pipeline_error
throw data_pipeline_error(
std::string("Error in filter function: ") + ex.what(),
example, // Pass the example that caused the error
true); // Mark as recoverable
}
}

} // fairseq2n::detail
4 changes: 1 addition & 3 deletions tests/unit/data/data_pipeline/test_data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def fn(d: int) -> bool:

assert output == [1, 3, 5]

@pytest.mark.skip("need additional work in data_pipeline::next")
def test_next_does_not_raise_error_when_num_errors_is_less_than_max_num_warnings(
self,
) -> None:
Expand All @@ -89,7 +88,6 @@ def fn(d: int) -> bool:
# TODO: assert log warning

@pytest.mark.parametrize("max_num_warnings", [0, 1, 2])
@pytest.mark.skip("need additional work in data_pipeline::next")
def test_next_raises_error_when_num_errors_exceed_max_num_warnings(
self, max_num_warnings: int
) -> None:
Expand All @@ -109,7 +107,7 @@ def fn(d: int) -> bool:

pipeline = read_sequence(seq).filter(fn).and_return(max_num_warnings)

with pytest.raises(ValueError):
with pytest.raises(DataPipelineError):
for _ in pipeline:
pass

Expand Down
7 changes: 4 additions & 3 deletions tests/unit/data/data_pipeline/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import pytest

from fairseq2.data import read_sequence
from fairseq2.data import DataPipelineError, read_sequence


class TestFilterOp:
Expand All @@ -34,8 +34,9 @@ def fn(d: int) -> bool:

pipeline = read_sequence([1, 2, 3, 4]).filter(fn).and_return()

with pytest.raises(ValueError) as exc_info:
with pytest.raises(DataPipelineError) as exc_info:
for d in pipeline:
pass

assert str(exc_info.value) == "filter error"
# Check that the original error message is included in the DataPipelineError
assert "filter error" in str(exc_info.value)
Loading