From bfd8976571cfa73dd1ac51568e4397b0f7ac4d49 Mon Sep 17 00:00:00 2001 From: Bethany Connolly Date: Tue, 4 Apr 2023 13:12:49 +0000 Subject: [PATCH 1/4] setting up wandb logging --- main.py | 27 ++++++++++++++------------- setup.py | 1 + 2 files changed, 15 insertions(+), 13 deletions(-) mode change 100644 => 100755 main.py mode change 100644 => 100755 setup.py diff --git a/main.py b/main.py old mode 100644 new mode 100755 index c716183..ff4a007 --- a/main.py +++ b/main.py @@ -2,6 +2,7 @@ import json import logging import fnmatch +import wandb from lm_eval import tasks, evaluator @@ -40,6 +41,9 @@ def parse_args(): parser.add_argument("--decontamination_ngrams_path", default=None) parser.add_argument("--description_dict_path", default=None) parser.add_argument("--check_integrity", action="store_true") + parser.add_argument("--wandb_log", type=bool, default=False) + parser.add_argument("--wandb_project", type=str, default=None) + parser.add_argument("--wandb_run_name", type=str, default=None) return parser.parse_args() @@ -57,10 +61,9 @@ def pattern_match(patterns, source_list): def main(): args = parse_args() - if args.limit: - print( - "WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." - ) + if args.wandb_log: + assert (wandb_project is not None) and (wandb_run_name is not None) + wandb.init(project=args.wandb_project, name=argd.wandb_run_name, config=args) if args.tasks is None: task_names = tasks.ALL_TASKS @@ -84,15 +87,13 @@ def main(): dumped = json.dumps(results, indent=2) print(dumped) - if args.output_path: - with open(args.output_path, "w") as f: - f.write(dumped) - - print( - f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, " - f"num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}" - ) - print(evaluator.make_table(results)) + if args.wandb_log: + # TODO: where is "filter" coming from? + for task, metrics in dumped["results"].items(): + # wandb.log( + # f"{task.split()[0]}_{metric}": value for metric, value in metrics.items() + # ) + wandb.log({task.split()[0]: metrics}) if __name__ == "__main__": diff --git a/setup.py b/setup.py old mode 100644 new mode 100755 index 6fde2b1..03bf0b9 --- a/setup.py +++ b/setup.py @@ -37,6 +37,7 @@ "tqdm-multiprocess", "transformers>=4.1", "zstandard", + "wandb", ], extras_require={ "dev": ["black", "flake8", "pre-commit", "pytest", "pytest-cov"], From 3c32578844c5024241a1250af9b23d27555a1960 Mon Sep 17 00:00:00 2001 From: Bethany Connolly Date: Tue, 4 Apr 2023 15:04:30 +0000 Subject: [PATCH 2/4] updated wandb logging --- main.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/main.py b/main.py index ff4a007..721ac6c 100755 --- a/main.py +++ b/main.py @@ -62,8 +62,8 @@ def main(): args = parse_args() if args.wandb_log: - assert (wandb_project is not None) and (wandb_run_name is not None) - wandb.init(project=args.wandb_project, name=argd.wandb_run_name, config=args) + assert (args.wandb_project is not None) and (args.wandb_run_name is not None) + wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=args) if args.tasks is None: task_names = tasks.ALL_TASKS @@ -89,10 +89,7 @@ def main(): if args.wandb_log: # TODO: where is "filter" coming from? - for task, metrics in dumped["results"].items(): - # wandb.log( - # f"{task.split()[0]}_{metric}": value for metric, value in metrics.items() - # ) + for task, metrics in results["results"].items(): wandb.log({task.split()[0]: metrics}) From 3c0d299aef2ac8bef719645ed4a00fd054cd21cd Mon Sep 17 00:00:00 2001 From: Bethany Connolly Date: Wed, 5 Apr 2023 13:46:32 +0000 Subject: [PATCH 3/4] updated main.py to take args via yaml --- main.py | 69 +++++++++++++++++++++++++------------------------------- setup.py | 3 ++- 2 files changed, 33 insertions(+), 39 deletions(-) diff --git a/main.py b/main.py index 721ac6c..9f89129 100755 --- a/main.py +++ b/main.py @@ -3,49 +3,37 @@ import logging import fnmatch import wandb +from pathlib import Path +from typing import Union +import yaml +from pydantic import BaseModel from lm_eval import tasks, evaluator logging.getLogger("openai").setLevel(logging.WARNING) -class MultiChoice: - def __init__(self, choices): - self.choices = choices +def load_config(path: Union[str, Path]): + with open(path, "r") as stream: + try: + return yaml.safe_load(stream) + except yaml.YAMLError as exc: + print(exc) - # Simple wildcard support (linux filename patterns) - def __contains__(self, values): - for value in values.split(","): - if len(fnmatch.filter(self.choices, value)) == 0: - return False - return True - - def __iter__(self): - for choice in self.choices: - yield choice - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--model", required=True) - parser.add_argument("--model_args", default="") - parser.add_argument("--tasks", default=None, choices=MultiChoice(tasks.ALL_TASKS)) - parser.add_argument("--provide_description", action="store_true") - parser.add_argument("--num_fewshot", type=int, default=0) - parser.add_argument("--batch_size", type=int, default=None) - parser.add_argument("--device", type=str, default=None) - parser.add_argument("--output_path", default=None) - parser.add_argument("--limit", type=int, default=None) - parser.add_argument("--no_cache", action="store_true") - parser.add_argument("--decontamination_ngrams_path", default=None) - parser.add_argument("--description_dict_path", default=None) - parser.add_argument("--check_integrity", action="store_true") - parser.add_argument("--wandb_log", type=bool, default=False) - parser.add_argument("--wandb_project", type=str, default=None) - parser.add_argument("--wandb_run_name", type=str, default=None) - - return parser.parse_args() +class EvalPipelineConfig(BaseModel): + model: str + model_args: str = "" + tasks: str = None # check the types + num_fewshot: int = 0 + batch_size: int = None + device: str = None + limit: int = None + decontamination_ngrams_path: str = None + check_integrity: bool = False + wandb_log: bool = False + wandb_project: str = None + wandb_run_name: str = None # Returns a list containing all values of the source_list that @@ -58,8 +46,10 @@ def pattern_match(patterns, source_list): return list(task_names) -def main(): - args = parse_args() +def main(config_path: str) -> None: + + raw_config = load_config(config_path) + args = EvalPipelineConfig(**raw_config) if args.wandb_log: assert (args.wandb_project is not None) and (args.wandb_run_name is not None) @@ -94,4 +84,7 @@ def main(): if __name__ == "__main__": - main() + parser = argparse.ArgumentParser() + parser.add_argument("config_path", help="The full path to the YAML config file.") + args = parser.parse_args() + main(args.config_path) diff --git a/setup.py b/setup.py index 03bf0b9..a6d35b0 100755 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ "omegaconf>=2.2", "pybind11>=2.6.2", "pycountry", + "pydantic", "pytablewriter", "rouge-score>=0.0.4", "sacrebleu==1.5.0", @@ -36,8 +37,8 @@ "torch>=1.7", "tqdm-multiprocess", "transformers>=4.1", - "zstandard", "wandb", + "zstandard", ], extras_require={ "dev": ["black", "flake8", "pre-commit", "pytest", "pytest-cov"], From 8b31a30fd889e121827e009d57d07b32e00a099c Mon Sep 17 00:00:00 2001 From: Bethany Connolly Date: Tue, 11 Apr 2023 13:09:59 +0000 Subject: [PATCH 4/4] added original main.py back and renamed new script to eval_main.py --- main.py | 87 +++++++++++++++++++++++++++----------------------- main_eval.py | 90 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+), 39 deletions(-) mode change 100755 => 100644 main.py create mode 100755 main_eval.py diff --git a/main.py b/main.py old mode 100755 new mode 100644 index 9f89129..dc3ca6c --- a/main.py +++ b/main.py @@ -2,38 +2,46 @@ import json import logging import fnmatch -import wandb -from pathlib import Path -from typing import Union -import yaml -from pydantic import BaseModel from lm_eval import tasks, evaluator logging.getLogger("openai").setLevel(logging.WARNING) -def load_config(path: Union[str, Path]): - with open(path, "r") as stream: - try: - return yaml.safe_load(stream) - except yaml.YAMLError as exc: - print(exc) +class MultiChoice: + def __init__(self, choices): + self.choices = choices + # Simple wildcard support (linux filename patterns) + def __contains__(self, values): + for value in values.split(","): + if len(fnmatch.filter(self.choices, value)) == 0: + return False -class EvalPipelineConfig(BaseModel): - model: str - model_args: str = "" - tasks: str = None # check the types - num_fewshot: int = 0 - batch_size: int = None - device: str = None - limit: int = None - decontamination_ngrams_path: str = None - check_integrity: bool = False - wandb_log: bool = False - wandb_project: str = None - wandb_run_name: str = None + return True + + def __iter__(self): + for choice in self.choices: + yield choice + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", required=True) + parser.add_argument("--model_args", default="") + parser.add_argument("--tasks", default=None, choices=MultiChoice(tasks.ALL_TASKS)) + parser.add_argument("--provide_description", action="store_true") + parser.add_argument("--num_fewshot", type=int, default=0) + parser.add_argument("--batch_size", type=int, default=None) + parser.add_argument("--device", type=str, default=None) + parser.add_argument("--output_path", default=None) + parser.add_argument("--limit", type=int, default=None) + parser.add_argument("--no_cache", action="store_true") + parser.add_argument("--decontamination_ngrams_path", default=None) + parser.add_argument("--description_dict_path", default=None) + parser.add_argument("--check_integrity", action="store_true") + + return parser.parse_args() # Returns a list containing all values of the source_list that @@ -46,14 +54,13 @@ def pattern_match(patterns, source_list): return list(task_names) -def main(config_path: str) -> None: +def main(): + args = parse_args() - raw_config = load_config(config_path) - args = EvalPipelineConfig(**raw_config) - - if args.wandb_log: - assert (args.wandb_project is not None) and (args.wandb_run_name is not None) - wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=args) + if args.limit: + print( + "WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." + ) if args.tasks is None: task_names = tasks.ALL_TASKS @@ -77,14 +84,16 @@ def main(config_path: str) -> None: dumped = json.dumps(results, indent=2) print(dumped) - if args.wandb_log: - # TODO: where is "filter" coming from? - for task, metrics in results["results"].items(): - wandb.log({task.split()[0]: metrics}) + if args.output_path: + with open(args.output_path, "w") as f: + f.write(dumped) + + print( + f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, " + f"num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}" + ) + print(evaluator.make_table(results)) if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("config_path", help="The full path to the YAML config file.") - args = parser.parse_args() - main(args.config_path) + main() \ No newline at end of file diff --git a/main_eval.py b/main_eval.py new file mode 100755 index 0000000..9f89129 --- /dev/null +++ b/main_eval.py @@ -0,0 +1,90 @@ +import argparse +import json +import logging +import fnmatch +import wandb +from pathlib import Path +from typing import Union +import yaml +from pydantic import BaseModel + +from lm_eval import tasks, evaluator + +logging.getLogger("openai").setLevel(logging.WARNING) + + +def load_config(path: Union[str, Path]): + with open(path, "r") as stream: + try: + return yaml.safe_load(stream) + except yaml.YAMLError as exc: + print(exc) + + +class EvalPipelineConfig(BaseModel): + model: str + model_args: str = "" + tasks: str = None # check the types + num_fewshot: int = 0 + batch_size: int = None + device: str = None + limit: int = None + decontamination_ngrams_path: str = None + check_integrity: bool = False + wandb_log: bool = False + wandb_project: str = None + wandb_run_name: str = None + + +# Returns a list containing all values of the source_list that +# match at least one of the patterns +def pattern_match(patterns, source_list): + task_names = set() + for pattern in patterns: + for matching in fnmatch.filter(source_list, pattern): + task_names.add(matching) + return list(task_names) + + +def main(config_path: str) -> None: + + raw_config = load_config(config_path) + args = EvalPipelineConfig(**raw_config) + + if args.wandb_log: + assert (args.wandb_project is not None) and (args.wandb_run_name is not None) + wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=args) + + if args.tasks is None: + task_names = tasks.ALL_TASKS + else: + task_names = pattern_match(args.tasks.split(","), tasks.ALL_TASKS) + + print(f"Selected Tasks: {task_names}") + + results = evaluator.simple_evaluate( + model=args.model, + model_args=args.model_args, + tasks=task_names, + num_fewshot=args.num_fewshot, + batch_size=args.batch_size, + device=args.device, + limit=args.limit, + decontamination_ngrams_path=args.decontamination_ngrams_path, + check_integrity=args.check_integrity, + ) + + dumped = json.dumps(results, indent=2) + print(dumped) + + if args.wandb_log: + # TODO: where is "filter" coming from? + for task, metrics in results["results"].items(): + wandb.log({task.split()[0]: metrics}) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("config_path", help="The full path to the YAML config file.") + args = parser.parse_args() + main(args.config_path)