|
| 1 | +import argparse # noqa: I001 |
| 2 | +import os |
| 3 | + |
| 4 | +import google.generativeai as genai |
| 5 | +import pandas as pd |
| 6 | +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed |
| 7 | +from tqdm import tqdm |
| 8 | + |
| 9 | +from templates import PROMPT_STRATEGY |
| 10 | + |
| 11 | +# TODO: generator-gemini.py to converge with generator.py |
| 12 | +API_KEY = "..." |
| 13 | +MODEL_NAME = "gemini-1.5-pro-001" |
| 14 | + |
| 15 | +genai.configure(api_key=API_KEY) |
| 16 | +model = genai.GenerativeModel(MODEL_NAME) |
| 17 | + |
| 18 | +safety_settings = { |
| 19 | + "HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE", |
| 20 | + "HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE", |
| 21 | + "HARM_CATEGORY_HARASSMENT": "BLOCK_NONE", |
| 22 | + "HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE", |
| 23 | +} |
| 24 | + |
| 25 | +parser = argparse.ArgumentParser() |
| 26 | +parser.add_argument("-o", "--output_dir", help="Directory to save outputs", default="./generated") |
| 27 | +args = parser.parse_args() |
| 28 | + |
| 29 | +df_questions = pd.read_json("questions.jsonl", orient="records", encoding="utf-8-sig", lines=True) |
| 30 | + |
| 31 | +if not os.path.exists(args.output_dir): |
| 32 | + os.makedirs(args.output_dir) |
| 33 | + |
| 34 | + |
| 35 | +@retry(stop=stop_after_attempt(10), wait=wait_fixed(1), retry=retry_if_exception_type(Exception)) |
| 36 | +def call_gemini_api(input_text): |
| 37 | + """Function to call the Gemini API and return the generated text.""" |
| 38 | + response = model.generate_content([input_text], safety_settings=safety_settings) |
| 39 | + |
| 40 | + if not response.candidates: |
| 41 | + raise ValueError("Invalid operation: No candidates returned in the response.") |
| 42 | + |
| 43 | + candidate = response.candidates[0] |
| 44 | + if not candidate.content.parts: |
| 45 | + print(candidate) |
| 46 | + raise ValueError("Invalid operation: No parts found in the candidate.") |
| 47 | + |
| 48 | + return candidate.content.parts[0].text |
| 49 | + |
| 50 | + |
| 51 | +for strategy_name, prompts in PROMPT_STRATEGY.items(): |
| 52 | + |
| 53 | + def format_single_turn_question(question): |
| 54 | + messages = prompts + [{"role": "user", "content": question[0]}] |
| 55 | + formatted_text = "\n".join([f"{message['role']}: {message['content']}" for message in messages]) |
| 56 | + return formatted_text |
| 57 | + |
| 58 | + single_turn_questions = df_questions["questions"].map(format_single_turn_question) |
| 59 | + single_turn_outputs = [] |
| 60 | + for formatted_text in tqdm(single_turn_questions, desc=f"Generating single-turn outputs for {strategy_name}"): |
| 61 | + generated_text = call_gemini_api(formatted_text) |
| 62 | + single_turn_outputs.append(generated_text) |
| 63 | + |
| 64 | + def format_double_turn_question(question, single_turn_output): |
| 65 | + messages = prompts + [ |
| 66 | + {"role": "user", "content": question[0]}, |
| 67 | + {"role": "assistant", "content": single_turn_output}, |
| 68 | + {"role": "user", "content": question[1]}, |
| 69 | + ] |
| 70 | + formatted_text = "\n".join([f"{message['role']}: {message['content']}" for message in messages]) |
| 71 | + return formatted_text |
| 72 | + |
| 73 | + multi_turn_questions = df_questions[["questions", "id"]].apply( |
| 74 | + lambda x: format_double_turn_question(x["questions"], single_turn_outputs[x["id"] - 1]), |
| 75 | + axis=1, |
| 76 | + ) |
| 77 | + multi_turn_outputs = [] |
| 78 | + for formatted_text in tqdm(multi_turn_questions, desc=f"Generating multi-turn outputs for {strategy_name}"): |
| 79 | + generated_text = call_gemini_api(formatted_text) |
| 80 | + multi_turn_outputs.append(generated_text) |
| 81 | + |
| 82 | + df_output = pd.DataFrame( |
| 83 | + { |
| 84 | + "id": df_questions["id"], |
| 85 | + "category": df_questions["category"], |
| 86 | + "questions": df_questions["questions"], |
| 87 | + "outputs": list(zip(single_turn_outputs, multi_turn_outputs)), |
| 88 | + "references": df_questions["references"], |
| 89 | + } |
| 90 | + ) |
| 91 | + output_path = os.path.join(args.output_dir, f"{strategy_name}.jsonl") |
| 92 | + df_output.to_json(output_path, orient="records", lines=True, force_ascii=False) |
| 93 | + print(f"Saved outputs to {output_path}") |
0 commit comments