Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion instructor/client_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def from_anthropic(

Args:
client: An instance of Anthropic client (sync or async)
mode: The mode to use for the client (ANTHROPIC_JSON or ANTHROPIC_TOOLS)
mode: The mode to use for the client (ANTHROPIC_JSON, ANTHROPIC_XML, or ANTHROPIC_TOOLS)
beta: Whether to use beta API features (uses client.beta.messages.create)
**kwargs: Additional keyword arguments to pass to the Instructor constructor

Expand All @@ -60,6 +60,7 @@ def from_anthropic(
"""
valid_modes = {
instructor.Mode.ANTHROPIC_JSON,
instructor.Mode.ANTHROPIC_XML,
instructor.Mode.ANTHROPIC_TOOLS,
instructor.Mode.ANTHROPIC_REASONING_TOOLS,
}
Expand Down
220 changes: 216 additions & 4 deletions instructor/function_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from instructor.utils import (
classproperty,
extract_json_from_codeblock,
extract_xml_from_codeblock,
map_to_gemini_function_schema,
)

Expand Down Expand Up @@ -156,6 +157,165 @@ def anthropic_schema(cls) -> dict[str, Any]:
"input_schema": cls.model_json_schema(),
}

@classproperty
def xml_schema(cls) -> str:
"""
Generate XML schema representation for the model, tailored for consumption by an LLM.

Returns:
str: XML schema template string that describes the expected structure
"""

def generate_xml_elements(schema: dict[str, Any], level: int = 0) -> str:
"""Generate XML element structure from JSON schema."""
if schema.get("type") == "object":
properties = schema.get("properties", {})
required = schema.get("required", [])
elements = []

for prop_name, prop_schema in properties.items():
is_required = prop_name in required
if prop_schema.get("type") == "array":
item_schema = prop_schema.get("items", {})
if item_schema.get("type") == "object":
elements.append(f" {' ' * level}<{prop_name}>")
elements.append(
generate_xml_elements(item_schema, level + 1)
)
elements.append(f" {' ' * level}</{prop_name}>")
else:
elements.append(
f" {' ' * level}<{prop_name}>[{item_schema.get('type', 'value')}]</{prop_name}>"
)
elif prop_schema.get("type") == "object":
elements.append(f" {' ' * level}<{prop_name}>")
elements.append(generate_xml_elements(prop_schema, level + 1))
elements.append(f" {' ' * level}</{prop_name}>")
else:
prop_type = prop_schema.get("type", "string")
elements.append(
f" {' ' * level}<{prop_name}>[{prop_type}]</{prop_name}>"
)

return "\n".join(elements)
return ""

json_schema = cls.model_json_schema()
class_name = cls.__name__

xml_template = f"""<{class_name}>
{generate_xml_elements(json_schema)}
</{class_name}>"""

return xml_template

@classproperty
def xsd_schema(cls) -> str:
"""
Generate XSD schema representation for the model.

Returns:
str: XSD schema string that can be used with xmlschema.XMLSchema()
"""

def json_type_to_xsd_type(json_type: str) -> str:
"""Map JSON Schema types to XSD types."""
type_mapping = {
"string": "xs:string",
"integer": "xs:int",
"number": "xs:decimal",
"boolean": "xs:boolean",
}
return type_mapping.get(json_type, "xs:string")

def resolve_ref(ref_path: str, full_schema: dict[str, Any]) -> dict[str, Any]:
"""Resolve $ref references in JSON schema."""
if ref_path.startswith("#/$defs/"):
def_name = ref_path.replace("#/$defs/", "")
return full_schema.get("$defs", {}).get(def_name, {})
return {}

def generate_xsd_elements(schema: dict[str, Any], level: int = 0, full_schema: dict[str, Any] = None) -> str:
"""Generate XSD element definitions from JSON schema."""
if full_schema is None:
full_schema = schema

if schema.get("type") == "object":
properties = schema.get("properties", {})
required = schema.get("required", [])
elements = []

for prop_name, prop_schema in properties.items():
is_required = prop_name in required
min_occurs = "1" if is_required else "0"

if prop_schema.get("type") == "array":
# Handle arrays - arrays in XML are represented as repeated elements
item_schema = prop_schema.get("items", {})

# Handle $ref in array items
if "$ref" in item_schema:
resolved_schema = resolve_ref(item_schema["$ref"], full_schema)
if resolved_schema.get("type") == "object":
# Array of objects - each item is a separate element
elements.append(f'{" " * (level + 1)}<xs:element name="{prop_name}" minOccurs="{min_occurs}" maxOccurs="unbounded">')
elements.append(f'{" " * (level + 2)}<xs:complexType>')
elements.append(f'{" " * (level + 3)}<xs:sequence>')
elements.append(generate_xsd_elements(resolved_schema, level + 3, full_schema))
elements.append(f'{" " * (level + 3)}</xs:sequence>')
elements.append(f'{" " * (level + 2)}</xs:complexType>')
elements.append(f'{" " * (level + 1)}</xs:element>')
else:
# Array of primitives via ref
xsd_type = json_type_to_xsd_type(resolved_schema.get("type", "string"))
elements.append(f'{" " * (level + 1)}<xs:element name="{prop_name}" type="{xsd_type}" minOccurs="{min_occurs}" maxOccurs="unbounded" />')
elif item_schema.get("type") == "object":
# Array of objects - each item is a separate element
elements.append(f'{" " * (level + 1)}<xs:element name="{prop_name}" minOccurs="{min_occurs}" maxOccurs="unbounded">')
elements.append(f'{" " * (level + 2)}<xs:complexType>')
elements.append(f'{" " * (level + 3)}<xs:sequence>')
elements.append(generate_xsd_elements(item_schema, level + 3, full_schema))
elements.append(f'{" " * (level + 3)}</xs:sequence>')
elements.append(f'{" " * (level + 2)}</xs:complexType>')
elements.append(f'{" " * (level + 1)}</xs:element>')
else:
# Array of primitives - repeated elements with same name
xsd_type = json_type_to_xsd_type(item_schema.get("type", "string"))
elements.append(f'{" " * (level + 1)}<xs:element name="{prop_name}" type="{xsd_type}" minOccurs="{min_occurs}" maxOccurs="unbounded" />')

elif prop_schema.get("type") == "object":
# Nested object
elements.append(f'{" " * (level + 1)}<xs:element name="{prop_name}" minOccurs="{min_occurs}" maxOccurs="1">')
elements.append(f'{" " * (level + 2)}<xs:complexType>')
elements.append(f'{" " * (level + 3)}<xs:sequence>')
elements.append(generate_xsd_elements(prop_schema, level + 3, full_schema))
elements.append(f'{" " * (level + 3)}</xs:sequence>')
elements.append(f'{" " * (level + 2)}</xs:complexType>')
elements.append(f'{" " * (level + 1)}</xs:element>')

else:
# Primitive type
xsd_type = json_type_to_xsd_type(prop_schema.get("type", "string"))
elements.append(f'{" " * (level + 1)}<xs:element name="{prop_name}" type="{xsd_type}" minOccurs="{min_occurs}" maxOccurs="1" />')

return "\n".join(elements)
return ""

json_schema = cls.model_json_schema()
class_name = cls.__name__

xsd_template = f'''<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
<xs:element name="{class_name}">
<xs:complexType>
<xs:sequence>
{generate_xsd_elements(json_schema)}
</xs:sequence>
</xs:complexType>
</xs:element>
</xs:schema>'''

return xsd_template

@classproperty
def gemini_schema(cls) -> Any:
import google.generativeai.types as genai_types
Expand Down Expand Up @@ -191,15 +351,15 @@ def from_response(
cls (OpenAISchema): An instance of the class
"""

if mode == Mode.ANTHROPIC_TOOLS:
return cls.parse_anthropic_tools(completion, validation_context, strict)

if mode == Mode.ANTHROPIC_TOOLS or mode == Mode.ANTHROPIC_REASONING_TOOLS:
if mode in {Mode.ANTHROPIC_TOOLS, Mode.ANTHROPIC_REASONING_TOOLS}:
return cls.parse_anthropic_tools(completion, validation_context, strict)

if mode == Mode.ANTHROPIC_JSON:
return cls.parse_anthropic_json(completion, validation_context, strict)

if mode == Mode.ANTHROPIC_XML:
return cls.parse_anthropic_xml(completion, validation_context, strict)

if mode == Mode.BEDROCK_JSON:
return cls.parse_bedrock_json(completion, validation_context, strict)

Expand Down Expand Up @@ -411,6 +571,58 @@ def parse_anthropic_json(

return model

@classmethod
def parse_anthropic_xml(
cls: type[BaseModel],
completion: ChatCompletion,
validation_context: Optional[dict[str, Any]] = None,
strict: Optional[bool] = None,
) -> BaseModel:
from anthropic.types import Message

last_block = None

if hasattr(completion, "choices"):
completion = completion.choices[0]
if completion.finish_reason == "length":
raise IncompleteOutputException(last_completion=completion)
text = completion.message.content
else:
assert isinstance(completion, Message)
if completion.stop_reason == "max_tokens":
raise IncompleteOutputException(last_completion=completion)
# Find the last text block in the completion
text_blocks = [c for c in completion.content if c.type == "text"]
last_block = text_blocks[-1]
# strip (\u0000-\u001F) control characters
text = re.sub(r"[\u0000-\u001F]", "", last_block.text)

xml_text = extract_xml_from_codeblock(text)

try:
import xmlschema
except ImportError:
raise ImportError(
"xmlschema is required for XML parsing. Install it with: pip install xmlschema"
) from None

try:
# Generate XSD schema for this model & validate against it
xsd_schema_str = cls.xsd_schema
schema = xmlschema.XMLSchema(xsd_schema_str)
parsed_dict = schema.to_dict(xml_text)
model = cls.model_validate(
parsed_dict, context=validation_context, strict=True if strict else False
)

return model

except Exception as e:
logger.error(f"Error parsing XML: {e}")
raise ValueError(
f"Failed to parse XML content: {e}. XML content: {xml_text}"
) from None

@classmethod
def parse_bedrock_json(
cls: type[BaseModel],
Expand Down
2 changes: 2 additions & 0 deletions instructor/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class Mode(enum.Enum):
ANTHROPIC_TOOLS = "anthropic_tools"
ANTHROPIC_REASONING_TOOLS = "anthropic_reasoning_tools"
ANTHROPIC_JSON = "anthropic_json"
ANTHROPIC_XML = "anthropic_xml"

# Mistral modes
MISTRAL_TOOLS = "mistral_tools"
Expand Down Expand Up @@ -99,6 +100,7 @@ def json_modes(cls) -> set["Mode"]:
cls.MD_JSON,
cls.JSON_SCHEMA,
cls.ANTHROPIC_JSON,
cls.ANTHROPIC_XML,
cls.VERTEXAI_JSON,
cls.GEMINI_JSON,
cls.COHERE_JSON_SCHEMA,
Expand Down
37 changes: 36 additions & 1 deletion instructor/process_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,40 @@ def handle_anthropic_json(
return response_model, new_kwargs


def handle_anthropic_xml(
response_model: type[T], new_kwargs: dict[str, Any]
) -> tuple[type[T], dict[str, Any]]:
system_messages = extract_system_messages(new_kwargs.get("messages", []))

if system_messages:
new_kwargs["system"] = combine_system_messages(
new_kwargs.get("system"), system_messages
)

new_kwargs["messages"] = [
m for m in new_kwargs.get("messages", []) if m["role"] != "system"
]

xml_schema_message = dedent(
f"""
As a genius expert, your task is to understand the content and provide
the parsed objects in XML that match the following xml_schema:\n

{response_model.xml_schema}

Make sure to return an instance of the XML, not the schema itself.
Use proper XML formatting with appropriate opening and closing tags.
"""
)

new_kwargs["system"] = combine_system_messages(
new_kwargs.get("system"),
[{"type": "text", "text": xml_schema_message}],
)

return response_model, new_kwargs


def handle_cohere_modes(new_kwargs: dict[str, Any]) -> tuple[None, dict[str, Any]]:
messages = new_kwargs.pop("messages", [])
chat_history = []
Expand Down Expand Up @@ -1099,7 +1133,7 @@ def handle_response_model(
mode,
autodetect_images=autodetect_images,
)
if mode in {Mode.ANTHROPIC_JSON, Mode.ANTHROPIC_TOOLS}:
if mode in {Mode.ANTHROPIC_JSON, Mode.ANTHROPIC_XML, Mode.ANTHROPIC_TOOLS}:
# Handle OpenAI style or Anthropic style messages
new_kwargs["messages"] = [m for m in messages if m["role"] != "system"]
if "system" not in new_kwargs:
Expand Down Expand Up @@ -1137,6 +1171,7 @@ def handle_response_model(
Mode.ANTHROPIC_TOOLS: handle_anthropic_tools,
Mode.ANTHROPIC_REASONING_TOOLS: handle_anthropic_reasoning_tools,
Mode.ANTHROPIC_JSON: handle_anthropic_json,
Mode.ANTHROPIC_XML: handle_anthropic_xml,
Mode.COHERE_JSON_SCHEMA: handle_cohere_json_schema,
Mode.COHERE_TOOLS: handle_cohere_tools,
Mode.GEMINI_JSON: handle_gemini_json,
Expand Down
37 changes: 37 additions & 0 deletions instructor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ def get_provider(base_url: str) -> Provider:
_JSON_CODEBLOCK_PATTERN = re.compile(r"```(?:json)?\s*(.*?)\s*```", re.DOTALL)
_JSON_PATTERN = re.compile(r"({[\s\S]*})")

# Regex patterns for XML extraction
_XML_CODEBLOCK_PATTERN = re.compile(r"```(?:xml)?\s*(.*?)\s*```", re.DOTALL)
_XML_PATTERN = re.compile(r"(<[\s\S]*>)")


def extract_json_from_codeblock(content: str) -> str:
"""
Expand Down Expand Up @@ -133,6 +137,39 @@ def extract_json_from_codeblock(content: str) -> str:
return json_content


def extract_xml_from_codeblock(content: str) -> str:
"""
Extract XML from a string that may contain markdown code blocks or plain XML.

This function uses regex patterns to extract XML more efficiently.

Args:
content: The string that may contain XML

Returns:
The extracted XML string
"""
# First try to find XML in code blocks
match = _XML_CODEBLOCK_PATTERN.search(content)
if match:
xml_content = match.group(1).strip()
else:
# Look for XML tags with the pattern < ... >
match = _XML_PATTERN.search(content)
if match:
xml_content = match.group(1)
else:
# Fallback: try to find XML start and end tags
first_bracket = content.find("<")
last_bracket = content.rfind(">")
if first_bracket != -1 and last_bracket != -1:
xml_content = content[first_bracket : last_bracket + 1]
else:
xml_content = content # Return as is if no XML-like content found

return xml_content


def extract_json_from_stream(
chunks: Iterable[str],
) -> Generator[str, None, None]:
Expand Down
Loading