Skip to content

Commit cf44222

Browse files
author
Thejas Venkatesh
committed
Add new document to passage splitting mechanism
1 parent 4e56f4d commit cf44222

File tree

2 files changed

+161
-4
lines changed

2 files changed

+161
-4
lines changed

utility/preprocess/docs2passages.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,18 @@ def process_page(inp):
2929
else:
3030
words = tokenizer.tokenize(content)
3131

32-
words_ = (words + words) if len(words) > nwords else words
33-
passages = [words_[offset:offset + nwords] for offset in range(0, len(words) - overlap, nwords - overlap)]
34-
35-
assert all(len(psg) in [len(words), nwords] for psg in passages), (list(map(len, passages)), len(words))
32+
n_passages = (len(words) + nwords - 1) // nwords
33+
if n_passages > 1:
34+
last_2_passage_length = len(words) - nwords * (n_passages - 2)
35+
passage_lengths = [0] + [nwords] * (n_passages - 2) + [last_2_passage_length // 2] + [last_2_passage_length - last_2_passage_length // 2]
36+
assert sum(passage_lengths) == len(words)
37+
elif n_passages == 1:
38+
passage_lengths = [0, len(words)]
39+
else:
40+
passage_lengths = [0]
41+
print(n_passages, passage_lengths)
42+
assert len(passage_lengths) == n_passages + 1
43+
passages = [words[passage_lengths[idx-1]:passage_lengths[idx]] for idx in range(1, len(passage_lengths))]
3644

3745
if tokenizer is None:
3846
passages = [' '.join(psg) for psg in passages]
+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
"""
2+
Divide a document collection into N-word/token passage spans (with wrap-around for last passage).
3+
"""
4+
5+
import os
6+
import math
7+
import ujson
8+
import random
9+
10+
from multiprocessing import Pool
11+
from argparse import ArgumentParser
12+
from colbert.utils.utils import print_message
13+
14+
Format1 = 'docid,text' # MS MARCO Passages
15+
Format2 = 'docid,text,title' # DPR Wikipedia
16+
Format3 = 'docid,url,title,text' # MS MARCO Documents
17+
18+
19+
def process_page(inp):
20+
"""
21+
Wraps around if we split: make sure last passage isn't too short.
22+
This is meant to be similar to the DPR preprocessing.
23+
"""
24+
25+
(nwords, overlap, tokenizer), (title_idx, docid, title, url, content) = inp
26+
27+
if tokenizer is None:
28+
words = content.split()
29+
else:
30+
words = tokenizer.tokenize(content)
31+
32+
words_ = (words + words) if len(words) > nwords else words
33+
passages = [words_[offset:offset + nwords] for offset in range(0, len(words) - overlap, nwords - overlap)]
34+
35+
assert all(len(psg) in [len(words), nwords] for psg in passages), (list(map(len, passages)), len(words))
36+
37+
if tokenizer is None:
38+
passages = [' '.join(psg) for psg in passages]
39+
else:
40+
passages = [' '.join(psg).replace(' ##', '') for psg in passages]
41+
42+
if title_idx % 100000 == 0:
43+
print("#> ", title_idx, '\t\t\t', title)
44+
45+
for p in passages:
46+
print("$$$ ", '\t\t', p)
47+
print()
48+
49+
print()
50+
print()
51+
print()
52+
53+
return (docid, title, url, passages)
54+
55+
56+
def main(args):
57+
random.seed(12345)
58+
print_message("#> Starting...")
59+
60+
letter = 'w' if not args.use_wordpiece else 't'
61+
output_path = f'{args.input}.{letter}{args.nwords}_{args.overlap}'
62+
assert not os.path.exists(output_path)
63+
64+
RawCollection = []
65+
Collection = []
66+
67+
NumIllFormattedLines = 0
68+
69+
with open(args.input) as f:
70+
for line_idx, line in enumerate(f):
71+
if line_idx % (100*1000) == 0:
72+
print(line_idx, end=' ')
73+
74+
title, url = None, None
75+
76+
try:
77+
line = line.strip().split('\t')
78+
79+
if args.format == Format1:
80+
docid, doc = line
81+
elif args.format == Format2:
82+
docid, doc, title = line
83+
elif args.format == Format3:
84+
docid, url, title, doc = line
85+
86+
RawCollection.append((line_idx, docid, title, url, doc))
87+
except:
88+
NumIllFormattedLines += 1
89+
90+
if NumIllFormattedLines % 1000 == 0:
91+
print(f'\n[{line_idx}] NumIllFormattedLines = {NumIllFormattedLines}\n')
92+
93+
print()
94+
print_message("# of documents is", len(RawCollection), '\n')
95+
96+
p = Pool(args.nthreads)
97+
98+
print_message("#> Starting parallel processing...")
99+
100+
tokenizer = None
101+
if args.use_wordpiece:
102+
from transformers import BertTokenizerFast
103+
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
104+
105+
process_page_params = [(args.nwords, args.overlap, tokenizer)] * len(RawCollection)
106+
Collection = p.map(process_page, zip(process_page_params, RawCollection))
107+
108+
print_message(f"#> Writing to {output_path} ...")
109+
with open(output_path, 'w') as f:
110+
line_idx = 1
111+
112+
if args.format == Format1:
113+
f.write('\t'.join(['id', 'text']) + '\n')
114+
elif args.format == Format2:
115+
f.write('\t'.join(['id', 'text', 'title']) + '\n')
116+
elif args.format == Format3:
117+
f.write('\t'.join(['id', 'text', 'title', 'docid']) + '\n')
118+
119+
for docid, title, url, passages in Collection:
120+
for passage in passages:
121+
if args.format == Format1:
122+
f.write('\t'.join([str(line_idx), passage]) + '\n')
123+
elif args.format == Format2:
124+
f.write('\t'.join([str(line_idx), passage, title]) + '\n')
125+
elif args.format == Format3:
126+
f.write('\t'.join([str(line_idx), passage, title, docid]) + '\n')
127+
128+
line_idx += 1
129+
130+
131+
if __name__ == "__main__":
132+
parser = ArgumentParser(description="docs2passages.")
133+
134+
# Input Arguments.
135+
parser.add_argument('--input', dest='input', required=True)
136+
parser.add_argument('--format', dest='format', required=True, choices=[Format1, Format2, Format3])
137+
138+
# Output Arguments.
139+
parser.add_argument('--use-wordpiece', dest='use_wordpiece', default=False, action='store_true')
140+
parser.add_argument('--nwords', dest='nwords', default=100, type=int)
141+
parser.add_argument('--overlap', dest='overlap', default=0, type=int)
142+
143+
# Other Arguments.
144+
parser.add_argument('--nthreads', dest='nthreads', default=28, type=int)
145+
146+
args = parser.parse_args()
147+
assert args.nwords in range(50, 500)
148+
149+
main(args)

0 commit comments

Comments
 (0)