Skip to content

Commit 38dc0ce

Browse files
committed
transformation function fix
1 parent 28f8d41 commit 38dc0ce

File tree

2 files changed

+76
-0
lines changed

2 files changed

+76
-0
lines changed

utils/task_utils.py

+3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ class TasksParam:
1111
def __init__(self, taskFilePath):
1212
# dictioanry holding all the tasks details with
1313
# task name as key.
14+
#The idea to store, retrieve task information in yaml file and process using dictionary maps and IntEnum classes
15+
# is inspired from Microsoft's mt-dnn <https://github.com/namisan/mt-dnn>
16+
1417
self.taskDetails = yaml.safe_load(open(taskFilePath))
1518
self.modelType = self.validity_checks()
1619

utils/tranform_functions.py

+73
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,79 @@ def generate_ngram_sequences(data, seq_len_right, seq_len_left):
339339
sequence_dict[key] = left_seq + right_seq
340340
i += 1
341341
return sequence_dict
342+
def msmarco_answerability_detection_to_tsv(dataDir, readFile, wrtDir, transParamDict, isTrainFile=False):
343+
"""
344+
This function transforms the MSMARCO triples data available at `triples <https://msmarco.blob.core.windows.net/msmarcoranking/triples.train.small.tar.gz>`_
345+
346+
The data contains triplets where the first entry is the query, second one is the context passage from which the query can be
347+
answered (positive passage) , while the third entry is a context passage from which the query cannot be answered (negative passage).
348+
Data is transformed into sentence pair classification format, with query-positive context pair labeled as 1 (answerable)
349+
and query-negative context pair labeled as 0 (non-answerable)
350+
351+
Following transformed files are written at wrtDir
352+
353+
- Sentence pair transformed downsampled file.
354+
- Sentence pair transformed train tsv file for answerability task
355+
- Sentence pair transformed dev tsv file for answerability task
356+
- Sentence pair transformed test tsv file for answerability task
357+
358+
For using this transform function, set ``transform_func`` : **msmarco_answerability_detection_to_tsv** in transform file.
359+
360+
Args:
361+
dataDir (:obj:`str`) : Path to the directory where the raw data files to be read are present..
362+
readFile (:obj:`str`) : This is the file which is currently being read and transformed by the function.
363+
wrtDir (:obj:`str`) : Path to the directory where to save the transformed tsv files.
364+
transParamDict (:obj:`dict`, defaults to :obj:`None`): Dictionary of function specific parameters. Not required for this transformation function.
365+
366+
- ``data_frac`` (defaults to 0.01) : Fraction of data to keep in downsampling as the original data size is too large.
367+
"""
368+
transParamDict.setdefault("data_frac", 0.01)
369+
sampleEvery = int(1/float(transParamDict["data_frac"]))
370+
startId = 0
371+
print('Making data from file {} ....'.format(readFile))
372+
rf = open(os.path.join(dataDir, readFile))
373+
sf = open(os.path.join(wrtDir, 'msmarco_triples_sampled.tsv'), 'w')
374+
375+
# reading the big file line by line
376+
for i, row in enumerate(rf):
377+
# sampling
378+
if i % 100000 == 0:
379+
print("Processing {} rows...".format(i))
380+
381+
if i % sampleEvery == 0:
382+
rowData = row.split('\t')
383+
posRowData = str(startId)+'\t'+str(1)+'\t'+ rowData[0]+'\t'+rowData[1]
384+
negRowData = str(startId+1)+'\t'+str(0)+'\t'+ rowData[0]+'\t'+rowData[2].rstrip('\n')
385+
386+
#AN IMPORTANT POINT HERE IS TO STRIP THE row ending '\n' present after the negative
387+
# passage, otherwise it will hamper the dataframe.
388+
389+
#print(negRowData)
390+
# writing the positive and negative into new sampled data file
391+
sf.write(posRowData+'\n')
392+
sf.write(negRowData+'\n')
393+
394+
#increasing id count
395+
startId += 2
396+
print('Total Number of rows in original data: ', i)
397+
print('Number of answerable samples in downsampled data: ', int(startId / 2))
398+
print('Number of non-answerable samples in downsampled data: ', int(startId / 2))
399+
print('Downsampled msmarco triples tsv saved at: {}'.format(os.path.join(wrtDir, 'msmarco_triples_sampled.tsv')))
400+
401+
#making train, test, dev split
402+
sampledDf = pd.read_csv(os.path.join(wrtDir, 'msmarco_triples_sampled.tsv'), sep='\t', header=None)
403+
trainDf, testDf = train_test_split(sampledDf, shuffle=True, random_state=SEED,
404+
test_size=0.02)
405+
trainDf.to_csv(os.path.join(wrtDir, 'msmarco_answerability_train.tsv'), sep='\t', index=False, header=False)
406+
print('Train file written at: ', os.path.join(wrtDir, 'msmarco_answerability_train.tsv'))
407+
408+
devDf, testDf = train_test_split(testDf, shuffle=True, random_state=SEED,
409+
test_size=0.5)
410+
devDf.to_csv(os.path.join(wrtDir, 'msmarco_answerability_dev.tsv'), sep='\t', index=False, header=False)
411+
print('Dev file written at: ', os.path.join(wrtDir, 'msmarco_answerability_dev.tsv'))
412+
413+
devDf.to_csv(os.path.join(wrtDir, 'msmarco_answerability_test.tsv'), sep='\t', index=False, header=False)
414+
print('Test file written at: ', os.path.join(wrtDir, 'msmarco_answerability_test.tsv'))
342415

343416
def msmarco_query_type_to_tsv(dataDir, readFile, wrtDir, transParamDict, isTrainFile=False):
344417

0 commit comments

Comments
 (0)