Skip to content
This repository was archived by the owner on Oct 17, 2024. It is now read-only.

Commit 43c10e5

Browse files
authored
Merge pull request #44 from instructkr/feat/gemini
[Feature] Add Gemini Client for generations
2 parents 1a584c9 + 8e688b1 commit 43c10e5

File tree

15 files changed

+636
-0
lines changed

15 files changed

+636
-0
lines changed

.github/workflows/ci.yaml

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
name: CI
2+
3+
on: [push]
4+
5+
env:
6+
OMP_NUM_THREADS: 2
7+
MKL_NUM_THREADS: 2
8+
PIP_DISABLE_PIP_VERSION_CHECK: 1
9+
10+
jobs:
11+
lint:
12+
runs-on: ubuntu-latest
13+
strategy:
14+
matrix:
15+
python-version: ["3.11"]
16+
17+
steps:
18+
- uses: actions/checkout@v4
19+
20+
- name: Set up Python ${{ matrix.python-version }}
21+
uses: actions/setup-python@v5
22+
with:
23+
python-version: ${{ matrix.python-version }}
24+
25+
- uses: actions/cache@v4
26+
name: Cache pip packages
27+
with:
28+
path: ~/.cache/uv
29+
key: ${{ runner.os }}-python-${{ matrix.python-version }}
30+
31+
- name: Install uv
32+
run: curl -LsSf https://astral.sh/uv/install.sh | sh
33+
34+
- name: Install dependencies
35+
run: uv pip install --system -r requirements-format.txt
36+
37+
- name: Check lint
38+
run: make check

evaluated/google/gemini_1.5_flash_001/1-shot.jsonl

+42
Large diffs are not rendered by default.

evaluated/google/gemini_1.5_flash_001/cot-1-shot.jsonl

+42
Large diffs are not rendered by default.

evaluated/google/gemini_1.5_flash_001/default.jsonl

+42
Large diffs are not rendered by default.

evaluated/google/gemini_1.5_pro_001/1-shot.jsonl

+42
Large diffs are not rendered by default.

evaluated/google/gemini_1.5_pro_001/cot-1-shot.jsonl

+42
Large diffs are not rendered by default.

evaluated/google/gemini_1.5_pro_001/default.jsonl

+42
Large diffs are not rendered by default.

generated/google/gemini_1.5_flash_001/1-shot.jsonl

+42
Large diffs are not rendered by default.

generated/google/gemini_1.5_flash_001/cot-1-shot.jsonl

+42
Large diffs are not rendered by default.

generated/google/gemini_1.5_flash_001/default.jsonl

+42
Large diffs are not rendered by default.

generated/google/gemini_1.5_pro_001/1-shot.jsonl

+42
Large diffs are not rendered by default.

generated/google/gemini_1.5_pro_001/cot-1-shot.jsonl

+42
Large diffs are not rendered by default.

generated/google/gemini_1.5_pro_001/default.jsonl

+42
Large diffs are not rendered by default.

generator-gemini.py

+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import argparse # noqa: I001
2+
import os
3+
4+
import google.generativeai as genai
5+
import pandas as pd
6+
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
7+
from tqdm import tqdm
8+
9+
from templates import PROMPT_STRATEGY
10+
11+
# TODO: generator-gemini.py to converge with generator.py
12+
API_KEY = "..."
13+
MODEL_NAME = "gemini-1.5-pro-001"
14+
15+
genai.configure(api_key=API_KEY)
16+
model = genai.GenerativeModel(MODEL_NAME)
17+
18+
safety_settings = {
19+
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE",
20+
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
21+
"HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
22+
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE",
23+
}
24+
25+
parser = argparse.ArgumentParser()
26+
parser.add_argument("-o", "--output_dir", help="Directory to save outputs", default="./generated")
27+
args = parser.parse_args()
28+
29+
df_questions = pd.read_json("questions.jsonl", orient="records", encoding="utf-8-sig", lines=True)
30+
31+
if not os.path.exists(args.output_dir):
32+
os.makedirs(args.output_dir)
33+
34+
35+
@retry(stop=stop_after_attempt(10), wait=wait_fixed(1), retry=retry_if_exception_type(Exception))
36+
def call_gemini_api(input_text):
37+
"""Function to call the Gemini API and return the generated text."""
38+
response = model.generate_content([input_text], safety_settings=safety_settings)
39+
40+
if not response.candidates:
41+
raise ValueError("Invalid operation: No candidates returned in the response.")
42+
43+
candidate = response.candidates[0]
44+
if not candidate.content.parts:
45+
print(candidate)
46+
raise ValueError("Invalid operation: No parts found in the candidate.")
47+
48+
return candidate.content.parts[0].text
49+
50+
51+
for strategy_name, prompts in PROMPT_STRATEGY.items():
52+
53+
def format_single_turn_question(question):
54+
messages = prompts + [{"role": "user", "content": question[0]}]
55+
formatted_text = "\n".join([f"{message['role']}: {message['content']}" for message in messages])
56+
return formatted_text
57+
58+
single_turn_questions = df_questions["questions"].map(format_single_turn_question)
59+
single_turn_outputs = []
60+
for formatted_text in tqdm(single_turn_questions, desc=f"Generating single-turn outputs for {strategy_name}"):
61+
generated_text = call_gemini_api(formatted_text)
62+
single_turn_outputs.append(generated_text)
63+
64+
def format_double_turn_question(question, single_turn_output):
65+
messages = prompts + [
66+
{"role": "user", "content": question[0]},
67+
{"role": "assistant", "content": single_turn_output},
68+
{"role": "user", "content": question[1]},
69+
]
70+
formatted_text = "\n".join([f"{message['role']}: {message['content']}" for message in messages])
71+
return formatted_text
72+
73+
multi_turn_questions = df_questions[["questions", "id"]].apply(
74+
lambda x: format_double_turn_question(x["questions"], single_turn_outputs[x["id"] - 1]),
75+
axis=1,
76+
)
77+
multi_turn_outputs = []
78+
for formatted_text in tqdm(multi_turn_questions, desc=f"Generating multi-turn outputs for {strategy_name}"):
79+
generated_text = call_gemini_api(formatted_text)
80+
multi_turn_outputs.append(generated_text)
81+
82+
df_output = pd.DataFrame(
83+
{
84+
"id": df_questions["id"],
85+
"category": df_questions["category"],
86+
"questions": df_questions["questions"],
87+
"outputs": list(zip(single_turn_outputs, multi_turn_outputs)),
88+
"references": df_questions["references"],
89+
}
90+
)
91+
output_path = os.path.join(args.output_dir, f"{strategy_name}.jsonl")
92+
df_output.to_json(output_path, orient="records", lines=True, force_ascii=False)
93+
print(f"Saved outputs to {output_path}")

requirements-format.txt

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
ruff==0.4.9

0 commit comments

Comments
 (0)