@@ -536,7 +536,11 @@ static PyObject *add_scalar_array(PyObject *array_obj, PyObject *scalar_obj, PyO
536536
537537 nk_dtype_t dtype = resolve_nk_dtype_in_py_buffer (& a_buffer );
538538 if (out_dtype_obj ) { dtype = py_object_to_nk_dtype (out_dtype_obj ); }
539- if (dtype == nk_dtype_unknown_k ) goto cleanup ;
539+ if (dtype == nk_dtype_unknown_k ) {
540+ if (!PyErr_Occurred ())
541+ PyErr_SetString (PyExc_TypeError , "unsupported buffer dtype for the requested elementwise operation" );
542+ goto cleanup ;
543+ }
540544
541545 nk_each_scale_punned_t scale_kernel = NULL ;
542546 nk_capability_t capability = nk_cap_serial_k ;
@@ -555,7 +559,11 @@ static PyObject *add_scalar_array(PyObject *array_obj, PyObject *scalar_obj, PyO
555559
556560 size_t const element_size = nk_dtype_bytes_per_value (dtype );
557561 size_t total_elements = 1 ;
558- for (int dim = 0 ; dim < a_buffer .ndim ; dim ++ ) total_elements *= (size_t )a_buffer .shape [dim ];
562+ for (int dim = 0 ; dim < a_buffer .ndim ; dim ++ )
563+ if (!nk_size_mul_checked_ (total_elements , (size_t )a_buffer .shape [dim ], & total_elements )) {
564+ PyErr_SetString (PyExc_OverflowError , "tensor element count overflows size_t" );
565+ goto cleanup ;
566+ }
559567
560568 char * result_data = NULL ;
561569 Py_ssize_t result_strides [NK_TENSOR_MAX_RANK ];
@@ -655,7 +663,11 @@ static PyObject *add_array_array(PyObject *a_obj, PyObject *b_obj, PyObject *out
655663 }
656664
657665 if (out_dtype_obj ) { dtype = py_object_to_nk_dtype (out_dtype_obj ); }
658- if (dtype == nk_dtype_unknown_k ) goto cleanup ;
666+ if (dtype == nk_dtype_unknown_k ) {
667+ if (!PyErr_Occurred ())
668+ PyErr_SetString (PyExc_TypeError , "unsupported buffer dtype for the requested elementwise operation" );
669+ goto cleanup ;
670+ }
659671
660672 nk_each_sum_punned_t sum_kernel = NULL ;
661673 nk_capability_t capability = nk_cap_serial_k ;
@@ -668,7 +680,11 @@ static PyObject *add_array_array(PyObject *a_obj, PyObject *b_obj, PyObject *out
668680
669681 int const num_dims = a_buffer .ndim ;
670682 size_t total_elements = 1 ;
671- for (int dim = 0 ; dim < num_dims ; dim ++ ) total_elements *= (size_t )a_buffer .shape [dim ];
683+ for (int dim = 0 ; dim < num_dims ; dim ++ )
684+ if (!nk_size_mul_checked_ (total_elements , (size_t )a_buffer .shape [dim ], & total_elements )) {
685+ PyErr_SetString (PyExc_OverflowError , "tensor element count overflows size_t" );
686+ goto cleanup ;
687+ }
672688
673689 a_promoted = ensure_contiguous_buffer (a_buffer .buf , a_dtype , dtype , num_dims , a_buffer .shape , a_buffer .strides ,
674690 total_elements , & a_needs_free );
@@ -827,7 +843,11 @@ static PyObject *multiply_scalar_array(PyObject *array_obj, PyObject *scalar_obj
827843
828844 nk_dtype_t dtype = resolve_nk_dtype_in_py_buffer (& a_buffer );
829845 if (out_dtype_obj ) { dtype = py_object_to_nk_dtype (out_dtype_obj ); }
830- if (dtype == nk_dtype_unknown_k ) goto cleanup ;
846+ if (dtype == nk_dtype_unknown_k ) {
847+ if (!PyErr_Occurred ())
848+ PyErr_SetString (PyExc_TypeError , "unsupported buffer dtype for the requested elementwise operation" );
849+ goto cleanup ;
850+ }
831851
832852 nk_each_scale_punned_t scale_kernel = NULL ;
833853 nk_capability_t capability = nk_cap_serial_k ;
@@ -846,7 +866,11 @@ static PyObject *multiply_scalar_array(PyObject *array_obj, PyObject *scalar_obj
846866
847867 size_t const element_size = nk_dtype_bytes_per_value (dtype );
848868 size_t total_elements = 1 ;
849- for (int dim = 0 ; dim < a_buffer .ndim ; dim ++ ) total_elements *= (size_t )a_buffer .shape [dim ];
869+ for (int dim = 0 ; dim < a_buffer .ndim ; dim ++ )
870+ if (!nk_size_mul_checked_ (total_elements , (size_t )a_buffer .shape [dim ], & total_elements )) {
871+ PyErr_SetString (PyExc_OverflowError , "tensor element count overflows size_t" );
872+ goto cleanup ;
873+ }
850874
851875 char * result_data = NULL ;
852876 Py_ssize_t result_strides [NK_TENSOR_MAX_RANK ];
@@ -946,7 +970,11 @@ static PyObject *multiply_array_array(PyObject *a_obj, PyObject *b_obj, PyObject
946970 }
947971
948972 if (out_dtype_obj ) { dtype = py_object_to_nk_dtype (out_dtype_obj ); }
949- if (dtype == nk_dtype_unknown_k ) goto cleanup ;
973+ if (dtype == nk_dtype_unknown_k ) {
974+ if (!PyErr_Occurred ())
975+ PyErr_SetString (PyExc_TypeError , "unsupported buffer dtype for the requested elementwise operation" );
976+ goto cleanup ;
977+ }
950978
951979 nk_each_fma_punned_t fma_kernel = NULL ;
952980 nk_capability_t capability = nk_cap_serial_k ;
@@ -965,7 +993,11 @@ static PyObject *multiply_array_array(PyObject *a_obj, PyObject *b_obj, PyObject
965993
966994 int const num_dims = a_buffer .ndim ;
967995 size_t total_elements = 1 ;
968- for (int dim = 0 ; dim < num_dims ; dim ++ ) total_elements *= (size_t )a_buffer .shape [dim ];
996+ for (int dim = 0 ; dim < num_dims ; dim ++ )
997+ if (!nk_size_mul_checked_ (total_elements , (size_t )a_buffer .shape [dim ], & total_elements )) {
998+ PyErr_SetString (PyExc_OverflowError , "tensor element count overflows size_t" );
999+ goto cleanup ;
1000+ }
9691001
9701002 a_promoted = ensure_contiguous_buffer (a_buffer .buf , a_dtype , dtype , num_dims , a_buffer .shape , a_buffer .strides ,
9711003 total_elements , & a_needs_free );
0 commit comments