-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathrun_descriptive_analysis.py
More file actions
85 lines (68 loc) · 2.72 KB
/
run_descriptive_analysis.py
File metadata and controls
85 lines (68 loc) · 2.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import argparse
from src.agent import HFAgent, OpenAIAgent
from src.descriptive_analysis.descriptive_analysis import generate_descriptive_analysis_answers
from src.utils.general import set_seed_everywhere
from src.utils.io import read_json
from src.utils.logger import freeze_args, get_logger
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--fakepedia_path", type=str)
parser.add_argument("--model_name_path", type=str)
parser.add_argument("--temperature", default=0, type=int)
parser.add_argument("--max_new_tokens", default=100, type=int)
parser.add_argument("--top_p", default=1, type=int)
parser.add_argument(
"--system_message",
type=str,
)
parser.add_argument(
"--human_message_prompt_template",
type=str,
)
parser.add_argument("--num_examples", type=int)
parser.add_argument("--openai_api_key", default=None, type=str)
parser.add_argument("--bfloat16", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--merge_system_message", action=argparse.BooleanOptionalAction, default=False)
parser.add_argument("--resume_path", default=None, type=str)
parser.add_argument("--seed", default=100, type=int)
return parser.parse_args()
def generate_answers(args):
logger = get_logger()
generation_parameters = {
"temperature": args.temperature,
"top_p": args.top_p,
"max_new_tokens": args.max_new_tokens,
}
logger.info("Initializing agent...")
agent = (
HFAgent(
model_path=args.model_name_path,
tokenizer_path=args.model_name_path,
generation_parameters=generation_parameters,
system_message=args.system_message,
human_message_prompt_template=args.human_message_prompt_template,
bfloat16=args.bfloat16,
merge_system_message=args.merge_system_message,
verbose=True,
)
if args.openai_api_key is None
else OpenAIAgent(
model_name=args.model_name_path,
api_key=args.openai_api_key,
generation_parameters=generation_parameters,
system_message=args.system_message,
human_message_prompt_template=args.human_message_prompt_template,
verbose=True,
)
)
logger.info("Loading base fakepedia...")
base_fakepedia = read_json(args.fakepedia_path)
logger.info("Generating descriptive analysis answers...")
generate_descriptive_analysis_answers(base_fakepedia, agent, args.num_examples, resume_path=args.resume_path)
def main():
args = get_args()
freeze_args(args)
set_seed_everywhere(args.seed)
generate_answers(args)
if __name__ == "__main__":
main()