Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@ python llm_judge/gen_model_answer.py --config <CONFIG-PATH>

Arguments & Options:
- `<CONFIG-PATH>` is the path to a configuration file. Examples are in `configs/`.
- `num_answers_per_question` specifies how many to generate (default: all)

For example:

```bash
python llm_judge/gen_model_answer.py --config configs/rinna--japanese-gpt-neox-3.6b-instruction-ppo.json
python llm_judge/gen_model_answer.py --config configs/rinna--japanese-gpt-neox-3.6b-instruction-ppo.json --num_answers_per_question <n>
```



#### Step 2. Generate GPT-4 judgments

There are several options to use GPT-4 as a judge, such as pairwise win-rate and single-answer grading.
Expand All @@ -43,7 +46,8 @@ OPENAI_API_KEY=<YOUR-KEY> python llm_judge/gen_judgment.py \
[--baseline-model <BASELINE-MODEL-ID>] \
[--model-list <LIST-OF-MODEL-IDS>] \
[--yes] \
[--wandb]
[--wandb] \
[--num_answers_per_question]
```

Arguments & Options:
Expand All @@ -55,6 +59,7 @@ Arguments & Options:
- `--model-list <LIST-OF-MODEL-IDS>` is a list of model IDs to be evaluated. If not specified, all models in `data/jp_bench/model_answer` will be evaluated.
- `--yes` is a flag to skip the confirmation prompt.
- `--wandb` is a flag to enable logging to W&B. You can upload the results later to W&B by running `upload_result.py`, as described in the next section.
- `num_answers_per_question` : Number of answers to evaluate per question

**Mode: `pairwise-baseline` (Default)**

Expand Down
30 changes: 0 additions & 30 deletions configs/README.md
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

configs/ を消しているのはなぜですか?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

こちらは必要ないコミットをしてしまったので削除しました。

This file was deleted.

13 changes: 0 additions & 13 deletions configs/cyberagent--calm2-7b-chat.json

This file was deleted.

This file was deleted.

This file was deleted.

16 changes: 0 additions & 16 deletions configs/openai--text-davinci-003.json

This file was deleted.

16 changes: 0 additions & 16 deletions configs/rinna--japanese-gpt-neox-3.6b-instruction-ppo.json

This file was deleted.

16 changes: 0 additions & 16 deletions configs/rinna--japanese-gpt-neox-3.6b-instruction-sft-v2.json

This file was deleted.

13 changes: 0 additions & 13 deletions configs/tokyotech-llm--Swallow-70b-instruct-hf.json

This file was deleted.

24 changes: 16 additions & 8 deletions llm_judge/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,20 @@
from typing import Optional, Union

import openai
from openai import AzureOpenAI
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Azure の API だけでなく OpenAI の API でも動く実装にしてください.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AzureのAPIしか現状使用できないので、検証はできませんが大丈夫でしょうか。

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

それならこの部分はこちらで実装 & テストします.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

承知しました。


client = AzureOpenAI(api_key=os.getenv("OPENAI_API_KEY"),
api_version=os.getenv("OPENAI_API_VERSION"))
import tiktoken
from dotenv import load_dotenv

logger = logging.getLogger(__name__)

load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY")
openai.organization = os.getenv("OPENAI_ORGANIZATION")
openai.api_type = os.getenv("OPENAI_API_TYPE")
openai.api_base = os.getenv("OPENAI_API_BASE")
openai.api_version = os.getenv("OPENAI_API_VERSION")
# TODO: The 'openai.organization' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(organization=os.getenv("OPENAI_ORGANIZATION"))'
# openai.organization = os.getenv("OPENAI_ORGANIZATION")
# TODO: The 'openai.api_base' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(base_url=os.getenv("OPENAI_API_BASE"))'
# openai.api_base = os.getenv("OPENAI_API_BASE")

# Data paths
JP_BENCH_DIR = Path(__file__).resolve().parent.parent / "data" / "jp_bench"
Expand Down Expand Up @@ -68,9 +71,9 @@ def judge(self, **kwargs):
params["engine"] = self.model
else:
params["model"] = self.model
response = openai.ChatCompletion.create(**params)
return response["choices"][0]["message"]["content"]
except openai.error.OpenAIError as e:
response = client.chat.completions.create(**params)
return response.choices[0].message.content
except openai.OpenAIError as e:
logger.warning(f"OpenAI API error: {e}")
time.sleep(API_RETRY_SLEEP)

Expand Down Expand Up @@ -363,3 +366,8 @@ def filter_pairwise_judgements(
else:
filtered_result_id_results_map[result_id] = results
return filtered_result_id_results_map





16 changes: 15 additions & 1 deletion llm_judge/gen_judgment.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def make_match_groups_single(
ref_answers: dict[str, dict[int, dict]],
judge_default: Judge,
judge_math: Judge,
num_answers_per_question: Optional[int] = None,
):
"""Make match groups for single answer grading.

Expand All @@ -41,6 +42,7 @@ def make_match_groups_single(
ref_answers (dict): A dict of reference answers.
judge_default (Judge): A judge for default questions.
judge_math (Judge): A judge for math questions.
num_answers_per_question (Optional[int]): Number of answers to evaluate per question.
"""
match_groups = {}
for model in model_answers:
Expand All @@ -63,6 +65,8 @@ def make_match_groups_single(
ref_answer=ref_answer,
)
)
if num_answers_per_question:
matches = matches[:num_answers_per_question]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

実装が間違っています.各質問について num_answers_per_question 件の回答を抽出してください.

match_groups[f"single:{model}"] = matches
return match_groups

Expand All @@ -74,6 +78,7 @@ def make_match_groups_pairwise(
judge_default: Judge,
judge_math: Judge,
baseline_model: Optional[str] = None,
num_answers_per_question: Optional[int] = None,
):
"""Make match groups for pairwise comparison.

Expand All @@ -84,6 +89,7 @@ def make_match_groups_pairwise(
judge_default (Judge): A judge for default questions.
judge_math (Judge): A judge for math questions.
baseline_model (Optional[str]): The baseline model.
num_answers_per_question (Optional[int]): Number of answers to evaluate per question.
"""
match_groups = {}
for model_1, model_2 in combinations(model_answers, 2):
Expand Down Expand Up @@ -111,6 +117,8 @@ def make_match_groups_pairwise(
ref_answer=ref_answer,
)
)
if num_answers_per_question:
matches = matches[:num_answers_per_question]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

実装が間違っています.各質問について num_answers_per_question 件の回答を抽出してください.

match_groups[f"pairwise:{model_1}_{model_2}"] = matches
return match_groups

Expand All @@ -132,7 +140,7 @@ def make_match_groups_pairwise(
parser.add_argument(
"--judge-model",
type=str,
default="gpt-4",
default="gpt-4-0613",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

judge-model のデフォルト値は gpt-4 のままにしておいてください.これは複数回の評価をサポートするための PR なので,それと関係ない変更はしないでください.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

すみません、こちらについては自分の環境のままpushしてしまいました。修正しておきました。

choices=["gpt-4", "gpt-4-0613", "gpt-4-1106-preview", "gpt-3.5-turbo"],
help="The judge model.",
)
Expand Down Expand Up @@ -167,6 +175,9 @@ def make_match_groups_pairwise(
parser.add_argument(
"--verbose", "-v", action="count", default=0, help="Verbosity level"
)
parser.add_argument(
"--num_answers_per_question", type=int, default=None, help="Number of answers to evaluate per question."
)
args = parser.parse_args()

if args.verbose == 0:
Expand Down Expand Up @@ -227,6 +238,7 @@ def make_match_groups_pairwise(
ref_answers=ref_answers,
judge_default=Judge(args.judge_model, judge_prompts["single"]),
judge_math=Judge(args.judge_model, judge_prompts["single-math"]),
num_answers_per_question=args.num_answers_per_question,
)
output_dir = JUDGEMENT_DIR / "single" / args.judge_model
else:
Expand All @@ -242,6 +254,7 @@ def make_match_groups_pairwise(
judge_default=Judge(args.judge_model, judge_prompts["pair"]),
judge_math=Judge(args.judge_model, judge_prompts["pair-math"]),
baseline_model=baseline_model,
num_answers_per_question=args.num_answers_per_question,
)
output_dir = JUDGEMENT_DIR / "pairwise" / args.judge_model
target_match_ids = set()
Expand Down Expand Up @@ -290,3 +303,4 @@ def make_match_groups_pairwise(
if args.wandb:
logger.info("Log to wandb")
upload_results(args.mode, match_id, results, args.baseline_model)

44 changes: 23 additions & 21 deletions llm_judge/gen_model_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
"generic": 0.1,
}


Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

品質管理のためにリンターとフォーマッターを入れています.以下のコマンドを実行してください.

$ pre-commit install  # 以降,コミット時に自動的にリンターとフォーマッターが走ります
$ pre-commit run -a  # 今いるディレクトリ以下の全ファイルにリンターとフォーマッターを適用します

def generate_response(
input_text, model, tokenizer, generation_config=None, special_token_map=None
):
Expand Down Expand Up @@ -64,7 +63,6 @@ def generate_response(

return output


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
Expand All @@ -79,6 +77,9 @@ def generate_response(
parser.add_argument(
"--overwrite", action="store_true", help="Overwrite the existing results"
)
parser.add_argument(
"--num_answers_per_question", type=int, default=1, help="Number of answers to generate per question"
)
args = parser.parse_args()

if args.verbose == 0:
Expand Down Expand Up @@ -159,25 +160,26 @@ def generate_response(
category = question["category"]
generation_config["temperature"] = DEFAULT_TEMPERATURE_MAP[category]

output = generate_response(
input_text=prompt_template.format_map({"instruction": instruction}),
model=model,
tokenizer=tokenizer,
generation_config=generation_config,
special_token_map=special_token_map,
)

logger.debug(f"{instruction}\n\n{output}")

results.append(
{
"question_id": int(question["question_id"]),
"answer_id": shortuuid.uuid(),
"model_id": model_id,
"choices": [{"index": 0, "turns": [output]}],
"tstamp": time.time(),
}
)
for _ in range(args.num_answers_per_question):
output = generate_response(
input_text=prompt_template.format_map({"instruction": instruction}),
model=model,
tokenizer=tokenizer,
generation_config=generation_config,
special_token_map=special_token_map,
)

logger.debug(f"{instruction}\n\n{output}")

results.append(
{
"question_id": int(question["question_id"]),
"answer_id": shortuuid.uuid(),
"model_id": model_id,
"choices": [{"index": 0, "turns": [output]}],
"tstamp": time.time(),
}
)

logger.info("Save the results")
prediction_dir.mkdir(parents=True, exist_ok=True)
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ classifiers = [
dependencies = [
"accelerate", "fastapi", "gradio==3.35.2", "httpx", "markdown2[all]", "nh3", "numpy",
"peft==0.5", "prompt_toolkit>=3.0.0", "pydantic<=2.0", "requests", "rich>=10.0.0", "sentencepiece",
"shortuuid", "shortuuid", "tiktoken", "tokenizers>=0.12.1", "torch",
"transformers", "uvicorn", "wandb", "openai==0.28.1", "ray", "python-dotenv", "protobuf==3.19",
"wandb", "tiktoken"
"shortuuid", "tiktoken", "tokenizers>=0.12.1", "torch",
"transformers", "uvicorn", "wandb", "openai==1.35.3", "ray", "python-dotenv", "protobuf==3.19"
]


[tool.setuptools.packages.find]
exclude = ["*"]