From a055aeda16d30153b1ad3816f4e92ace67981149 Mon Sep 17 00:00:00 2001 From: keenborder786 <21110290@lums.edu.pk> Date: Sun, 8 Mar 2026 22:22:19 +0500 Subject: [PATCH] fix: .save for chains --- .../langchain_classic/chains/base.py | 6 +- .../tests/unit_tests/chains/test_base.py | 57 +++++++++++++++++++ 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain_classic/chains/base.py b/libs/langchain/langchain_classic/chains/base.py index c91ccf6c6882e..a70e3424f2d5d 100644 --- a/libs/langchain/langchain_classic/chains/base.py +++ b/libs/langchain/langchain_classic/chains/base.py @@ -732,15 +732,15 @@ async def arun( ) raise ValueError(msg) - def dict(self, **kwargs: Any) -> dict: + def model_dump(self, **kwargs: Any) -> dict: """Dictionary representation of chain. Expects `Chain._chain_type` property to be implemented and for memory to be null. Args: - **kwargs: Keyword arguments passed to default `pydantic.BaseModel.dict` - method. + **kwargs: Keyword arguments passed to default + `pydantic.BaseModel.model_dump` method. Returns: A dictionary representation of the chain. diff --git a/libs/langchain/tests/unit_tests/chains/test_base.py b/libs/langchain/tests/unit_tests/chains/test_base.py index a607eada5b901..bd5d83625e43b 100644 --- a/libs/langchain/tests/unit_tests/chains/test_base.py +++ b/libs/langchain/tests/unit_tests/chains/test_base.py @@ -2,6 +2,7 @@ import re import uuid +from pathlib import Path from typing import Any import pytest @@ -66,6 +67,14 @@ def _call( return {"baz": "bar"} +class FakeSavableChain(FakeChain): + """Fake chain that supports saving via _chain_type.""" + + @property + def _chain_type(self) -> str: + return "fake_savable" + + def test_bad_inputs() -> None: """Test errors are raised if input keys are not found.""" chain = FakeChain() @@ -243,3 +252,51 @@ def test_run_with_callback_and_output_error() -> None: assert handler.starts == 1 assert handler.ends == 0 assert handler.errors == 1 + + +def test_model_dump_includes_type() -> None: + """Test that model_dump includes _type when _chain_type is implemented.""" + chain = FakeSavableChain() + dumped = chain.model_dump() + assert "_type" in dumped + assert dumped["_type"] == "fake_savable" + + +def test_model_dump_excludes_type_when_not_implemented() -> None: + """Test that model_dump omits _type when _chain_type raises.""" + chain = FakeChain() + dumped = chain.model_dump() + assert "_type" not in dumped + + +def test_save_yaml(tmp_path: Path) -> None: + """Test that save() works for a chain that implements _chain_type.""" + chain = FakeSavableChain() + file_path = tmp_path / "chain.yaml" + chain.save(str(file_path)) + assert file_path.exists() + import yaml + + with file_path.open() as f: + data = yaml.safe_load(f) + assert data["_type"] == "fake_savable" + + +def test_save_json(tmp_path: Path) -> None: + """Test that save() works with JSON format.""" + chain = FakeSavableChain() + file_path = tmp_path / "chain.json" + chain.save(str(file_path)) + assert file_path.exists() + import json + + with file_path.open() as f: + data = json.load(f) + assert data["_type"] == "fake_savable" + + +def test_save_raises_when_chain_type_not_implemented(tmp_path: Path) -> None: + """Test that save() raises NotImplementedError for unsavable chains.""" + chain = FakeChain() + with pytest.raises(NotImplementedError, match="does not support saving"): + chain.save(str(tmp_path / "chain.yaml"))