-
Notifications
You must be signed in to change notification settings - Fork 753
Expand file tree
/
Copy pathdownstream.py
More file actions
2376 lines (2090 loc) · 86.8 KB
/
downstream.py
File metadata and controls
2376 lines (2090 loc) · 86.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import abc
import logging
import re
from typing import Any, Dict, List, Optional, Sequence, Union
import datasets
import torch
import torch.nn.functional as F
from sklearn.metrics import f1_score
from torchmetrics import Metric
from olmo.util import load_hf_dataset, load_oe_eval_requests
from ..tokenizer import Tokenizer
log = logging.getLogger(__name__)
# Map from oe-eval metrics to metrics used here
METRIC_FROM_OE_EVAL = {
"acc_raw": "acc",
"acc_per_char": "len_norm",
"acc_uncond": "pmi_dc",
"logits_per_byte": "bpb",
}
LOG_2_OF_E = 1.44269504089
class ICLMetric(Metric):
# update method does not require access to global metric state
full_state_update: bool = False
def __init__(self, metric_type="acc") -> None:
"""metric_type: f1, acc, len_norm, pmi_dc, ce_loss, bpb"""
super().__init__(sync_on_compute=True)
self.metric_type = metric_type
self.add_state("loglikelihoods", default=[], dist_reduce_fx=None)
self.add_state("labels", default=[], dist_reduce_fx=None)
def reset(
self,
):
self.loglikelihoods = []
self.labels = []
def update(self, batch: Dict[str, Any], lm_logits: torch.Tensor, dc_lm_logits=None):
lm_logits = F.log_softmax(lm_logits, dim=-1)
if self.metric_type == "pmi_dc":
assert dc_lm_logits is not None, "PMI_DC acc type selected but no domain conditional logits provided"
for idx, (doc_id, cont_id) in enumerate(zip(batch["doc_id"], batch["cont_id"])):
# [cont_len]: continuation is padded for batching
cont_tokens = batch["continuation"][idx][: batch["cont_len"][idx]]
# get logits from LM for the continuation: [cont_len, vocab]
# batch['input_ids'][idx] -> ctx + cont + padding
# -1 in both indices: lm_logits will be left shited 1 pos as 0th pos in input generates next token in the 0th pos of lm_logits
lm_cont_logits = lm_logits[idx][
batch["ctx_len"][idx] - 1 : batch["ctx_len"][idx] + batch["cont_len"][idx] - 1
]
log_likelihood: torch.Tensor
if self.metric_type == "pmi_dc":
assert dc_lm_logits is not None
# get domain conditional continuation logits: [cont_len, vocab]
dc_lm_cont_logits = dc_lm_logits[idx][
batch["dc_len"][idx] - 1 : batch["dc_len"][idx] + batch["cont_len"][idx] - 1
]
# gather log-probs at continuation token indices but divide by domain conditional prob
log_likelihood = (
torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum()
/ torch.gather(dc_lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum()
)
elif self.metric_type == "acc" or self.metric_type == "f1":
# gather log-probs at continuation token indices
log_likelihood = torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum()
elif self.metric_type == "len_norm" or self.metric_type == "ce_loss":
log_likelihood = (
torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum() / batch["cont_str_len"][idx]
)
if self.metric_type == "ce_loss":
log_likelihood = -log_likelihood
elif self.metric_type == "bpb":
# bits per byte
log_likelihood = (
-torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum()
/ batch["cont_byte_len"][idx]
* LOG_2_OF_E
)
else:
raise ValueError(self.metric_type)
# because metric states cannot be dict/list of tuples, store this tuple as tensor: (doc_id, cont_id, metric_state)
self.loglikelihoods.append(
torch.Tensor((doc_id, cont_id, log_likelihood)).to(batch["continuation"][idx].device)
)
self.labels.append(
torch.LongTensor((doc_id, cont_id, batch["label_id"][idx])).to(batch["label_id"][idx].device)
)
def compute(self) -> torch.Tensor:
# states should have been synced from all accelerators at this point
# account for duplicates here because of DistributedSampler compensating for drop_last=False
loglikelihood_dict: Dict[int, Dict[int, float]] = {}
label_dict = {}
# collect labels
for doc_id, cont_id, label_id in self.labels:
if doc_id.item() not in label_dict:
label_dict[doc_id.item()] = label_id.item()
# collect loglikelihoods
for doc_id, cont_id, loglikelihood in self.loglikelihoods:
if int(doc_id.item()) not in loglikelihood_dict:
loglikelihood_dict[int(doc_id.item())] = {}
if int(cont_id.item()) not in loglikelihood_dict[int(doc_id.item())]:
loglikelihood_dict[int(doc_id.item())][int(cont_id.item())] = loglikelihood
# compute acc
correct = []
preds: Optional[List[float]] = None
labels: Optional[List[int]] = None
if self.metric_type == "f1":
preds = []
labels = []
for doc_id in loglikelihood_dict:
# each doc_id might have a different number of continuation
num_continuations = len(loglikelihood_dict[doc_id].keys())
loglikelihoods = torch.tensor([-float("inf")] * num_continuations)
skip_document = False
for cont_id in loglikelihood_dict[doc_id]:
try:
loglikelihoods[cont_id] = loglikelihood_dict[doc_id][cont_id]
except IndexError:
# We didn't process all of the continuations, so skip this document.
skip_document = True
break
if skip_document:
continue
if self.metric_type in ["ce_loss", "bpb"]:
correct.append(loglikelihoods[0]) # Only one answer is scored
else:
correct.append(1.0 if torch.argmax(loglikelihoods).item() == label_dict[doc_id] else 0.0)
if self.metric_type == "f1":
assert preds is not None
assert labels is not None
preds.append(torch.argmax(loglikelihoods).item())
labels.append(label_dict[doc_id])
if self.metric_type == "f1":
assert preds is not None
assert labels is not None
# for NLI tasks, continuations are yes, no, neither, so idx=0 assigned to pos label
score = f1_score(labels, preds, pos_label=0)
else:
score = sum(correct) / len(correct)
return torch.tensor(score)
class ICLMultiChoiceTaskDataset(metaclass=abc.ABCMeta):
"""Only supports zero-shot for now."""
metric_type: str
def __init__(
self,
tokenizer: Tokenizer,
dataset_path: str,
dataset_name: Union[str, Sequence[str], None] = None,
model_ctx_len: int = 2048,
split="validation",
metric_type=None, # Override default metric type
prompts=[None], # List of prompt variants to use
):
super().__init__()
self.tokenizer = tokenizer
self.dataset_path = dataset_path
self.dataset_name = dataset_name
self.model_ctx_len = model_ctx_len
self.prompts = prompts
self.current_prompt = None
if metric_type is not None:
self.metric_type = metric_type
self.log_instances = 0 # Set to > 0 to log the first few instances as a sanity check
self.samples: List[Dict[str, Any]] = []
dataset_names: Sequence[Optional[str]]
if isinstance(dataset_name, str) or dataset_name is None:
dataset_names = [dataset_name]
else:
dataset_names = dataset_name
dataset_list = []
for ds_name in dataset_names:
dataset = load_hf_dataset(self.dataset_path, ds_name, split)
dataset_list.append(dataset)
self.dataset = datasets.concatenate_datasets(dataset_list)
# prep examples
self.prep_examples()
def __getitem__(self, index):
return self.samples[index]
def __len__(self):
return len(self.samples)
def prep_examples(self):
"""Append doc_ids to each example so that they are processed together in the metric"""
doc_id = 0
for doc in self.dataset:
for prompt in self.prompts:
self.current_prompt = prompt
# from EAI harness
# how this all works:
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# gpt2 \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
continuations = self.doc_to_continuations(doc)
label_id = self.doc_to_label(doc)
doc_text = self.doc_to_text(doc)
ctx = self.token_encode(doc_text)
dc = self.token_encode(self.doc_to_domain_conditional(doc))
if self.log_instances > 0:
self.log_instances -= 1
ds_name = self.dataset_name
if isinstance(ds_name, list):
ds_name = ds_name[0]
log.info(
f"Sample doc from ({self.dataset_path}, {ds_name}, {self.current_prompt}):"
+ f"\ndoc_text: {doc_text}\ncontinuations: {continuations}"
)
for cont_id, continuation_str in enumerate(continuations):
cont_str_len = len(continuation_str) - 1 # continuation contain leading blank
cont_byte_len = len(continuation_str[1:].encode("utf-8"))
continuation = self.token_encode(continuation_str)
# query, remove last token from continuation, truncate from left is longer than model ctx length
query = ctx + continuation[:-1]
query = query[-self.model_ctx_len :]
# this will be different from len(ctx) when truncated by model_ctx_len
actual_ctx_len = len(query) - len(continuation) + 1
# get domain conditional query
# we don't expect this to be longer than self.model_ctx_len and it won't make sense to truncate from left
dc_query = dc + continuation[:-1]
# form a sample
self.samples.append(
{
"doc_id": doc_id,
"cont_id": cont_id,
"ctx": ctx,
"continuation": continuation,
"ctx_len": actual_ctx_len,
"dc_len": len(dc),
"cont_len": len(
continuation
), # even if query has last token removed, LM will output same cont len
"cont_str_len": cont_str_len,
"cont_byte_len": cont_byte_len,
"query": query, # remove last token from continuation
"dc_query": dc_query,
"label_id": label_id,
}
)
doc_id += 1
def pad_tokens_until_max(self, tokens, max_len=2048):
"""truncate from left if len(tokens) > model_ctx_len, max_len is not considered then
queries are already truncated at max length of model_ctx_len
this acts as additional check for all types of sequences in the batch
"""
if len(tokens) > self.model_ctx_len:
return tokens[-self.model_ctx_len :]
else:
# pad to max_len, but check again if this padding exceeded self.model_ctx_len
# this time truncate from right side of the sequence because additional padding caused len(tokens) > self.model_ctx_len
tokens = tokens + [self.tokenizer.pad_token_id] * (max_len - len(tokens))
if len(tokens) > self.model_ctx_len:
tokens = tokens[: self.model_ctx_len]
return tokens
def collate_fn(self, data):
# pad to max length
# 'ctx', 'continuation', 'query' can all have variable length
max_ctx_len = 0
max_cont_len = 0
max_query_len = 0
max_dc_query_len = 0
for sample in data:
if len(sample["ctx"]) > max_ctx_len:
max_ctx_len = len(sample["ctx"])
if len(sample["continuation"]) > max_cont_len:
max_cont_len = len(sample["continuation"])
if len(sample["query"]) > max_query_len:
max_query_len = len(sample["query"])
if len(sample["dc_query"]) > max_dc_query_len:
max_dc_query_len = len(sample["dc_query"])
doc_ids = []
cont_ids = []
ctxs = []
continuations = []
ctx_lens = []
dc_lens = []
cont_lens = []
cont_str_lens = []
cont_byte_lens = []
queries = []
dc_queries = []
label_ids = []
# pad according to max_lengths
for sample in data:
doc_ids.append(sample["doc_id"])
cont_ids.append(sample["cont_id"])
ctxs.append(torch.LongTensor(self.pad_tokens_until_max(sample["ctx"], max_len=max_ctx_len)))
continuations.append(
torch.LongTensor(self.pad_tokens_until_max(sample["continuation"], max_len=max_cont_len))
)
ctx_lens.append(sample["ctx_len"])
dc_lens.append(sample["dc_len"])
cont_lens.append(sample["cont_len"])
cont_str_lens.append(sample["cont_str_len"])
cont_byte_lens.append(sample["cont_byte_len"])
queries.append(torch.LongTensor(self.pad_tokens_until_max(sample["query"], max_len=max_query_len)))
dc_queries.append(
torch.LongTensor(self.pad_tokens_until_max(sample["dc_query"], max_len=max_dc_query_len))
)
label_ids.append(sample["label_id"])
batch = {
"doc_id": torch.LongTensor(doc_ids),
"cont_id": torch.LongTensor(cont_ids),
"ctx": torch.stack(ctxs),
"continuation": torch.stack(continuations),
"ctx_len": torch.LongTensor(ctx_lens),
"dc_len": torch.LongTensor(dc_lens),
"cont_len": torch.LongTensor(cont_lens), # since query has last token removed from continuation
"cont_str_len": torch.LongTensor(cont_str_lens),
"cont_byte_len": torch.LongTensor(cont_byte_lens),
"input_ids": torch.stack(queries),
"dc_input_ids": torch.stack(dc_queries),
}
if not isinstance(label_ids, str):
batch["label_id"] = torch.LongTensor(label_ids)
return batch
def token_encode(self, string: str) -> List[int]:
return self.tokenizer.encode(string, add_special_tokens=False)
def token_decode(self, tokens: List[int]) -> str:
return self.tokenizer.decode(tokens)
@abc.abstractmethod
def doc_to_text(self, doc) -> str:
"""Match EAI eval harness
returns a single context string
"""
raise NotImplementedError
@abc.abstractmethod
def doc_to_continuations(self, doc) -> List[str]:
"""Match EAI eval harness
returns a list of continuations
"""
raise NotImplementedError
@abc.abstractmethod
def doc_to_label(self, doc) -> int:
"""Match EAI eval harness
returns continuation id which corresponds to true label
"""
raise NotImplementedError
def doc_to_domain_conditional(self, doc) -> str:
"""Provide string for domain conditional normalization
by default its blank string, continuation normalized by prob conditioned on a blank
"""
del doc
return " "
class PIQA(ICLMultiChoiceTaskDataset):
"""PIQA sends context in the following fashion: "Question: GOAL\nAnswer:"
space added as prefix to each continuation
implement PMI_DC
{
'goal': "How do I ready a guinea pig cage for it's new occupants?",
'sol1': 'Provide the guinea pig with a cage full of a few inches of bedding made of ripped paper strips, you will also need to supply it with a water bottle and a food dish.',
'sol2': 'Provide the guinea pig with a cage full of a few inches of bedding made of ripped jeans material, you will also need to supply it with a water bottle and a food dish.',
'label': 0
}
"""
metric_type = "len_norm"
def __init__(
self,
tokenizer,
dataset_path="piqa",
dataset_name="plain_text",
):
super().__init__(
tokenizer=tokenizer,
dataset_path=dataset_path,
dataset_name=dataset_name,
)
def doc_to_text(self, doc):
return "Question: " + doc["goal"] + "\nAnswer:"
def doc_to_continuations(self, doc):
# add spaces in front of continuation
return [" " + doc["sol1"], " " + doc["sol2"]]
def doc_to_label(self, doc):
return doc["label"]
def doc_to_domain_conditional(self, doc):
del doc
return "Answer:"
class HellaSwag(ICLMultiChoiceTaskDataset):
"""HellaSwag concats "ACTIVITY_LABEL: CTX_A CTX_B.capitalize()" to form context and then sends endings as continuations
space added as prefix to each continuation
{
'activity_label': 'Roof shingle removal',
'ctx_a': 'A man is sitting on a roof.',
'ctx_b': 'he',
'ctx': 'A man is sitting on a roof. he',
'endings': ['is using wrap to wrap a pair of skis.', 'is ripping level tiles off.', "is holding a rubik's cube.", 'starts pulling up roofing on a roof.'],
'label': '3'
}
"""
metric_type = "len_norm"
def __init__(
self,
tokenizer,
dataset_path="hellaswag",
dataset_name=None,
):
super().__init__(
tokenizer=tokenizer,
dataset_path=dataset_path,
dataset_name=dataset_name,
)
@classmethod
def preprocess(cls, text):
text = text.strip()
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
text = text.replace(" [title]", ". ")
text = re.sub("\\[.*?\\]", "", text)
text = text.replace(" ", " ")
return text
def doc_to_text(self, doc):
return self.preprocess(doc["activity_label"] + ": " + doc["ctx_a"] + " " + doc["ctx_b"].capitalize())
def doc_to_continuations(self, doc):
# add spaces in front of continuation
return [" " + self.preprocess(ending) for ending in doc["endings"]]
def doc_to_label(self, doc):
return int(doc["label"])
def doc_to_domain_conditional(self, doc):
domain_conditional = self.preprocess(doc["ctx_b"].capitalize())
# ensure non 0 len domain conditional
if len(domain_conditional) == 0:
return self.preprocess(doc["ctx_a"]).split(" ")[-1]
return domain_conditional
class WinoGrande(ICLMultiChoiceTaskDataset):
"""Prompt: split sentence at _ "SENTENCE[:idx] + OPTION1/OPTION2", where idx = SENTENCE.index("_")
implement PMI_DC
acc, random at 50%
continuation is everything in setnence after '_' (" SENTENCE[idx:].strip()")
Req_loglikelihood('People think Samantha', ' is embarassed, because Samantha made snide comments about the shirt Rebecca was wearing.')
Req_loglikelihood('People think Rebecca', ' is embarassed, because Samantha made snide comments about the shirt Rebecca was wearing.')
{
'sentence': 'People think _ is embarassed, because Samantha made snide comments about the shirt Rebecca was wearing.',
'option1': 'Samantha',
'option2': 'Rebecca',
'answer': '2'
}
TODO: might need to write custom metric for Winogrande
"""
metric_type = "acc"
def __init__(
self,
tokenizer,
dataset_path="winogrande",
dataset_name="winogrande_xl",
):
# all winogrande datasets have same val set
super().__init__(
tokenizer=tokenizer,
dataset_path=dataset_path,
dataset_name=dataset_name,
)
def prep_examples(self):
"""Overwrite for WinoGrande as multiple ctx, single continuation"""
doc_id = 0
for doc in self.dataset:
# here ctx is a list
ctxs = self.doc_to_text(doc)
dcs = self.doc_to_domain_conditional(doc)
continuation_str = self.doc_to_continuations(doc)
label_id = self.doc_to_label(doc)
cont_str_len = len(continuation_str) - 1 # continuations contain leading blank space
cont_byte_len = len(continuation_str[1:].encode("utf-8"))
# tokenize
continuation = self.token_encode(continuation_str)
for cont_id, (ctx, dc) in enumerate(zip(ctxs, dcs)):
ctx = self.token_encode(ctx)
dc = self.token_encode(dc)
# query, remove last token from continuation, truncate from left is longer than model ctx length
query = ctx + continuation[:-1]
query = query[-self.model_ctx_len :]
# get domain conditional query
# we don't expect this to be longer than self.model_ctx_len and it won't make sense to truncate from left
dc_query = dc + continuation[:-1]
# form a sample
self.samples.append(
{
"doc_id": doc_id,
"cont_id": cont_id,
"ctx": ctx,
"continuation": continuation,
"ctx_len": len(ctx),
"dc_len": len(dc),
"cont_len": len(
continuation
), # even if query has last token removed, LM will output same cont len
"cont_str_len": cont_str_len,
"cont_byte_len": cont_byte_len,
"query": query, # remove last token from continuation
"dc_query": dc_query,
"label_id": label_id,
}
)
doc_id += 1
def doc_to_text(self, doc):
# special case where there are multiple ctx and single continuation
pronoun_loc = doc["sentence"].index("_")
ctx = []
for option in [doc["option1"], doc["option2"]]:
ctx.append(doc["sentence"][:pronoun_loc] + option)
return ctx
def doc_to_continuations(self, doc):
# add spaces in front of continuation
pronoun_loc = doc["sentence"].index("_") + 1
return " " + doc["sentence"][pronoun_loc:].strip()
def doc_to_label(self, doc):
return int(doc["answer"]) - 1
def doc_to_domain_conditional(self, doc):
"""same number of domain conditionals as context"""
return [doc["option1"], doc["option2"]]
class OpenBookQA(ICLMultiChoiceTaskDataset):
"""OBQA: question_stem is sent as context (no special prompt format) and choices are sent as continuation
space added as prefix to each continuation
implement PMI_DC
{
'question_stem': 'Frilled sharks and angler fish live far beneath the surface of the ocean, which is why they are known as',
'choices': {'text': ['Deep sea animals', 'fish', 'Long Sea Fish', 'Far Sea Animals'],
'label': ['A', 'B', 'C', 'D']},
'answerKey': 'A'
}
"""
metric_type = "len_norm"
def __init__(
self,
tokenizer,
dataset_path="openbookqa",
dataset_name="main",
):
super().__init__(
tokenizer=tokenizer,
dataset_path=dataset_path,
dataset_name=dataset_name,
)
def doc_to_text(self, doc):
return doc["question_stem"]
def doc_to_continuations(self, doc):
# add spaces in front of continuation
return [" " + choice for choice in doc["choices"]["text"]]
def doc_to_label(self, doc):
return ["A", "B", "C", "D"].index(doc["answerKey"].strip())
def doc_to_domain_conditional(self, doc):
return doc["question_stem"].strip().split(" ")[-1]
class BoolQ(ICLMultiChoiceTaskDataset):
"""Prompt: "PASSAGE\nQuestion: QUESTION?\nAnswer:"
acc, random at 50% (SuperGLUE)
continuation: yes, no
{
'question': 'is ncis new orleans over for the season',
'passage': 'NCIS: New Orleans (season 4) -- The fourth season of NCIS: New Orleans premiered on September 26, 2017 on CBS. The series continues to air following Bull, Tuesday at 10:00 p.m. (ET) and contained 24 episodes. The season concluded on May 15, 2018.',
'label': 1
}
"""
metric_type = "acc"
def __init__(
self,
tokenizer,
dataset_path="boolq",
dataset_name=None,
):
super().__init__(
tokenizer=tokenizer,
dataset_path=dataset_path,
dataset_name=dataset_name,
)
def doc_to_text(self, doc):
return doc["passage"] + "\nQuestion: " + doc["question"] + "?\nAnswer:"
def doc_to_continuations(self, doc):
del doc
# add spaces in front of continuation
return [" yes", " no"]
def doc_to_label(self, doc):
# if doc['answer'] is True, return index of " yes" which is 0
if doc["answer"]:
return 0
else:
return 1
def doc_to_domain_conditional(self, doc):
del doc
return "Answer:"
class SciQ(ICLMultiChoiceTaskDataset):
"""SciQ sends context as "SUPPORT\nQuestion: QUESTION\nAnswer:" and then distractors + correct_answer as continuations
space added as prefix to each continuation
implement PMI_DC
{
'question': 'Who proposed the theory of evolution by natural selection?',
'distractor3': 'Scopes',
'distractor1': 'Linnaeus',
'distractor2': 'shaw',
'correct_answer': 'darwin',
'support': ''
}
"""
metric_type = "acc"
def __init__(
self,
tokenizer,
dataset_path="sciq",
dataset_name=None,
):
super().__init__(
tokenizer=tokenizer,
dataset_path=dataset_path,
dataset_name=dataset_name,
)
def doc_to_text(self, doc):
return doc["support"].strip() + "\nQuestion: " + doc["question"] + "\nAnswer:"
def doc_to_continuations(self, doc):
# add spaces in front of continuation
return [
" " + doc["distractor1"],
" " + doc["distractor2"],
" " + doc["distractor3"],
" " + doc["correct_answer"],
]
def doc_to_label(self, doc):
del doc
return 3
def doc_to_domain_conditional(self, doc):
del doc
return "Answer:"
class ArcEasy(ICLMultiChoiceTaskDataset):
"""ArcEasy creates context with "Question: QUESTION\nAnswer:" and sends the choices as continuations
space added as prefix to each continuation
{
'question': 'Which technology was developed most recently?',
'choices': {'text': ['cellular telephone', 'television', 'refrigerator', 'airplane'],
'label': ['A', 'B', 'C', 'D']},
'answerKey': 'A'
}
"""
metric_type = "acc"
def __init__(
self,
tokenizer,
dataset_path="ai2_arc",
dataset_name="ARC-Easy",
):
super().__init__(
tokenizer=tokenizer,
dataset_path=dataset_path,
dataset_name=dataset_name,
)
def doc_to_text(self, doc):
return "Question: " + doc["question"] + "\nAnswer:"
def doc_to_continuations(self, doc):
# add spaces in front of continuation
return [" " + choice for choice in doc["choices"]["text"]]
def doc_to_label(self, doc):
# some doc["answerKey"] are stored as numbers
num_to_letter = {"1": "A", "2": "B", "3": "C", "4": "D", "5": "E"}
if doc["answerKey"] in num_to_letter:
doc["answerKey"] = num_to_letter[doc["answerKey"]]
return ["A", "B", "C", "D", "E"].index(doc["answerKey"])
def doc_to_domain_conditional(self, doc):
del doc
return "Answer:"
class ArcChallenge(ArcEasy):
"""ArcChallenge follows the same prompt format as ArcEasy.
implement PMI_DC
"""
metric_type = "len_norm" # Ideally "pmi_dc"
def __init__(
self,
tokenizer,
dataset_path="ai2_arc",
dataset_name="ARC-Challenge",
):
super().__init__(
tokenizer=tokenizer,
dataset_path=dataset_path,
dataset_name=dataset_name,
)
class ArcEasyCELoss(ArcEasy):
"""ArcEasyCELoss is ARCEasy using an alternate ce_loss metric"""
metric_type = "ce_loss"
def doc_to_continuations(self, doc):
# We only consider the correct answer for this metric
answer = doc["choices"]["text"][self.doc_to_label(doc)]
return [" " + answer]
def doc_to_label(self, doc):
return 0
class BasicArithmetic(ArcEasy):
"""This is a basic arithmetic task follows the same prompt format as ArcEasy.
Example:
{"id": "q85_1d1d_max1d_plus",
"question": "Calculate 2 + 5 =",
"choices": {"text": ["8", "7", "6", "17"],
"label": ["A", "B", "C", "D"]},
"answerKey": "B", "type_tag": "easy"}
"""
metric_type = "acc"
def __init__(
self,
tokenizer,
dataset_path="allenai/basic_arithmetic",
dataset_name=None,
):
super().__init__(
tokenizer=tokenizer,
dataset_path=dataset_path,
dataset_name=dataset_name,
)
class CommonsenseQA(ArcEasy):
"""CommonsenseQA
Example:
{'id': 'e68fb2448fd74e402aae9982aa76e527',
'question': 'Where are you likely to find a hamburger?',
'question_concept': 'hamburger',
'choices': {'label': ['A', 'B', 'C', 'D', 'E'],
'text': ['fast food restaurant', 'pizza', 'ground up dead cows', 'mouth', 'cow carcus']},
'answerKey': 'A'}
"""
metric_type = "len_norm"
def __init__(
self,
tokenizer,
dataset_path="tau/commonsense_qa",
dataset_name=None,
):
super().__init__(
tokenizer=tokenizer,
dataset_path=dataset_path,
dataset_name=dataset_name,
)
class SocialIQa(ICLMultiChoiceTaskDataset):
"""SocialIQa
Example:
{'context': 'Jordan was in charge of taking the food on the camping trip and left all the food at home.',
'question': 'How would Jordan feel afterwards?',
'answerA': 'horrible that he let his friends down on the camping trip',
'answerB': "happy that he doesn't need to do the cooking on the trip",
'answerC': 'very proud and accomplished about the camping trip', 'label': '1'}
"""
metric_type = "len_norm"
def __init__(
self,
tokenizer,
dataset_path="social_i_qa",
dataset_name=None,
):
super().__init__(
tokenizer=tokenizer,
dataset_path=dataset_path,
dataset_name=dataset_name,
)
def doc_to_text(self, doc):
return "Question: " + doc["context"] + " " + doc["question"] + "\nAnswer:"
def doc_to_continuations(self, doc):
# add spaces in front of continuation
return [
" " + doc["answerA"],
" " + doc["answerB"],
" " + doc["answerC"],
]
def doc_to_label(self, doc):
return int(doc["label"]) - 1
def doc_to_domain_conditional(self, doc):
return "Answer:"
class COPA(ICLMultiChoiceTaskDataset):
"""Prompt: "PREMISE.strip()[:-1] because/therefore"
Req_loglikelihood('The pair of students came under scrutiny by the teacher because', ' the students both received excellent grades.'
continuations: CHOICE1/CHOICE2
"cause": "because",
"effect": "therefore",
implement PMI_DC
acc, random at 50%
{
'premise': 'The pair of students came under scrutiny by the teacher.',
'choice1': 'The students both received excellent grades.',
'choice2': 'Their responses on the assignment were identical.',
'question': 'cause',
'label': 1
}
"""
metric_type = "acc"
def __init__(
self,
tokenizer,
dataset_path="super_glue",
dataset_name="copa",
):
super().__init__(
tokenizer=tokenizer,
dataset_path=dataset_path,
dataset_name=dataset_name,
)
def doc_to_text(self, doc):
connector = "because" if doc["question"] == "cause" else "therefore"
# remove the period
return doc["premise"].strip()[:-1] + " " + connector
def doc_to_continuations(self, doc):
# add spaces in front of continuation
def convert_choice(choice):
return choice[0].lower() + choice[1:]
return [" " + convert_choice(doc["choice1"]), " " + convert_choice(doc["choice2"])]
def doc_to_label(self, doc):
return doc["label"]
def doc_to_domain_conditional(self, doc):
return "because" if doc["question"] == "cause" else "therefore"
class RTE(ICLMultiChoiceTaskDataset):
"""Prompt: "SENTENCE1\nQuestion: SENTENCE2 True or False?\nAnswer:"
implement PMI_DC
acc, random at 50% (GLUE)
continuations: True, False
{
'sentence1': 'The number of Danes opposed to swapping the krone for the euro has increased slightly to 35.3 percent, up from 34.6 percent in April, according to a poll published on Thursday by Danske Bank.',
'sentence2': 'The introduction of the euro has been opposed.',
'label': 0,
}
"""
metric_type = "len_norm"