Skip to content

Commit 5c6b146

Browse files
committed
Support 3-argument form of GroupBy with reducer function
GroupBy[list, f, reducer] now applies the reducer function to each group. E.g. GroupBy[{1,2,3,4,5,6}, EvenQ, Total] returns <|False -> 9, True -> 12|>. Previously only the 2-argument form was supported.
1 parent 9ddf6c1 commit 5c6b146

3 files changed

Lines changed: 39 additions & 3 deletions

File tree

src/evaluator/dispatch/arg_count.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ pub fn get_arg_count_range(name: &str) -> Option<(usize, usize)> {
399399
"GreaterEqual" => Some((2, usize::MAX)),
400400
"GreaterEqualThan" => Some((1, 1)),
401401
"GreaterThan" => Some((1, 1)),
402-
"GroupBy" => Some((1, 2)),
402+
"GroupBy" => Some((1, 3)),
403403
"GroupGenerators" => Some((1, 1)),
404404
"Groupings" => Some((2, 2)),
405405
"Gudermannian" => Some((1, 1)),

src/evaluator/dispatch/list_operations.rs

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,28 @@ pub fn dispatch_list_operations(
300300
"CountBy" if args.len() == 2 => {
301301
return Some(list_helpers_ast::count_by_ast(&args[0], &args[1]));
302302
}
303-
"GroupBy" if args.len() == 2 => {
304-
return Some(list_helpers_ast::group_by_ast(&args[0], &args[1]));
303+
"GroupBy" if args.len() == 2 || args.len() == 3 => {
304+
let result = list_helpers_ast::group_by_ast(&args[0], &args[1]);
305+
if args.len() == 3 {
306+
// GroupBy[list, f, reducer] - apply reducer to each group
307+
return Some(result.and_then(|grouped| match &grouped {
308+
Expr::Association(pairs) => {
309+
let new_pairs: Result<Vec<(Expr, Expr)>, InterpreterError> = pairs
310+
.iter()
311+
.map(|(k, v)| {
312+
let reduced =
313+
crate::functions::list_helpers_ast::apply_func_ast(
314+
&args[2], v,
315+
)?;
316+
Ok((k.clone(), reduced))
317+
})
318+
.collect();
319+
Ok(Expr::Association(new_pairs?))
320+
}
321+
_ => Ok(grouped),
322+
}));
323+
}
324+
return Some(result);
305325
}
306326
"SortBy" if args.len() == 2 => {
307327
return Some(list_helpers_ast::sort_by_ast(&args[0], &args[1]));

tests/interpreter_tests/list.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4225,6 +4225,22 @@ mod join_non_list {
42254225
);
42264226
}
42274227

4228+
#[test]
4229+
fn group_by_with_reducer() {
4230+
assert_eq!(
4231+
interpret("GroupBy[{1, 2, 3, 4, 5, 6}, EvenQ, Total]").unwrap(),
4232+
"<|False -> 9, True -> 12|>"
4233+
);
4234+
}
4235+
4236+
#[test]
4237+
fn group_by_with_length_reducer() {
4238+
assert_eq!(
4239+
interpret("GroupBy[{1, 2, 3, 4, 5, 6}, EvenQ, Length]").unwrap(),
4240+
"<|False -> 3, True -> 3|>"
4241+
);
4242+
}
4243+
42284244
#[test]
42294245
fn counts_by_operator_form() {
42304246
assert_eq!(

0 commit comments

Comments
 (0)