Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion libs/core/langchain_core/utils/_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def merge_dicts(left: dict[str, Any], *others: dict[str, Any]) -> dict[str, Any]
merged[right_k] = merge_lists(merged[right_k], right_v)
elif merged[right_k] == right_v:
continue
elif isinstance(merged[right_k], int):
elif isinstance(merged[right_k], (int, float)):
# Preserve identification and temporal fields using last-wins strategy
# instead of summing:
# - index: identifies which tool call a chunk belongs to
Expand Down Expand Up @@ -201,6 +201,8 @@ def merge_obj(left: Any, right: Any) -> Any:
return merge_lists(left, right)
if left == right:
return left
if isinstance(left, (int, float)):
return left + right
msg = (
f"Unable to merge {left=} and {right=}. Both must be of type str, dict, or "
f"list, or else be two equal objects."
Expand Down
17 changes: 9 additions & 8 deletions libs/core/langchain_core/utils/usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,23 @@
def _dict_int_op(
left: dict,
right: dict,
op: Callable[[int, int], int],
op: Callable[[float, float], float],
*,
default: int = 0,
depth: int = 0,
max_depth: int = 100,
) -> dict:
"""Apply an integer operation to corresponding values in two dictionaries.
"""Apply a numeric operation to corresponding values in two dictionaries.

Recursively combines two dictionaries by applying the given operation to integer
values at matching keys.
Recursively combines two dictionaries by applying the given operation to numeric
(int or float) values at matching keys.

Supports nested dictionaries.

Args:
left: First dictionary to combine.
right: Second dictionary to combine.
op: Binary operation function to apply to integer values.
op: Binary operation function to apply to numeric values.
default: Default value to use when a key is missing from a dictionary.
depth: Current recursion depth (used internally).
max_depth: Maximum recursion depth (to prevent infinite loops).
Expand All @@ -38,8 +38,8 @@ def _dict_int_op(
raise ValueError(msg)
combined: dict = {}
for k in set(left).union(right):
if isinstance(left.get(k, default), int) and isinstance(
right.get(k, default), int
if isinstance(left.get(k, default), (int, float)) and isinstance(
right.get(k, default), (int, float)
):
combined[k] = op(left.get(k, default), right.get(k, default))
elif isinstance(left.get(k, {}), dict) and isinstance(right.get(k, {}), dict):
Expand All @@ -54,7 +54,8 @@ def _dict_int_op(
else:
types = [type(d[k]) for d in (left, right) if k in d]
msg = (
f"Unknown value types: {types}. Only dict and int values are supported."
f"Unknown value types: {types}. "
f"Only dict, int, and float values are supported."
)
raise ValueError(msg) # noqa: TRY004
return combined
26 changes: 25 additions & 1 deletion libs/core/tests/unit_tests/utils/test_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,30 @@ def test_dict_int_op_invalid_types() -> None:
right = {"a": 2, "b": 3}
with pytest.raises(
ValueError,
match="Only dict and int values are supported",
match="Only dict, int, and float values are supported",
):
_dict_int_op(left, right, operator.add)


def test_dict_int_op_float_add() -> None:
"""Test that float values are handled correctly."""
left = {"a": 0.5, "b": 1.2}
right = {"b": 0.3, "c": 2.5}
result = _dict_int_op(left, right, operator.add)
assert result == {"a": 0.5, "b": 1.5, "c": 2.5}


def test_dict_int_op_mixed_int_float() -> None:
"""Test that mixed int and float values are handled correctly."""
left = {"a": 1, "b": 0.5}
right = {"a": 2.5, "b": 3}
result = _dict_int_op(left, right, operator.add)
assert result == {"a": 3.5, "b": 3.5}


def test_dict_int_op_nested_float() -> None:
"""Test that nested dictionaries with float values are handled correctly."""
left = {"a": 1, "b": {"c": 0.5, "d": 3}}
right = {"a": 2, "b": {"c": 1.5, "e": 0.1}}
result = _dict_int_op(left, right, operator.add)
assert result == {"a": 3, "b": {"c": 2.0, "d": 3, "e": 0.1}}
18 changes: 17 additions & 1 deletion libs/core/tests/unit_tests/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,18 @@ def test_check_package_version(
# Other integer fields should still be summed (e.g., token counts)
({"tokens": 10}, {"tokens": 5}, {"tokens": 15}),
({"count": 1}, {"count": 2}, {"count": 3}),
# Float fields should be summed (e.g., cost, score)
({"score": 0.5}, {"score": 0.3}, {"score": 0.8}),
({"total_cost": 0.01}, {"total_cost": 0.02}, {"total_cost": 0.03}),
# Float 'index' should be preserved, not summed
({"index": 0.0}, {"index": 1.0}, {"index": 1.0}),
# Float 'created'/'timestamp' should be preserved, not summed
(
{"created": 1700000000.5},
{"created": 1700000001.5},
{"created": 1700000001.5},
),
({"timestamp": 100.0}, {"timestamp": 200.0}, {"timestamp": 200.0}),
],
)
def test_merge_dicts(
Expand Down Expand Up @@ -478,6 +490,10 @@ def test_merge_lists_all_none() -> None:
(42, 42, 42),
(3.14, 3.14, 3.14),
(True, True, True),
# Numeric addition (int)
(10, 5, 15),
# Numeric addition (float)
(0.5, 0.3, 0.8),
],
)
def test_merge_obj(left: Any, right: Any, expected: Any) -> None:
Expand All @@ -494,7 +510,7 @@ def test_merge_obj_type_mismatch() -> None:
def test_merge_obj_unmergeable_values() -> None:
"""Test `merge_obj` raises `ValueError` on unmergeable values."""
with pytest.raises(ValueError, match="Unable to merge"):
merge_obj(1, 2) # Different integers
merge_obj({1, 2}, {3, 4}) # Sets are not mergeable


def test_merge_obj_tuple_raises() -> None:
Expand Down
Loading