1+ import json
12import os
23import subprocess
34import sys
@@ -31,6 +32,7 @@ def generate_response(
3132 question_num : Optional [int ] = None ,
3233 test_output : Optional [Path ] = None ,
3334 llama_mode : Optional [str ] = None ,
35+ json_schema : Optional [str ] = None ,
3436 ) -> Optional [Tuple [str , str ]]:
3537 """
3638 Generate a model response using the prompt and assignment files.
@@ -44,18 +46,28 @@ def generate_response(
4446 test_output (Optional[Path]): Path Object pointing to the test output file.
4547 llama_mode (Optional[str]): Optional mode to invoke llama.cpp in.
4648 question_num (Optional[int]): An optional question number to target specific content.
49+ json_schema (Optional[str]): Optional json schema to use.
4750
4851 Returns:
4952 Optional[Tuple[str, str]]: A tuple containing the prompt and the model's response,
5053 or None if the response was invalid.
5154 """
55+ if json_schema :
56+ schema_path = Path (json_schema )
57+ if not schema_path .exists ():
58+ raise FileNotFoundError (f"JSON schema file not found: { schema_path } " )
59+ with open (schema_path , "r" , encoding = "utf-8" ) as f :
60+ schema = json .load (f )
61+ else :
62+ schema = None
63+
5264 prompt = f"{ system_instructions } \n { prompt } "
5365 if llama_mode == 'server' :
5466 self ._ensure_env_vars ('LLAMA_SERVER_URL' )
55- response = self ._get_response_server (prompt )
67+ response = self ._get_response_server (prompt , schema )
5668 else :
5769 self ._ensure_env_vars ('LLAMA_MODEL_PATH' , 'LLAMA_CLI_PATH' )
58- response = self ._get_response_cli (prompt )
70+ response = self ._get_response_cli (prompt , schema )
5971
6072 response = response .strip ()
6173
@@ -81,24 +93,24 @@ def _ensure_env_vars(self, *names):
8193 if missing :
8294 raise RuntimeError (f"Error: Environment variable(s) { ', ' .join (missing )} not set" )
8395
84- def _get_response_server (
85- self ,
86- prompt : str ,
87- ) -> str :
96+ def _get_response_server (self , prompt : str , schema : Optional [dict ] = None ) -> str :
8897 """
8998 Generate a model response using the prompt
9099
91100 Args:
92101 prompt (str): The input prompt provided by the user.
102+ schema (Optional[dict]): Optional schema provided by the user.
93103
94104 Returns:
95105 str: A tuple containing the model response or None if the response was invalid.
96106 """
97107 url = f"{ LLAMA_SERVER_URL } /v1/completions"
98108
99- payload = {
100- "prompt" : prompt ,
101- }
109+ payload = {"prompt" : prompt , "temperature" : 0.7 , "max_tokens" : 1000 }
110+
111+ if schema :
112+ raw_schema = schema .get ("schema" , schema )
113+ payload ["json_schema" ] = raw_schema
102114
103115 try :
104116 response = requests .post (url , json = payload , timeout = 3000 )
@@ -116,15 +128,13 @@ def _get_response_server(
116128
117129 return model_output
118130
119- def _get_response_cli (
120- self ,
121- prompt : str ,
122- ) -> str :
131+ def _get_response_cli (self , prompt : str , schema : Optional [dict ] = None ) -> str :
123132 """
124133 Generate a model response using the prompt
125134
126135 Args:
127136 prompt (str): The input prompt provided by the user.
137+ schema (Optional[dict]): Optional schema provided by the user.
128138
129139 Returns:
130140 str: The model response or None if the response was invalid.
@@ -141,6 +151,10 @@ def _get_response_cli(
141151 "--no-display-prompt" ,
142152 ]
143153
154+ if schema :
155+ raw_schema = schema ["schema" ] if "schema" in schema else schema
156+ cmd += ["--json-schema" , json .dumps (raw_schema )]
157+
144158 try :
145159 completed = subprocess .run (
146160 cmd , input = prompt .encode (), check = True , stdout = subprocess .PIPE , stderr = subprocess .PIPE , timeout = 300
0 commit comments