-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathmain.py
65 lines (50 loc) · 2.12 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import json
import fire
from models.knowledge_harvester import KnowledgeHarvester
def main(rel_set='conceptnet',
model_name='roberta-large',
max_n_ent_tuples=1000,
max_n_prompts=20,
prompt_temp=2.,
max_word_repeat=5,
max_ent_subwords=2,
use_init_prompts=False):
knowledge_harvester = KnowledgeHarvester(
model_name=model_name,
max_n_ent_tuples=max_n_ent_tuples,
max_n_prompts=max_n_prompts,
max_word_repeat=max_word_repeat,
max_ent_subwords=max_ent_subwords,
prompt_temp=prompt_temp)
relation_info = json.load(open(f'relation_info/{rel_set}.json'))
for rel, info in relation_info.items():
print(f'Harvesting for relation {rel}...')
setting = f'{max_n_ent_tuples}tuples'
if use_init_prompts:
setting += '_initprompts'
else:
setting += f'_top{max_n_prompts}prompts'
output_dir = f'results/{rel_set}/{setting}/{model_name}'
if os.path.exists(f'{output_dir}/{rel}/ent_tuples.json'):
print(f'file {output_dir}/{rel}/ent_tuples.json exists, skipped.')
continue
else:
os.makedirs(f'{output_dir}/{rel}', exist_ok=True)
json.dump([], open(f'{output_dir}/{rel}/ent_tuples.json', 'w'))
knowledge_harvester.clear()
knowledge_harvester.set_seed_ent_tuples(
seed_ent_tuples=info['seed_ent_tuples'])
knowledge_harvester.set_prompts(
prompts=info['init_prompts'] if use_init_prompts
else list(set(info['init_prompts'] + info['prompts'])))
knowledge_harvester.update_prompts()
json.dump(knowledge_harvester.weighted_prompts, open(
f'{output_dir}/{rel}/prompts.json', 'w'), indent=4)
for prompt, weight in knowledge_harvester.weighted_prompts:
print(f'{weight:.4f} {prompt}')
knowledge_harvester.update_ent_tuples()
json.dump(knowledge_harvester.weighted_ent_tuples, open(
f'{output_dir}/{rel}/ent_tuples.json', 'w'), indent=4)
if __name__ == '__main__':
fire.Fire(main)