diff --git a/nle/tests/test_converter.py b/nle/tests/test_converter.py index 70ca6f98a..7a078be1b 100644 --- a/nle/tests/test_converter.py +++ b/nle/tests/test_converter.py @@ -190,6 +190,12 @@ def test_noexist(self): ): converter.load_ttyrec(fn) + def test_rejects_zero_rows_or_cols(self): + with pytest.raises(ValueError, match=r"rows and cols must be > 0"): + Converter(0, COLUMNS, TTYREC_V3, ROWS, COLUMNS) + with pytest.raises(ValueError, match=r"rows and cols must be > 0"): + Converter(ROWS, 0, TTYREC_V3, ROWS, COLUMNS) + def test_illegal_buffers(self): converter = Converter(ROWS, COLUMNS, TTYREC_V1) converter.load_ttyrec(getfilename(TTYREC_2020)) @@ -284,6 +290,18 @@ def test_illegal_buffers(self): with pytest.raises(ValueError, match=r"Numpy array required"): converter.convert(chars, colors, cursors, timestamps, actions, scores) + chars = np.array(8, dtype=np.uint8) + colors = np.zeros((10, ROWS, COLUMNS), dtype=np.int8) + cursors = np.zeros((10, 2), dtype=np.int16) + actions = np.zeros((10), dtype=np.uint8) + timestamps = np.zeros((10,), dtype=np.int64) + scores = np.zeros((10), dtype=np.int32) + with pytest.raises( + ValueError, + match=r"Array has wrong number of dimensions \(expected 3, got 0\)", + ): + converter.convert(chars, colors, cursors, timestamps, actions, scores) + chars = np.zeros((10, ROWS, COLUMNS), dtype=np.uint8) timestamps = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] with pytest.raises(ValueError, match=r"Numpy array required"): diff --git a/third_party/converter/converter.c b/third_party/converter/converter.c index eae901f7d..7103050d4 100644 --- a/third_party/converter/converter.c +++ b/third_party/converter/converter.c @@ -186,6 +186,10 @@ Conversion *conversion_create(size_t rows, size_t cols, size_t term_rows, stripgfx_init = true; } + if (rows == 0 || cols == 0) { + return NULL; + } + Conversion *c = malloc(sizeof(Conversion)); if (!c) return NULL; c->version = version; @@ -193,6 +197,10 @@ Conversion *conversion_create(size_t rows, size_t cols, size_t term_rows, c->cols = cols; if (!term_rows) term_rows = rows; if (!term_cols) term_cols = cols; + if (term_rows < rows || term_cols < cols) { + free(c); + return NULL; + } assert(rows <= term_rows && cols <= term_cols); c->chars = (UnsignedCharPtr){0}; c->colors = (SignedCharPtr){0}; diff --git a/third_party/converter/pyconverter.cc b/third_party/converter/pyconverter.cc index 2d4d5d751..ba4a06ce1 100644 --- a/third_party/converter/pyconverter.cc +++ b/third_party/converter/pyconverter.cc @@ -65,6 +65,9 @@ class Converter term_rows_((term_rows != 0) ? term_rows : rows), term_cols_((term_cols != 0) ? term_cols : cols) { + if (rows_ == 0 || cols_ == 0) + throw std::invalid_argument("rows and cols must be > 0"); + if (term_rows_ < 2 || term_cols_ < 2) throw std::invalid_argument("Terminal invalid: term_rows and term_cols must be >1"); @@ -117,7 +120,14 @@ class Converter py::array array = py::array::ensure(chars); if (!array.dtype().is(py::dtype::of())) throw std::invalid_argument("Buffer dtype mismatch."); - size_t unroll = array.request().shape[0]; + py::buffer_info chars_buf = array.request(); + if (chars_buf.ndim != 3) { + std::ostringstream ss; + ss << "Array has wrong number of dimensions (expected 3, got " + << chars_buf.ndim << ")"; + throw std::invalid_argument(ss.str()); + } + size_t unroll = chars_buf.shape[0]; conversion_set_buffers( conversion_, @@ -207,4 +217,4 @@ PYBIND11_MODULE(_pyconverter, m) .def_property_readonly("filename", &Converter::filename) .def_property_readonly("part", &Converter::part) .def_property_readonly("gameid", &Converter::gameid); -} \ No newline at end of file +}