Skip to content

Commit ad45545

Browse files
authored
fix(python): dot product of two integer series is cast to float (#15502)
1 parent a8c9738 commit ad45545

File tree

3 files changed

+47
-4
lines changed

3 files changed

+47
-4
lines changed

py-polars/polars/series/series.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4947,7 +4947,7 @@ def round_sig_figs(self, digits: int) -> Series:
49474947
]
49484948
"""
49494949

4950-
def dot(self, other: Series | ArrayLike) -> float | None:
4950+
def dot(self, other: Series | ArrayLike) -> int | float | None:
49514951
"""
49524952
Compute the dot/inner product between two Series.
49534953

py-polars/src/series/mod.rs

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -604,9 +604,30 @@ impl PySeries {
604604
self.series.shrink_to_fit();
605605
}
606606

607-
fn dot(&self, other: &PySeries) -> PyResult<f64> {
608-
let out = self.series.dot(&other.series).map_err(PyPolarsErr::from)?;
609-
Ok(out)
607+
fn dot(&self, other: &PySeries, py: Python) -> PyResult<PyObject> {
608+
let lhs_dtype = self.series.dtype();
609+
let rhs_dtype = other.series.dtype();
610+
611+
if !lhs_dtype.is_numeric() {
612+
return Err(PyPolarsErr::from(polars_err!(opq = dot, lhs_dtype)).into());
613+
};
614+
if !rhs_dtype.is_numeric() {
615+
return Err(PyPolarsErr::from(polars_err!(opq = dot, rhs_dtype)).into());
616+
}
617+
618+
let result: AnyValue = if lhs_dtype.is_float() || rhs_dtype.is_float() {
619+
(&self.series * &other.series)
620+
.sum::<f64>()
621+
.map_err(PyPolarsErr::from)?
622+
.into()
623+
} else {
624+
(&self.series * &other.series)
625+
.sum::<i64>()
626+
.map_err(PyPolarsErr::from)?
627+
.into()
628+
};
629+
630+
Ok(Wrap(result).into_py(py))
610631
}
611632

612633
#[cfg(feature = "ipc_streaming")]

py-polars/tests/unit/dataframe/test_df.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1407,6 +1407,28 @@ def test_dot_product() -> None:
14071407
assert df["a"].dot(df["b"]) == 20
14081408
assert typing.cast(int, df.select([pl.col("a").dot("b")])[0, "a"]) == 20
14091409

1410+
result = pl.Series([1, 2, 3]) @ pl.Series([4, 5, 6])
1411+
assert isinstance(result, int)
1412+
assert result == 32
1413+
1414+
result = pl.Series([1, 2, 3]) @ pl.Series([4.0, 5.0, 6.0])
1415+
assert isinstance(result, float)
1416+
assert result == 32.0
1417+
1418+
result = pl.Series([1.0, 2.0, 3.0]) @ pl.Series([4.0, 5.0, 6.0])
1419+
assert isinstance(result, float)
1420+
assert result == 32.0
1421+
1422+
with pytest.raises(
1423+
pl.InvalidOperationError, match="`dot` operation not supported for dtype `bool`"
1424+
):
1425+
pl.Series([True, False, False, True]) @ pl.Series([4, 5, 6, 7])
1426+
1427+
with pytest.raises(
1428+
pl.InvalidOperationError, match="`dot` operation not supported for dtype `str`"
1429+
):
1430+
pl.Series([1, 2, 3, 4]) @ pl.Series(["True", "False", "False", "True"])
1431+
14101432

14111433
def test_hash_rows() -> None:
14121434
df = pl.DataFrame({"a": [1, 2, 3, 4], "b": [2, 2, 2, 2]})

0 commit comments

Comments
 (0)