Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update README.md #36

Open
wants to merge 54 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
01fcbf9
update README.md
smiyawaki0820 May 21, 2020
1a19d32
create parallel corpus from CoNLL
smiyawaki0820 May 23, 2020
49f6dbe
update preprocess/
smiyawaki0820 May 23, 2020
6295215
update preprocess/
smiyawaki0820 May 23, 2020
ced51af
update README.md
smiyawaki0820 May 23, 2020
3c708ff
update preprocess/
smiyawaki0820 May 23, 2020
39d680d
update README.md
smiyawaki0820 May 23, 2020
b8ef8fd
add srl_setting.sh
smiyawaki0820 May 23, 2020
85b9b6a
add train/
smiyawaki0820 May 23, 2020
419ca0c
update config.sh
smiyawaki0820 May 23, 2020
d42440e
add run/train .sh
smiyawaki0820 May 23, 2020
1f854d7
update README.md
smiyawaki0820 May 23, 2020
d947521
update run.sh
smiyawaki0820 May 23, 2020
bdb01a3
update preprocess/
smiyawaki0820 May 26, 2020
3030ffe
comment preprocess.py
smiyawaki0820 May 26, 2020
ddebf89
update preprocess/
smiyawaki0820 May 26, 2020
bd073bf
update preprocess/
smiyawaki0820 May 26, 2020
d6d00dd
update preprocess/
smiyawaki0820 May 26, 2020
4bbe54a
update config.sh
smiyawaki0820 May 26, 2020
7c7ff04
comment train.py
smiyawaki0820 May 26, 2020
78d1a38
update run.sh
smiyawaki0820 May 26, 2020
4b9d951
modified _
smiyawaki0820 May 26, 2020
475b6cd
add train
smiyawaki0820 May 26, 2020
9c3cfb1
add comment
smiyawaki0820 May 26, 2020
42130b8
rename train/ to sh/
smiyawaki0820 May 26, 2020
9383b6d
add sh/evaluate.sh
smiyawaki0820 May 26, 2020
ae64265
add sh/srl_eval.sh
smiyawaki0820 May 26, 2020
edb4452
update sh/srl_eval.sh
smiyawaki0820 May 26, 2020
6869ad3
add eval module
smiyawaki0820 May 26, 2020
4da228a
add logging into generate.py
smiyawaki0820 May 27, 2020
1b8f92b
update sh
smiyawaki0820 May 27, 2020
294e6dc
add comment
smiyawaki0820 May 27, 2020
56c469d
overwrite dicts/dict.src.txt
smiyawaki0820 May 28, 2020
7e260dc
update preprocess/
smiyawaki0820 May 28, 2020
c3e1050
update preprocess/
smiyawaki0820 May 28, 2020
1c6bbd5
add creating small
smiyawaki0820 May 28, 2020
2e5fac9
update run.sh
smiyawaki0820 May 31, 2020
f6f9c17
comment fairseq/trainer.py
smiyawaki0820 May 31, 2020
189aadc
add sh/create_datasets
smiyawaki0820 May 31, 2020
d760fcf
update config.sh
smiyawaki0820 May 31, 2020
fbfa70d
update sh/create_datasets
smiyawaki0820 May 31, 2020
bbcbd79
comment fairseq/criterions/cross_entropy.py
smiyawaki0820 May 31, 2020
1dd8473
update fairseq/
smiyawaki0820 May 31, 2020
10f30b5
add sh/src/JSON_to_CoNLL.py
smiyawaki0820 Jun 1, 2020
a0ad9c0
memo fairseq/criterions/cross_entropy.py
smiyawaki0820 Jun 1, 2020
657913d
add eval_step()
smiyawaki0820 Jun 1, 2020
4116797
update run.sh
smiyawaki0820 Jun 1, 2020
a861b31
update train.py
smiyawaki0820 Jun 1, 2020
82b590f
add srl eval
smiyawaki0820 Jun 2, 2020
7a5b53f
update sh/
smiyawaki0820 Jun 2, 2020
051d22c
update train.py
smiyawaki0820 Jun 2, 2020
4119eda
data 10% train
smiyawaki0820 Jun 3, 2020
ffcea20
update run.sh
smiyawaki0820 Jun 3, 2020
cca41a5
update train.py
smiyawaki0820 Jun 3, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update train.py
smiyawaki0820 committed Jun 2, 2020
commit 051d22ca540100f6c472fb7c1b89c0d0c0f57be8
109 changes: 81 additions & 28 deletions train.py
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@
import time
import numpy as np
import pandas as pd
from pprint import pprint

import torch

@@ -40,6 +41,33 @@
)
logger = logging.getLogger(__name__)

class GlobalParameters():
def __init__(self, epoch, lr):
# debugger
self.epoch = epoch
self.lr = lr
# counter
self.n_instances = 0 # instance 数
self.n_diff_gp_len = 0 # gold と pred の len(seq) が異なる
self.n_invld_bracket = 0 # blacket が unclosing

def basic_params(self) -> dict:
return {
'epoch': self.epoch,
'lr': self.lr,
'n_instances': self.n_instances,
'n_diff_gp_len': self.n_diff_gp_len,
'n_invld_bracket': self.n_invld_bracket
}

def store_params(self, params:list, **kwargs) -> list:
basics = self.basic_params()
basics.update(kwargs)
params.append(basics)
return params

epochs = []
global_params = GlobalParameters(-1, -1)
sys.setrecursionlimit(2000)


@@ -142,9 +170,16 @@ def main(args, init_distributed=False):

if epoch_itr.epoch % args.validate_interval == 0:
logger.info('===== validate =====')
global global_params
global_params = GlobalParameters(epoch_itr.epoch, lr)
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
global_params.store_params(
epochs,
loss=valid_losses[0]
)
pprint(epochs[-1])
logger.info("epoch_itr.epoch ... %d" % epoch_itr.epoch)
logger.debug("valid_losses ... ".format(valid_losses[0]))
logger.debug("valid_losses ... {}".format(valid_losses[0]))

# ema process (直近の比重を大きく?)
logger.info('ema process')
@@ -169,6 +204,7 @@ def main(args, init_distributed=False):
train_meter.stop()
print('| DONE training in {:.1f} seconds'.format(train_meter.sum))
logger.info('destdir is ... {}'.format(args.save_dir))
pprint(epochs)
sys.exit()


@@ -310,10 +346,10 @@ def validate(args, trainer, task, epoch_itr, subsets):
extra_meters = collections.defaultdict(lambda: AverageMeter())

#TODO ここで計算する
fo_gold, fo_pred = open('work/gold.prop', 'a'), open('work/pred.prop', 'a')
fo_gold, fo_pred = open(f'work/gold_{epoch_itr.epoch}.prop', 'w'), open(f'work/pred_{epoch_itr.epoch}.prop', 'w')
logger.info('open gold prop ... {}'.format(fo_gold))
logger.info('open pred prop ... {}'.format(fo_pred))

for sample in progress:
"""
sample.keys()
@@ -334,8 +370,8 @@ def validate(args, trainer, task, epoch_itr, subsets):
pred_tokens = np.array(list(dict_itos(task.tgt_dict, pred))).reshape(nsents, -1).tolist()
assert len(gold_tokens) == nsents
assert len(pred_tokens) == nsents
wrap_write_prop(fo_gold, gold_tokens, gold_tokens)
wrap_write_prop(fo_pred, pred_tokens, gold_tokens)
wrap_write_prop(args, fo_gold, gold_tokens, gold_tokens, oracle='max')
wrap_write_prop(args, fo_pred, pred_tokens, gold_tokens, oracle='max')


### original ###
@@ -347,24 +383,29 @@ def validate(args, trainer, task, epoch_itr, subsets):
extra_meters[k].update(v)

fo_gold.close(), fo_pred.close()


### srl-eval.pl ###
path_pl = os.path.abspath('/home/miyawaki_shumpei/soft/srlconll-1.1/bin/srl-eval.pl')
path_gldp = os.path.abspath(fo_gold.name)
path_prdp = os.path.abspath(fo_pred.name)
command = f"perl {path_pl} {path_gldp} {path_prdp}"
process = subprocess.run(command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
result = process.stdout.decode("utf8")
result = [o.split() for o in result.split('\n') if len(o.split()) > 1]
statistics, scores = result[:3], result[3:]
column, *scores = scores
column.insert(0, 'label')
overall = (
pd.DataFrame(scores, columns=column)
.set_index('label')
.loc['Overall']
)

logger.debug('overall scores ::: {}' % overall)
try:
result = [o.split() for o in result.split('\n') if len(o.split()) > 1]
statistics, scores = result[:3], result[3:]
column, *scores = scores
column.insert(0, 'label')
overall = (
pd.DataFrame(scores, columns=column)
.set_index('label')
.loc['Overall']
)

logger.debug('overall scores ::: {}'.format(overall))
except:
import ipdb; ipdb.set_trace()

# log validation stats
stats = get_valid_stats(trainer)
@@ -399,13 +440,20 @@ def extract_index(tokens: list, query='<V>', pos=1):
index = pos+tokens.index(query) if query in tokens else -1
return index if (index+1 <= len(tokens)) and (index != -1) else -1

def eliminate_labels(tokens):
return [t for t in tokens if (not is_start(t) and not is_end(t))]

def is_closed(tokens: list):
isInvalid = 0
isInvalid, label = 0, ''
for token in tokens:
if isInvalid < 0 or 1 < isInvalid: return False
if is_start(token): isInvalid += 1
elif is_end(token): isInvalid -= 1
return True
if is_start(token):
label = token[1:-1]
isInvalid += 1
elif is_end(token):
isInvalid -= 1
if label != token[2:-1]: return False
return True if isInvalid == 0 else False

#@develop('dev.1 convert into bio')
def to_bio(tokens: list, init_label='O', excepts=('<pad>', '<EN-SRL>', '<DE-SRL>', '</s>')) -> list:
@@ -437,12 +485,13 @@ def to_prop(bios:list):
# dev なので,同一文でも新しく 2cols を作成
#@develop('dev.3 write prop')
def write_prop(fo, tokens:list, golds:list, oracle='min'):
isPred = 'pred' in fo.name
assert oracle in ('min', 'max'), "OracleError"
gold_props = to_prop(to_bio(golds))
verb_idx = extract_index(golds, query='<V>', pos=1)
col_0 = ['-' for _ in gold_props]
col_1 = []

if isPred: global_params.n_instances += 1
if verb_idx == -1:
# 述語なし
col_1 = []
@@ -458,35 +507,39 @@ def extract_index_for_cols(golds, target='<V>'):
col_0[col_vix] = verb
except IndexError:
import ipdb; ipdb.set_trace()
if len(tokens) != len(golds):
if len(eliminate_labels(tokens)) != len(eliminate_labels(golds)):
# len 不一致
import ipdb; ipdb.set_trace()
col_1 = ['*' if oracle=='min' else gold for gold in gold_props]
elif is_closed(tokens):
#import ipdb; ipdb.set_trace()
if isPred: global_params.n_diff_gp_len += 1
elif not is_closed(tokens):
# unclosed (不適切 bracket)
col_1 = ['*' if oracle=='min' else gold for gold in gold_props]
if isPred: global_params.n_invld_bracket += 1
else:
# 問題なし
col_1 = to_prop(to_bio(tokens))
if len(col_0) != len(col_1):
# len 不一致
import ipdb; ipdb.set_trace()
col_1 = ['*' if oracle=='min' else gold for gold in gold_props]

try:
assert (len(col_0)==len(col_1)) or (set(col_0) == {'-'})
except:
import ipdb; ipdb.set_trace()
for c0, c1 in itertools.zip_longest(col_0, col_1):
print(f"{c0}\t{c1}", file=fo, end='\n')
print("\n", file=fo)
if c1 is not None: print(f"{c0}\t{c1}", file=fo, end='\n')
else: print(f"{c0}", file=fo, end='\n')


def wrap_write_prop(fo, sents:list, golds:list):
def wrap_write_prop(args, fo, sents:list, golds:list, oracle='min'):
if args.no_ema: return
# write_prop 関数は一文専用なので,複数文で loop
for sent, gold in zip(sents, golds):
sent = [token for token in sent if token!='<pad>']
gold = [token for token in gold if token!='<pad>']
write_prop(fo, sent, gold, oracle='min')
write_prop(fo, sent, gold, oracle=oracle)


def get_valid_stats(trainer):