diff --git a/libs/core/langchain_core/utils/_merge.py b/libs/core/langchain_core/utils/_merge.py index 6a0cb38f07850..7035bab01697b 100644 --- a/libs/core/langchain_core/utils/_merge.py +++ b/libs/core/langchain_core/utils/_merge.py @@ -68,7 +68,7 @@ def merge_dicts(left: dict[str, Any], *others: dict[str, Any]) -> dict[str, Any] merged[right_k] = merge_lists(merged[right_k], right_v) elif merged[right_k] == right_v: continue - elif isinstance(merged[right_k], int): + elif isinstance(merged[right_k], (int, float)): # Preserve identification and temporal fields using last-wins strategy # instead of summing: # - index: identifies which tool call a chunk belongs to @@ -201,6 +201,8 @@ def merge_obj(left: Any, right: Any) -> Any: return merge_lists(left, right) if left == right: return left + if isinstance(left, (int, float)): + return left + right msg = ( f"Unable to merge {left=} and {right=}. Both must be of type str, dict, or " f"list, or else be two equal objects." diff --git a/libs/core/langchain_core/utils/usage.py b/libs/core/langchain_core/utils/usage.py index 47e483a5555cd..975c54d7936bb 100644 --- a/libs/core/langchain_core/utils/usage.py +++ b/libs/core/langchain_core/utils/usage.py @@ -6,23 +6,23 @@ def _dict_int_op( left: dict, right: dict, - op: Callable[[int, int], int], + op: Callable[[float, float], float], *, default: int = 0, depth: int = 0, max_depth: int = 100, ) -> dict: - """Apply an integer operation to corresponding values in two dictionaries. + """Apply a numeric operation to corresponding values in two dictionaries. - Recursively combines two dictionaries by applying the given operation to integer - values at matching keys. + Recursively combines two dictionaries by applying the given operation to numeric + (int or float) values at matching keys. Supports nested dictionaries. Args: left: First dictionary to combine. right: Second dictionary to combine. - op: Binary operation function to apply to integer values. + op: Binary operation function to apply to numeric values. default: Default value to use when a key is missing from a dictionary. depth: Current recursion depth (used internally). max_depth: Maximum recursion depth (to prevent infinite loops). @@ -38,8 +38,8 @@ def _dict_int_op( raise ValueError(msg) combined: dict = {} for k in set(left).union(right): - if isinstance(left.get(k, default), int) and isinstance( - right.get(k, default), int + if isinstance(left.get(k, default), (int, float)) and isinstance( + right.get(k, default), (int, float) ): combined[k] = op(left.get(k, default), right.get(k, default)) elif isinstance(left.get(k, {}), dict) and isinstance(right.get(k, {}), dict): @@ -54,7 +54,8 @@ def _dict_int_op( else: types = [type(d[k]) for d in (left, right) if k in d] msg = ( - f"Unknown value types: {types}. Only dict and int values are supported." + f"Unknown value types: {types}. " + f"Only dict, int, and float values are supported." ) raise ValueError(msg) # noqa: TRY004 return combined diff --git a/libs/core/tests/unit_tests/utils/test_usage.py b/libs/core/tests/unit_tests/utils/test_usage.py index 1ad3500d6d830..8440758a526bb 100644 --- a/libs/core/tests/unit_tests/utils/test_usage.py +++ b/libs/core/tests/unit_tests/utils/test_usage.py @@ -40,6 +40,30 @@ def test_dict_int_op_invalid_types() -> None: right = {"a": 2, "b": 3} with pytest.raises( ValueError, - match="Only dict and int values are supported", + match="Only dict, int, and float values are supported", ): _dict_int_op(left, right, operator.add) + + +def test_dict_int_op_float_add() -> None: + """Test that float values are handled correctly.""" + left = {"a": 0.5, "b": 1.2} + right = {"b": 0.3, "c": 2.5} + result = _dict_int_op(left, right, operator.add) + assert result == {"a": 0.5, "b": 1.5, "c": 2.5} + + +def test_dict_int_op_mixed_int_float() -> None: + """Test that mixed int and float values are handled correctly.""" + left = {"a": 1, "b": 0.5} + right = {"a": 2.5, "b": 3} + result = _dict_int_op(left, right, operator.add) + assert result == {"a": 3.5, "b": 3.5} + + +def test_dict_int_op_nested_float() -> None: + """Test that nested dictionaries with float values are handled correctly.""" + left = {"a": 1, "b": {"c": 0.5, "d": 3}} + right = {"a": 2, "b": {"c": 1.5, "e": 0.1}} + result = _dict_int_op(left, right, operator.add) + assert result == {"a": 3, "b": {"c": 2.0, "d": 3, "e": 0.1}} diff --git a/libs/core/tests/unit_tests/utils/test_utils.py b/libs/core/tests/unit_tests/utils/test_utils.py index 815296f88892a..4c55920717547 100644 --- a/libs/core/tests/unit_tests/utils/test_utils.py +++ b/libs/core/tests/unit_tests/utils/test_utils.py @@ -129,6 +129,18 @@ def test_check_package_version( # Other integer fields should still be summed (e.g., token counts) ({"tokens": 10}, {"tokens": 5}, {"tokens": 15}), ({"count": 1}, {"count": 2}, {"count": 3}), + # Float fields should be summed (e.g., cost, score) + ({"score": 0.5}, {"score": 0.3}, {"score": 0.8}), + ({"total_cost": 0.01}, {"total_cost": 0.02}, {"total_cost": 0.03}), + # Float 'index' should be preserved, not summed + ({"index": 0.0}, {"index": 1.0}, {"index": 1.0}), + # Float 'created'/'timestamp' should be preserved, not summed + ( + {"created": 1700000000.5}, + {"created": 1700000001.5}, + {"created": 1700000001.5}, + ), + ({"timestamp": 100.0}, {"timestamp": 200.0}, {"timestamp": 200.0}), ], ) def test_merge_dicts( @@ -478,6 +490,10 @@ def test_merge_lists_all_none() -> None: (42, 42, 42), (3.14, 3.14, 3.14), (True, True, True), + # Numeric addition (int) + (10, 5, 15), + # Numeric addition (float) + (0.5, 0.3, 0.8), ], ) def test_merge_obj(left: Any, right: Any, expected: Any) -> None: @@ -494,7 +510,7 @@ def test_merge_obj_type_mismatch() -> None: def test_merge_obj_unmergeable_values() -> None: """Test `merge_obj` raises `ValueError` on unmergeable values.""" with pytest.raises(ValueError, match="Unable to merge"): - merge_obj(1, 2) # Different integers + merge_obj({1, 2}, {3, 4}) # Sets are not mergeable def test_merge_obj_tuple_raises() -> None: