Skip to content

Commit e133736

Browse files
authored
Update cli.py
1 parent 2d9ee84 commit e133736

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

llama_cpp/server/cli.py

+15-13
Original file line numberDiff line numberDiff line change
@@ -6,34 +6,35 @@
66
from pydantic import BaseModel
77

88

9-
def _get_base_type(annotation: type[Any]) -> type[Any]:
9+
def _get_base_type(annotation: Type[Any]) -> Type[Any]:
1010
if getattr(annotation, "__origin__", None) is Literal:
1111
assert hasattr(annotation, "__args__") and len(annotation.__args__) >= 1 # type: ignore
1212
return type(annotation.__args__[0]) # type: ignore
13-
if getattr(annotation, "__origin__", None) is Union:
13+
elif getattr(annotation, "__origin__", None) is Union:
1414
assert hasattr(annotation, "__args__") and len(annotation.__args__) >= 1 # type: ignore
15-
non_optional_args: list[type[Any]] = [
15+
non_optional_args: List[Type[Any]] = [
1616
arg for arg in annotation.__args__ if arg is not type(None) # type: ignore
1717
]
1818
if non_optional_args:
1919
return _get_base_type(non_optional_args[0])
2020
elif (
2121
getattr(annotation, "__origin__", None) is list
22-
or getattr(annotation, "__origin__", None) is list
22+
or getattr(annotation, "__origin__", None) is List
2323
):
2424
assert hasattr(annotation, "__args__") and len(annotation.__args__) >= 1 # type: ignore
2525
return _get_base_type(annotation.__args__[0]) # type: ignore
2626
return annotation
2727

2828

29-
def _contains_list_type(annotation: type[Any] | None) -> bool:
29+
def _contains_list_type(annotation: Type[Any] | None) -> bool:
3030
origin = getattr(annotation, "__origin__", None)
3131

32-
if origin is list or origin is list:
32+
if origin is list or origin is List:
3333
return True
34-
if origin in (Literal, Union):
34+
elif origin in (Literal, Union):
3535
return any(_contains_list_type(arg) for arg in annotation.__args__) # type: ignore
36-
return False
36+
else:
37+
return False
3738

3839

3940
def _parse_bool_arg(arg: str | bytes | bool) -> bool:
@@ -47,12 +48,13 @@ def _parse_bool_arg(arg: str | bytes | bool) -> bool:
4748

4849
if arg_str in true_values:
4950
return True
50-
if arg_str in false_values:
51+
elif arg_str in false_values:
5152
return False
52-
raise ValueError(f"Invalid boolean argument: {arg}")
53+
else:
54+
raise ValueError(f"Invalid boolean argument: {arg}")
5355

5456

55-
def add_args_from_model(parser: argparse.ArgumentParser, model: type[BaseModel]):
57+
def add_args_from_model(parser: argparse.ArgumentParser, model: Type[BaseModel]):
5658
"""Add arguments from a pydantic model to an argparse parser."""
5759

5860
for name, field in model.model_fields.items():
@@ -80,7 +82,7 @@ def add_args_from_model(parser: argparse.ArgumentParser, model: type[BaseModel])
8082
)
8183

8284

83-
T = TypeVar("T", bound=type[BaseModel])
85+
T = TypeVar("T", bound=Type[BaseModel])
8486

8587

8688
def parse_model_from_args(model: T, args: argparse.Namespace) -> T:
@@ -90,5 +92,5 @@ def parse_model_from_args(model: T, args: argparse.Namespace) -> T:
9092
k: v
9193
for k, v in vars(args).items()
9294
if v is not None and k in model.model_fields
93-
},
95+
}
9496
)

0 commit comments

Comments
 (0)