Skip to content

Commit 930d659

Browse files
Improve PersonaGenerator initialization
Added methods to infer placeholder values and build persona initialization data. Otherwise, when a persona was defined with mandatory attributes with no default values, the code was crashing.
1 parent 7a52b13 commit 930d659

1 file changed

Lines changed: 62 additions & 3 deletions

File tree

src/sdialog/generators/__init__.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
from tqdm.auto import tqdm
1111
from jinja2 import Template
12-
from pydantic import BaseModel
13-
from typing import Union, List, Any, Optional
12+
from pydantic import BaseModel, ValidationError
13+
from typing import Union, List, Any, Optional, get_origin, get_args
1414
from langchain_core.messages import HumanMessage, SystemMessage
1515
from langchain_core.language_models.base import BaseLanguageModel
1616

@@ -367,6 +367,60 @@ class PersonaGenerator(BaseAttributeModelGenerator):
367367
:param llm_kwargs: Extra LLM keyword arguments.
368368
:type llm_kwargs: dict
369369
"""
370+
@staticmethod
371+
def _infer_placeholder_value(annotation: Any) -> Any:
372+
"""Infer a conservative placeholder value from a type annotation."""
373+
origin = get_origin(annotation)
374+
375+
if origin is Union:
376+
args = get_args(annotation)
377+
non_none_args = [arg for arg in args if arg is not type(None)]
378+
# If Optional[...] is allowed, None is the safest placeholder.
379+
if len(non_none_args) != len(args):
380+
return None
381+
for arg in non_none_args:
382+
value = PersonaGenerator._infer_placeholder_value(arg)
383+
if value is not None:
384+
return value
385+
return ""
386+
387+
if origin in (list, List):
388+
return []
389+
if origin is dict:
390+
return {}
391+
if origin is tuple:
392+
return ()
393+
394+
if annotation is str:
395+
return ""
396+
if annotation is int:
397+
return 0
398+
if annotation is float:
399+
return 0.0
400+
if annotation is bool:
401+
return False
402+
if annotation is Any:
403+
return ""
404+
405+
if isinstance(annotation, type):
406+
try:
407+
return annotation()
408+
except Exception:
409+
return ""
410+
411+
return ""
412+
413+
@classmethod
414+
def _build_persona_init_data(cls, persona_cls: type[BasePersona]) -> dict[str, Any]:
415+
"""Build initialization data for every declared field in a persona class."""
416+
init_data = {}
417+
for field_name, field_info in persona_cls.model_fields.items():
418+
if field_info.is_required():
419+
init_data[field_name] = cls._infer_placeholder_value(field_info.annotation)
420+
else:
421+
init_data[field_name] = field_info.get_default(call_default_factory=True)
422+
return init_data
423+
370424
def __init__(self,
371425
persona: BasePersona,
372426
generated_attributes: str = "all",
@@ -377,7 +431,12 @@ def __init__(self,
377431
if isinstance(persona, BasePersona):
378432
persona_instance = persona
379433
elif isinstance(persona, type) and issubclass(persona, BasePersona):
380-
persona_instance = persona()
434+
try:
435+
persona_instance = persona()
436+
except ValidationError:
437+
# Build an instance with placeholders so required fields without defaults
438+
# do not break generator initialization.
439+
persona_instance = persona.model_construct(**self._build_persona_init_data(persona))
381440
else:
382441
raise ValueError("persona must be a BasePersona instance or subclass.")
383442
system_prompt = "You are an expert at generating persona JSON objects for synthetic dialogue generation."

0 commit comments

Comments
 (0)