Skip to content

Commit 0da27c0

Browse files
author
Shing Lyu
committed
Fix: fix FLAN-XXL input/output format
1 parent 54a37e1 commit 0da27c0

2 files changed

Lines changed: 5 additions & 4 deletions

File tree

kendra_retriever_samples/kendra_chat_flan_xxl.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,13 @@ class ContentHandler(LLMContentHandler):
3030
accepts = "application/json"
3131

3232
def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
33-
input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs})
33+
input_str = json.dumps({"text_inputs": prompt, **model_kwargs})
3434
return input_str.encode('utf-8')
3535

3636
def transform_output(self, output: bytes) -> str:
3737
response_json = json.loads(output.read().decode("utf-8"))
38-
return response_json[0]["generated_text"]
38+
print(response_json)
39+
return response_json["generated_texts"][0]
3940

4041
content_handler = ContentHandler()
4142

kendra_retriever_samples/kendra_retriever_flan_xxl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ class ContentHandler(LLMContentHandler):
1818
accepts = "application/json"
1919

2020
def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
21-
input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs})
21+
input_str = json.dumps({"text_inputs": prompt, **model_kwargs})
2222
return input_str.encode('utf-8')
2323

2424
def transform_output(self, output: bytes) -> str:
2525
response_json = json.loads(output.read().decode("utf-8"))
26-
return response_json[0]["generated_text"]
26+
return response_json["generated_texts"][0]
2727

2828
content_handler = ContentHandler()
2929

0 commit comments

Comments
 (0)