Skip to content

Commit 4dd272c

Browse files
authored
fix: Properly broadcast array arithmetic (#18851)
1 parent dcad7d8 commit 4dd272c

File tree

3 files changed

+52
-3
lines changed

3 files changed

+52
-3
lines changed

.typos.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ extend-exclude = [
44
"*.csv",
55
"*.gz",
66
"dists.dss",
7+
"**/images/*",
78
]
89
ignore-hidden = false
910

crates/polars-core/src/series/arithmetic/borrowed.rs

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,20 +130,42 @@ fn array_shape(dt: &DataType, infer: bool) -> Vec<i64> {
130130
buf
131131
}
132132

133+
#[cfg(feature = "dtype-array")]
134+
fn broadcast_array(lhs: &ArrayChunked, rhs: &Series) -> PolarsResult<(ArrayChunked, Series)> {
135+
let out = match (lhs.len(), rhs.len()) {
136+
(1, _) => (lhs.new_from_index(0, rhs.len()), rhs.clone()),
137+
(_, 1) => {
138+
// Numeric scalars will be broadcasted implicitly without intermediate allocation.
139+
if rhs.dtype().is_numeric() {
140+
(lhs.clone(), rhs.clone())
141+
} else {
142+
(lhs.clone(), rhs.new_from_index(0, lhs.len()))
143+
}
144+
},
145+
(a, b) if a == b => (lhs.clone(), rhs.clone()),
146+
_ => {
147+
polars_bail!(InvalidOperation: "can only do arithmetic of array's of the same type and shape; got {} and {}", lhs.dtype(), rhs.dtype())
148+
},
149+
};
150+
Ok(out)
151+
}
152+
133153
#[cfg(feature = "dtype-array")]
134154
impl ArrayChunked {
135155
fn arithm_helper(
136156
&self,
137157
rhs: &Series,
138158
op: &dyn Fn(Series, Series) -> PolarsResult<Series>,
139159
) -> PolarsResult<Series> {
140-
let l_leaf_array = self.clone().into_series().get_leaf_array();
141-
let shape = array_shape(self.dtype(), true);
160+
let (lhs, rhs) = broadcast_array(self, rhs)?;
161+
162+
let l_leaf_array = lhs.clone().into_series().get_leaf_array();
163+
let shape = array_shape(lhs.dtype(), true);
142164

143165
let r_leaf_array = if rhs.dtype().is_numeric() && rhs.len() == 1 {
144166
rhs.clone()
145167
} else {
146-
polars_ensure!(self.dtype() == rhs.dtype(), InvalidOperation: "can only do arithmetic of array's of the same type and shape; got {} and {}", self.dtype(), rhs.dtype());
168+
polars_ensure!(lhs.dtype() == rhs.dtype(), InvalidOperation: "can only do arithmetic of array's of the same type and shape; got {} and {}", self.dtype(), rhs.dtype());
147169
rhs.get_leaf_array()
148170
};
149171

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import polars as pl
2+
3+
4+
def test_literal_broadcast_array() -> None:
5+
df = pl.DataFrame({"A": [[0.1, 0.2], [0.3, 0.4]]}).cast(pl.Array(float, 2))
6+
7+
lit = pl.lit([3, 5], pl.Array(float, 2))
8+
assert df.select(
9+
mul=pl.all() * lit,
10+
div=pl.all() / lit,
11+
add=pl.all() + lit,
12+
sub=pl.all() - lit,
13+
div_=lit / pl.all(),
14+
add_=lit + pl.all(),
15+
sub_=lit - pl.all(),
16+
mul_=lit * pl.all(),
17+
).to_dict(as_series=False) == {
18+
"mul": [[0.30000000000000004, 1.0], [0.8999999999999999, 2.0]],
19+
"div": [[0.03333333333333333, 0.04], [0.09999999999999999, 0.08]],
20+
"add": [[3.1, 5.2], [3.3, 5.4]],
21+
"sub": [[-2.9, -4.8], [-2.7, -4.6]],
22+
"div_": [[30.0, 25.0], [10.0, 12.5]],
23+
"add_": [[3.1, 5.2], [3.3, 5.4]],
24+
"sub_": [[2.9, 4.8], [2.7, 4.6]],
25+
"mul_": [[0.30000000000000004, 1.0], [0.8999999999999999, 2.0]],
26+
}

0 commit comments

Comments
 (0)