Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions libs/langchain/langchain_classic/chains/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,15 +732,15 @@ async def arun(
)
raise ValueError(msg)

def dict(self, **kwargs: Any) -> dict:
def model_dump(self, **kwargs: Any) -> dict:
"""Dictionary representation of chain.

Expects `Chain._chain_type` property to be implemented and for memory to be
null.

Args:
**kwargs: Keyword arguments passed to default `pydantic.BaseModel.dict`
method.
**kwargs: Keyword arguments passed to default
`pydantic.BaseModel.model_dump` method.

Returns:
A dictionary representation of the chain.
Expand Down
57 changes: 57 additions & 0 deletions libs/langchain/tests/unit_tests/chains/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import re
import uuid
from pathlib import Path
from typing import Any

import pytest
Expand Down Expand Up @@ -66,6 +67,14 @@ def _call(
return {"baz": "bar"}


class FakeSavableChain(FakeChain):
"""Fake chain that supports saving via _chain_type."""

@property
def _chain_type(self) -> str:
return "fake_savable"


def test_bad_inputs() -> None:
"""Test errors are raised if input keys are not found."""
chain = FakeChain()
Expand Down Expand Up @@ -243,3 +252,51 @@ def test_run_with_callback_and_output_error() -> None:
assert handler.starts == 1
assert handler.ends == 0
assert handler.errors == 1


def test_model_dump_includes_type() -> None:
"""Test that model_dump includes _type when _chain_type is implemented."""
chain = FakeSavableChain()
dumped = chain.model_dump()
assert "_type" in dumped
assert dumped["_type"] == "fake_savable"


def test_model_dump_excludes_type_when_not_implemented() -> None:
"""Test that model_dump omits _type when _chain_type raises."""
chain = FakeChain()
dumped = chain.model_dump()
assert "_type" not in dumped


def test_save_yaml(tmp_path: Path) -> None:
"""Test that save() works for a chain that implements _chain_type."""
chain = FakeSavableChain()
file_path = tmp_path / "chain.yaml"
chain.save(str(file_path))
assert file_path.exists()
import yaml

with file_path.open() as f:
data = yaml.safe_load(f)
assert data["_type"] == "fake_savable"


def test_save_json(tmp_path: Path) -> None:
"""Test that save() works with JSON format."""
chain = FakeSavableChain()
file_path = tmp_path / "chain.json"
chain.save(str(file_path))
assert file_path.exists()
import json

with file_path.open() as f:
data = json.load(f)
assert data["_type"] == "fake_savable"


def test_save_raises_when_chain_type_not_implemented(tmp_path: Path) -> None:
"""Test that save() raises NotImplementedError for unsavable chains."""
chain = FakeChain()
with pytest.raises(NotImplementedError, match="does not support saving"):
chain.save(str(tmp_path / "chain.yaml"))
Loading