Skip to content

Relaxed model checks #154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def __init__(self, **kwargs: Any):
if not model.client:
warnings.warn(f"Unable to determine validity of {model.id}")
elif model.client != self.cls:
raise ValueError(
warnings.warn(
f"Model {model.id} is incompatible with client {self.cls}. "
f"Please check `{self.cls}.get_available_models()`."
)
Expand Down Expand Up @@ -222,10 +222,15 @@ def __init__(self, **kwargs: Any):
if self.mdl_name.startswith("nvdev/"): # assume valid
model = Model(id=self.mdl_name)
else:
raise ValueError(
warnings.warn(
f"Model {self.mdl_name} is unknown, "
"check `available_models`"
)

if model is None:
raise ValueError(
f"Unable to find {self.mdl_name}. Please verify configuration."
)
self.model = model
self.mdl_name = self.model.id # name may change because of aliasing
else:
Expand All @@ -246,7 +251,7 @@ def __init__(self, **kwargs: Any):
UserWarning,
)
else:
raise ValueError("No locally hosted model was found.")
warnings.warn("No locally hosted model was found.")

###################################################################################
################### LangChain functions ###########################################
Expand Down
10 changes: 5 additions & 5 deletions libs/ai-endpoints/tests/unit_tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def test_unknown_unknown(public_class: type, empty_v1_models: None) -> None:
# todo: make this work for local NIM
with pytest.raises(ValueError) as e:
public_class(model="test/unknown-unknown", nvidia_api_key="a-bogus-key")
assert "unknown" in str(e.value)
assert "Unable to find" in str(e.value)


def test_default_known(public_class: type, known_unknown: str) -> None:
Expand Down Expand Up @@ -174,8 +174,8 @@ def test_all_incompatible(public_class: type, model: str, client: str) -> None:
if client == public_class.__name__:
pytest.skip("Compatibility expected.")

with pytest.raises(ValueError) as err_msg:
with pytest.warns(UserWarning) as record:
public_class(model=model, nvidia_api_key="a-bogus-key")
assert f"Model {model} is incompatible with client {public_class.__name__}" in str(
err_msg.value
)

assert len(record) == 1
assert "incompatible with client" in str(record[0].message)