From 9453aa93573a0ddfe074066dc7fcd0e229759d39 Mon Sep 17 00:00:00 2001 From: Julien Yao Date: Wed, 19 Mar 2025 14:42:44 +0100 Subject: [PATCH 1/4] test --- .../src/fairseq2n/bindings/data/data_pipeline.cc | 2 ++ native/src/fairseq2n/data/data_pipeline.cc | 14 +++++++++++--- native/src/fairseq2n/data/data_pipeline.h | 6 ++++++ native/src/fairseq2n/data/filter_data_source.cc | 10 +++++++++- 4 files changed, 28 insertions(+), 4 deletions(-) diff --git a/native/python/src/fairseq2n/bindings/data/data_pipeline.cc b/native/python/src/fairseq2n/bindings/data/data_pipeline.cc index deb6f4891..5f7f536ec 100644 --- a/native/python/src/fairseq2n/bindings/data/data_pipeline.cc +++ b/native/python/src/fairseq2n/bindings/data/data_pipeline.cc @@ -228,6 +228,8 @@ def_data_pipeline(py::module_ &data_module) .def_property_readonly("is_broken", &data_pipeline::is_broken) + .def_property_readonly("warning_count", &data_pipeline::warning_count) + // state_dict .def( "state_dict", diff --git a/native/src/fairseq2n/data/data_pipeline.cc b/native/src/fairseq2n/data/data_pipeline.cc index ee6b4edd7..a84988f7e 100644 --- a/native/src/fairseq2n/data/data_pipeline.cc +++ b/native/src/fairseq2n/data/data_pipeline.cc @@ -58,10 +58,16 @@ data_pipeline::next() if (ex.recoverable() && warning_count_ < max_num_warnings_) { warning_count_++; - // TODO: log exception + // Log the exception with the current warning count + fprintf(stderr, "Data pipeline warning (%zu/%zu): %s\n", + warning_count_, max_num_warnings_, ex.what()); + + // Continue to the next example + continue; } else { - if (max_num_warnings_ > 0) { - // TODO: log max number of warnings reached. + if (max_num_warnings_ > 0 && warning_count_ >= max_num_warnings_) { + fprintf(stderr, "Data pipeline error: Maximum number of warnings (%zu) reached.\n", + max_num_warnings_); } // If the error is not recoverable, any further attempt to read @@ -87,6 +93,8 @@ data_pipeline::reset(bool reset_rng) try { source_->reset(reset_rng); + // Reset warning counter when pipeline is reset + warning_count_ = 0; } catch (const std::exception &) { is_broken_ = true; diff --git a/native/src/fairseq2n/data/data_pipeline.h b/native/src/fairseq2n/data/data_pipeline.h index 849a2439c..51f0742c5 100644 --- a/native/src/fairseq2n/data/data_pipeline.h +++ b/native/src/fairseq2n/data/data_pipeline.h @@ -63,6 +63,12 @@ class FAIRSEQ2_API data_pipeline { return is_broken_; } + std::size_t + warning_count() const noexcept + { + return warning_count_; + } + private: bool is_initialized() const noexcept; diff --git a/native/src/fairseq2n/data/filter_data_source.cc b/native/src/fairseq2n/data/filter_data_source.cc index 1e6986abd..24b8188a7 100644 --- a/native/src/fairseq2n/data/filter_data_source.cc +++ b/native/src/fairseq2n/data/filter_data_source.cc @@ -49,7 +49,15 @@ filter_data_source::finitude_type() const noexcept bool filter_data_source::invoke_function(data &example) { - return predicate_fn_(example); + try { + return predicate_fn_(example); + } catch (const std::exception &ex) { + // Convert any exception into a recoverable data_pipeline_error + throw data_pipeline_error( + std::string("Error in filter function: ") + ex.what(), + example, // Pass the example that caused the error + true); // Mark as recoverable + } } } // fairseq2n::detail From b61e63925e37313e572f03ed8d5b3065098acee9 Mon Sep 17 00:00:00 2001 From: Julien Yao Date: Wed, 19 Mar 2025 14:46:31 +0100 Subject: [PATCH 2/4] add 2 tests for max num warnings --- tests/unit/data/data_pipeline/test_data_pipeline.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/unit/data/data_pipeline/test_data_pipeline.py b/tests/unit/data/data_pipeline/test_data_pipeline.py index 5f70f8200..9dcf846f4 100644 --- a/tests/unit/data/data_pipeline/test_data_pipeline.py +++ b/tests/unit/data/data_pipeline/test_data_pipeline.py @@ -70,7 +70,6 @@ def fn(d: int) -> bool: assert output == [1, 3, 5] - @pytest.mark.skip("need additional work in data_pipeline::next") def test_next_does_not_raise_error_when_num_errors_is_less_than_max_num_warnings( self, ) -> None: @@ -89,7 +88,6 @@ def fn(d: int) -> bool: # TODO: assert log warning @pytest.mark.parametrize("max_num_warnings", [0, 1, 2]) - @pytest.mark.skip("need additional work in data_pipeline::next") def test_next_raises_error_when_num_errors_exceed_max_num_warnings( self, max_num_warnings: int ) -> None: From e86d4bf303e29f1ad57056b339d56db20444b43b Mon Sep 17 00:00:00 2001 From: Julien Yao Date: Wed, 19 Mar 2025 14:59:08 +0100 Subject: [PATCH 3/4] update test accordingly --- tests/unit/data/data_pipeline/test_data_pipeline.py | 2 +- tests/unit/data/data_pipeline/test_filter.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/unit/data/data_pipeline/test_data_pipeline.py b/tests/unit/data/data_pipeline/test_data_pipeline.py index 9dcf846f4..915497679 100644 --- a/tests/unit/data/data_pipeline/test_data_pipeline.py +++ b/tests/unit/data/data_pipeline/test_data_pipeline.py @@ -107,7 +107,7 @@ def fn(d: int) -> bool: pipeline = read_sequence(seq).filter(fn).and_return(max_num_warnings) - with pytest.raises(ValueError): + with pytest.raises(DataPipelineError): for _ in pipeline: pass diff --git a/tests/unit/data/data_pipeline/test_filter.py b/tests/unit/data/data_pipeline/test_filter.py index d332eb712..2ddce8ff3 100644 --- a/tests/unit/data/data_pipeline/test_filter.py +++ b/tests/unit/data/data_pipeline/test_filter.py @@ -8,7 +8,7 @@ import pytest -from fairseq2.data import read_sequence +from fairseq2.data import DataPipelineError, read_sequence class TestFilterOp: @@ -34,8 +34,9 @@ def fn(d: int) -> bool: pipeline = read_sequence([1, 2, 3, 4]).filter(fn).and_return() - with pytest.raises(ValueError) as exc_info: + with pytest.raises(DataPipelineError) as exc_info: for d in pipeline: pass - assert str(exc_info.value) == "filter error" + # Check that the original error message is included in the DataPipelineError + assert "filter error" in str(exc_info.value) From 6365f28c096a7ce17ac8f52ab853358ffd116bdd Mon Sep 17 00:00:00 2001 From: Julien Yao Date: Wed, 19 Mar 2025 15:17:02 +0100 Subject: [PATCH 4/4] lint (maybe not the best way for logging yet) --- native/src/fairseq2n/data/data_pipeline.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/native/src/fairseq2n/data/data_pipeline.cc b/native/src/fairseq2n/data/data_pipeline.cc index a84988f7e..75d5d96c8 100644 --- a/native/src/fairseq2n/data/data_pipeline.cc +++ b/native/src/fairseq2n/data/data_pipeline.cc @@ -59,14 +59,14 @@ data_pipeline::next() warning_count_++; // Log the exception with the current warning count - fprintf(stderr, "Data pipeline warning (%zu/%zu): %s\n", + (void) fprintf(stderr, "Data pipeline warning (%zu/%zu): %s\n", warning_count_, max_num_warnings_, ex.what()); // Continue to the next example continue; } else { if (max_num_warnings_ > 0 && warning_count_ >= max_num_warnings_) { - fprintf(stderr, "Data pipeline error: Maximum number of warnings (%zu) reached.\n", + (void) fprintf(stderr, "Data pipeline error: Maximum number of warnings (%zu) reached.\n", max_num_warnings_); }