Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 43 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@

**This code is for EMNLP 2019 paper [Text Summarization with Pretrained Encoders](https://arxiv.org/abs/1908.08345)**

**Updates Jan 22 2020**: Now you can **Summarize Raw Text Input!**. Swith to the dev branch, and use `-text_src $RAW_SRC.TXT` to input your text file.
* use `-test_from $PT_FILE$` to use your model checkpoint file.
* Format of the source text file:
* For **abstractive summarization**, each line is a document.
* If you want to do **extractive summarization**, please insert ` [CLS] [SEP] ` as your sentence boundaries.
* There are example input files in the [raw_data directory](https://github.com/nlpyang/PreSumm/tree/dev/raw_data)
* If you also have reference summaries aligned with your source input, please use `-text_tgt $RAW_TGT.TXT` to keep the order for evaluation.


Results on CNN/DailyMail (20/8/2019):


Expand Down Expand Up @@ -60,17 +69,25 @@ Results on CNN/DailyMail (20/8/2019):

**Package Requirements**: torch==1.1.0 pytorch_transformers tensorboardX multiprocess pyrouge



**Updates**: For encoding a text longer than 512 tokens, for example 800. Set max_pos to 800 during both preprocessing and training.


Some codes are borrowed from ONMT(https://github.com/OpenNMT/OpenNMT-py)

## Trained Models
[CNN/DM Extractive](https://drive.google.com/open?id=1kKWoV0QCbeIuFt85beQgJ4v0lujaXobJ)
[CNN/DM BertExt](https://drive.google.com/open?id=1kKWoV0QCbeIuFt85beQgJ4v0lujaXobJ)

[CNN/DM BertExtAbs](https://drive.google.com/open?id=1-IKVCtc4Q-BdZpjXc4s70_fRsWnjtYLr)

[CNN/DM Abstractive](https://drive.google.com/open?id=1-IKVCtc4Q-BdZpjXc4s70_fRsWnjtYLr)
[CNN/DM TransformerAbs](https://drive.google.com/open?id=1yLCqT__ilQ3mf5YUUCw9-UToesX5Roxy)

[XSum](https://drive.google.com/open?id=1H50fClyTkNprWJNh10HWdGEdDdQIkzsI)
[XSum BertExtAbs](https://drive.google.com/open?id=1H50fClyTkNprWJNh10HWdGEdDdQIkzsI)

## System Outputs

[CNN/DM and XSum](https://drive.google.com/file/d/1kYA384UEAQkvmZ-yWZAfxw7htCbCwFzC)

## Data Preparation For XSum
[Pre-processed data](https://drive.google.com/open?id=1BWBN1coTWGBqrWoOfRc5dhojPHhatbYs)
Expand Down Expand Up @@ -133,7 +150,7 @@ python train.py -task ext -mode train -bert_data_path BERT_DATA_PATH -ext_dropou

#### TransformerAbs (baseline)
```
python train.py -mode train -accum_count 5 -batch_size 300 -bert_data_path BERT_DATA_PATH -dec_dropout 0.1 -log_file ../../logs/cnndm_baseline -lr 0.1 -model_path MODEL_PATH -save_checkpoint_steps 2000 -seed 777 -sep_optim false -train_steps 200000 -use_bert_emb true -use_interval true -warmup_steps 8000 -visible_gpus 0,1,2,3 -max_pos 512 -report_every 50 -enc_hidden_size 512 -enc_layers 6 -enc_ff_size 2048 -enc_dropout 0.1 -dec_layers 6 -dec_hidden_size 512 -dec_ff_size 2048 -encoder baseline -task abs
python train.py -mode train -accum_count 5 -batch_size 300 -bert_data_path BERT_DATA_PATH -dec_dropout 0.1 -log_file ../../logs/cnndm_baseline -lr 0.05 -model_path MODEL_PATH -save_checkpoint_steps 2000 -seed 777 -sep_optim false -train_steps 200000 -use_bert_emb true -use_interval true -warmup_steps 8000 -visible_gpus 0,1,2,3 -max_pos 512 -report_every 50 -enc_hidden_size 512 -enc_layers 6 -enc_ff_size 2048 -enc_dropout 0.1 -dec_layers 6 -dec_hidden_size 512 -dec_ff_size 2048 -encoder baseline -task abs
```
#### BertAbs
```
Expand All @@ -145,15 +162,34 @@ python train.py -task abs -mode train -bert_data_path BERT_DATA_PATH -dec_dropo
```
* `EXT_CKPT` is the saved `.pt` checkpoint of the extractive model.




## Model Evaluation
### CNN/DM
```
python train.py -task abs -mode validate -batch_size 3000 -test_batch_size 500 -bert_data_path BERT_DATA_PATH -log_file ../logs/val_abs_bert_cnndm -model_path MODEL_PATH -sep_optim true -use_interval true -visible_gpus 1 -max_pos 512 -max_length 200 -alpha 0.95 -min_length 50 -result_path ../logs/abs_bert_cnndm
```
### XSum
```
python train.py -task abs -mode validate -batch_size 3000 -test_batch_size 500 -bert_data_path BERT_DATA_PATH -log_file ../logs/val_abs_bert_cnndm -model_path MODEL_PATH -sep_optim true -use_interval true -visible_gpus 1 -max_pos 512 -min_length 20 -max_length 100 -alpha 0.9 -result_path ../logs/abs_bert_cnndm
```
* `-mode` can be {`validate, test`}, where `validate` will inspect the model directory and evaluate the model for each newly saved checkpoint, `test` need to be used with `-test_from`, indicating the checkpoint you want to use
* `MODEL_PATH` is the directory of saved checkpoints
* use `-mode valiadte` with `-test_all`, the system will load all saved checkpoints and select the top ones to generate summaries (this will take a while)

## Raw Text Input

### Abstractive Summarization

Source File
```
python train.py -task abs -mode test_text -visible_gpus 0 -test_from PATH_TO_CHECKPOINT -text_src PATH_TO_SRC -text_tgt PATH_TO_TGT -log_file ../logs/abs_bert_cnndm
```

String Input
```python
python train.py -task abs -mode test_text -visible_gpus 0 -test_from ../models/model_step_148000.pt -text_src 'this is a string test' -input_type str -log_file ../logs/abs_bert_cnndm
```
### Extractive Summarization

```
python train.py -task ext -mode test_text -visible_gpus 0 -test_from PATH_TO_CHECKPOINT -text_src PATH_TO_SRC -text_tgt PATH_TO_TGT -log_file ../logs/abs_bert_cnndm
```
25 changes: 23 additions & 2 deletions src/models/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import gc
import glob
import random
import os.path

import torch
from tqdm import tqdm
Expand Down Expand Up @@ -295,7 +296,11 @@ def load_text(args, source_fp, target_fp, device):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
sep_vid = tokenizer.vocab['[SEP]']
cls_vid = tokenizer.vocab['[CLS]']
n_lines = len(open(source_fp).read().split('\n'))

if os.path.isfile(source_fp):
n_lines = len(open(source_fp).read().split('\n'))
else:
n_lines = 1

def _process_src(raw):
raw = raw.strip().lower()
Expand Down Expand Up @@ -323,7 +328,7 @@ def _process_src(raw):

return src, mask_src, segments_ids, clss, mask_cls

if(target_fp==''):
if(target_fp=='' and args.input_type == 'doc'):
with open(source_fp) as source:
for x in tqdm(source, total=n_lines):
src, mask_src, segments_ids, clss, mask_cls = _process_src(x)
Expand All @@ -341,6 +346,22 @@ def _process_src(raw):

batch.batch_size=1
yield batch
elif(args.input_type == 'str'):
src, mask_src, segments_ids, clss, mask_cls = _process_src(source_fp)
segs = torch.tensor(segments_ids)[None, :].to(device)
batch = Batch()
batch.src = src
batch.tgt = None
batch.mask_src = mask_src
batch.mask_tgt = None
batch.segs = segs
batch.src_str = [[sent.replace('[SEP]','').strip() for sent in source_fp.split('[CLS]')]]
batch.tgt_str = ['']
batch.clss = clss
batch.mask_cls = mask_cls

batch.batch_size=1
yield batch
else:
with open(source_fp) as source, open(target_fp) as target:
for x, y in tqdm(zip(source, target), total=n_lines):
Expand Down
23 changes: 18 additions & 5 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import division

import argparse
import torch
import os
from others.logging import init_logger
from train_abstractive import validate_abs, train_abs, baseline, test_abs, test_text_abs
Expand All @@ -30,16 +31,16 @@ def str2bool(v):
parser.add_argument("-task", default='ext', type=str, choices=['ext', 'abs'])
parser.add_argument("-encoder", default='bert', type=str, choices=['bert', 'baseline'])
parser.add_argument("-mode", default='train', type=str, choices=['train', 'validate', 'test', 'test_text'])
parser.add_argument("-bert_data_path", default='../bert_data_new/cnndm')
parser.add_argument("-text_src", default='../raw_data/temp.raw_src')
parser.add_argument("-input_type",default='doc',type=str, choices=['doc','str'])
parser.add_argument("-text_tgt", default='')
parser.add_argument("-bert_data_path", default='../bert_data/cnndm')
parser.add_argument("-model_path", default='../models/')
parser.add_argument("-result_path", default='../results/cnndm')
parser.add_argument("-temp_dir", default='../temp')
parser.add_argument("-text_src", default='')
parser.add_argument("-text_tgt", default='')

parser.add_argument("-batch_size", default=140, type=int)
parser.add_argument("-test_batch_size", default=200, type=int)
parser.add_argument("-max_ndocs_in_batch", default=6, type=int)

parser.add_argument("-max_pos", default=512, type=int)
parser.add_argument("-use_interval", type=str2bool, nargs='?',const=True,default=True)
Expand Down Expand Up @@ -120,6 +121,8 @@ def str2bool(v):
device = "cpu" if args.visible_gpus == '-1' else "cuda"
device_id = 0 if device == "cuda" else -1

torch.cuda.empty_cache()

if (args.task == 'abs'):
if (args.mode == 'train'):
train_abs(args, device_id)
Expand All @@ -137,6 +140,11 @@ def str2bool(v):
step = 0
test_abs(args, device_id, cp, step)
elif (args.mode == 'test_text'):
cp = args.test_from
try:
step = int(cp.split('.')[-2].split('_')[-1])
except:
step = 0
test_text_abs(args)

elif (args.task == 'ext'):
Expand All @@ -152,4 +160,9 @@ def str2bool(v):
step = 0
test_ext(args, device_id, cp, step)
elif (args.mode == 'test_text'):
test_text_ext(args)
cp = args.test_from
try:
step = int(cp.split('.')[-2].split('_')[-1])
except:
step = 0
test_text_ext(args)
2 changes: 0 additions & 2 deletions src/train_abstractive.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,6 @@ def train_iter_fct():
trainer.train(train_iter_fct, args.train_steps)




def test_text_abs(args):

logger.info('Loading checkpoint from %s' % args.test_from)
Expand Down