-
Notifications
You must be signed in to change notification settings - Fork 3
OpenAI version upgrade (latest version) #56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
a6a3d26
b2ab6c3
21de29b
01e6043
ec5a5e0
b319fd7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. こちらは必要ないコミットをしてしまったので削除しました。 |
This file was deleted.
This file was deleted.
This file was deleted.
This file was deleted.
This file was deleted.
This file was deleted.
This file was deleted.
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,17 +9,20 @@ | |
| from typing import Optional, Union | ||
|
|
||
| import openai | ||
| from openai import AzureOpenAI | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Azure の API だけでなく OpenAI の API でも動く実装にしてください.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AzureのAPIしか現状使用できないので、検証はできませんが大丈夫でしょうか。
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. それならこの部分はこちらで実装 & テストします.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 承知しました。 |
||
|
|
||
| client = AzureOpenAI(api_key=os.getenv("OPENAI_API_KEY"), | ||
| api_version=os.getenv("OPENAI_API_VERSION")) | ||
| import tiktoken | ||
| from dotenv import load_dotenv | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| load_dotenv() | ||
| openai.api_key = os.getenv("OPENAI_API_KEY") | ||
| openai.organization = os.getenv("OPENAI_ORGANIZATION") | ||
| openai.api_type = os.getenv("OPENAI_API_TYPE") | ||
| openai.api_base = os.getenv("OPENAI_API_BASE") | ||
| openai.api_version = os.getenv("OPENAI_API_VERSION") | ||
| # TODO: The 'openai.organization' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(organization=os.getenv("OPENAI_ORGANIZATION"))' | ||
| # openai.organization = os.getenv("OPENAI_ORGANIZATION") | ||
| # TODO: The 'openai.api_base' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(base_url=os.getenv("OPENAI_API_BASE"))' | ||
| # openai.api_base = os.getenv("OPENAI_API_BASE") | ||
|
|
||
| # Data paths | ||
| JP_BENCH_DIR = Path(__file__).resolve().parent.parent / "data" / "jp_bench" | ||
|
|
@@ -68,9 +71,9 @@ def judge(self, **kwargs): | |
| params["engine"] = self.model | ||
| else: | ||
| params["model"] = self.model | ||
| response = openai.ChatCompletion.create(**params) | ||
| return response["choices"][0]["message"]["content"] | ||
| except openai.error.OpenAIError as e: | ||
| response = client.chat.completions.create(**params) | ||
| return response.choices[0].message.content | ||
| except openai.OpenAIError as e: | ||
| logger.warning(f"OpenAI API error: {e}") | ||
| time.sleep(API_RETRY_SLEEP) | ||
|
|
||
|
|
@@ -363,3 +366,8 @@ def filter_pairwise_judgements( | |
| else: | ||
| filtered_result_id_results_map[result_id] = results | ||
| return filtered_result_id_results_map | ||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,6 +32,7 @@ def make_match_groups_single( | |
| ref_answers: dict[str, dict[int, dict]], | ||
| judge_default: Judge, | ||
| judge_math: Judge, | ||
| num_answers_per_question: Optional[int] = None, | ||
| ): | ||
| """Make match groups for single answer grading. | ||
|
|
||
|
|
@@ -41,6 +42,7 @@ def make_match_groups_single( | |
| ref_answers (dict): A dict of reference answers. | ||
| judge_default (Judge): A judge for default questions. | ||
| judge_math (Judge): A judge for math questions. | ||
| num_answers_per_question (Optional[int]): Number of answers to evaluate per question. | ||
| """ | ||
| match_groups = {} | ||
| for model in model_answers: | ||
|
|
@@ -63,6 +65,8 @@ def make_match_groups_single( | |
| ref_answer=ref_answer, | ||
| ) | ||
| ) | ||
| if num_answers_per_question: | ||
| matches = matches[:num_answers_per_question] | ||
|
||
| match_groups[f"single:{model}"] = matches | ||
| return match_groups | ||
|
|
||
|
|
@@ -74,6 +78,7 @@ def make_match_groups_pairwise( | |
| judge_default: Judge, | ||
| judge_math: Judge, | ||
| baseline_model: Optional[str] = None, | ||
| num_answers_per_question: Optional[int] = None, | ||
| ): | ||
| """Make match groups for pairwise comparison. | ||
|
|
||
|
|
@@ -84,6 +89,7 @@ def make_match_groups_pairwise( | |
| judge_default (Judge): A judge for default questions. | ||
| judge_math (Judge): A judge for math questions. | ||
| baseline_model (Optional[str]): The baseline model. | ||
| num_answers_per_question (Optional[int]): Number of answers to evaluate per question. | ||
| """ | ||
| match_groups = {} | ||
| for model_1, model_2 in combinations(model_answers, 2): | ||
|
|
@@ -111,6 +117,8 @@ def make_match_groups_pairwise( | |
| ref_answer=ref_answer, | ||
| ) | ||
| ) | ||
| if num_answers_per_question: | ||
| matches = matches[:num_answers_per_question] | ||
|
||
| match_groups[f"pairwise:{model_1}_{model_2}"] = matches | ||
| return match_groups | ||
|
|
||
|
|
@@ -132,7 +140,7 @@ def make_match_groups_pairwise( | |
| parser.add_argument( | ||
| "--judge-model", | ||
| type=str, | ||
| default="gpt-4", | ||
| default="gpt-4-0613", | ||
|
||
| choices=["gpt-4", "gpt-4-0613", "gpt-4-1106-preview", "gpt-3.5-turbo"], | ||
| help="The judge model.", | ||
| ) | ||
|
|
@@ -167,6 +175,9 @@ def make_match_groups_pairwise( | |
| parser.add_argument( | ||
| "--verbose", "-v", action="count", default=0, help="Verbosity level" | ||
| ) | ||
| parser.add_argument( | ||
| "--num_answers_per_question", type=int, default=None, help="Number of answers to evaluate per question." | ||
| ) | ||
| args = parser.parse_args() | ||
|
|
||
| if args.verbose == 0: | ||
|
|
@@ -227,6 +238,7 @@ def make_match_groups_pairwise( | |
| ref_answers=ref_answers, | ||
| judge_default=Judge(args.judge_model, judge_prompts["single"]), | ||
| judge_math=Judge(args.judge_model, judge_prompts["single-math"]), | ||
| num_answers_per_question=args.num_answers_per_question, | ||
| ) | ||
| output_dir = JUDGEMENT_DIR / "single" / args.judge_model | ||
| else: | ||
|
|
@@ -242,6 +254,7 @@ def make_match_groups_pairwise( | |
| judge_default=Judge(args.judge_model, judge_prompts["pair"]), | ||
| judge_math=Judge(args.judge_model, judge_prompts["pair-math"]), | ||
| baseline_model=baseline_model, | ||
| num_answers_per_question=args.num_answers_per_question, | ||
| ) | ||
| output_dir = JUDGEMENT_DIR / "pairwise" / args.judge_model | ||
| target_match_ids = set() | ||
|
|
@@ -290,3 +303,4 @@ def make_match_groups_pairwise( | |
| if args.wandb: | ||
| logger.info("Log to wandb") | ||
| upload_results(args.mode, match_id, results, args.baseline_model) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,7 +26,6 @@ | |
| "generic": 0.1, | ||
| } | ||
|
|
||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 品質管理のためにリンターとフォーマッターを入れています.以下のコマンドを実行してください. |
||
| def generate_response( | ||
| input_text, model, tokenizer, generation_config=None, special_token_map=None | ||
| ): | ||
|
|
@@ -64,7 +63,6 @@ def generate_response( | |
|
|
||
| return output | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument( | ||
|
|
@@ -79,6 +77,9 @@ def generate_response( | |
| parser.add_argument( | ||
| "--overwrite", action="store_true", help="Overwrite the existing results" | ||
| ) | ||
| parser.add_argument( | ||
| "--num_answers_per_question", type=int, default=1, help="Number of answers to generate per question" | ||
| ) | ||
| args = parser.parse_args() | ||
|
|
||
| if args.verbose == 0: | ||
|
|
@@ -159,25 +160,26 @@ def generate_response( | |
| category = question["category"] | ||
| generation_config["temperature"] = DEFAULT_TEMPERATURE_MAP[category] | ||
|
|
||
| output = generate_response( | ||
| input_text=prompt_template.format_map({"instruction": instruction}), | ||
| model=model, | ||
| tokenizer=tokenizer, | ||
| generation_config=generation_config, | ||
| special_token_map=special_token_map, | ||
| ) | ||
|
|
||
| logger.debug(f"{instruction}\n\n{output}") | ||
|
|
||
| results.append( | ||
| { | ||
| "question_id": int(question["question_id"]), | ||
| "answer_id": shortuuid.uuid(), | ||
| "model_id": model_id, | ||
| "choices": [{"index": 0, "turns": [output]}], | ||
| "tstamp": time.time(), | ||
| } | ||
| ) | ||
| for _ in range(args.num_answers_per_question): | ||
| output = generate_response( | ||
| input_text=prompt_template.format_map({"instruction": instruction}), | ||
| model=model, | ||
| tokenizer=tokenizer, | ||
| generation_config=generation_config, | ||
| special_token_map=special_token_map, | ||
| ) | ||
|
|
||
| logger.debug(f"{instruction}\n\n{output}") | ||
|
|
||
| results.append( | ||
| { | ||
| "question_id": int(question["question_id"]), | ||
| "answer_id": shortuuid.uuid(), | ||
| "model_id": model_id, | ||
| "choices": [{"index": 0, "turns": [output]}], | ||
| "tstamp": time.time(), | ||
| } | ||
| ) | ||
|
|
||
| logger.info("Save the results") | ||
| prediction_dir.mkdir(parents=True, exist_ok=True) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.