diff --git a/d2l/mxnet.py b/d2l/mxnet.py index 8e7d7673a3..e0db1a4f4b 100644 --- a/d2l/mxnet.py +++ b/d2l/mxnet.py @@ -2272,8 +2272,8 @@ def forward(self, tokens, segments, valid_lens=None, pred_positions=None): return encoded_X, mlm_Y_hat, nsp_Y_hat d2l.DATA_HUB['wikitext-2'] = ( - 'https://s3.amazonaws.com/research.metamind.io/wikitext/' - 'wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe') + 'https://www.kaggle.com/api/v1/datasets/download/bestwater/wikitext-2-v1' + '', 'ca5f319246c1e34d406780c0b6c5d1b0ec9b9a10') def _read_wiki(data_dir): """Defined in :numref:`sec_bert-dataset`""" @@ -3104,7 +3104,11 @@ def download(url, folder='../data', sha1_hash=None): # For back compatability url, sha1_hash = DATA_HUB[url] os.makedirs(folder, exist_ok=True) - fname = os.path.join(folder, url.split('/')[-1]) + file_name = url.split('/')[-1] + if (not "." in file_name) and file_name in ["wikitext-2-v1"]: + file_name += ".zip" + fname = os.path.join(folder, file_name) + # Check if hit cache if os.path.exists(fname) and sha1_hash: sha1 = hashlib.sha1() diff --git a/d2l/torch.py b/d2l/torch.py index 84ce7da901..9cc25062f8 100644 --- a/d2l/torch.py +++ b/d2l/torch.py @@ -2285,8 +2285,8 @@ def forward(self, tokens, segments, valid_lens=None, pred_positions=None): return encoded_X, mlm_Y_hat, nsp_Y_hat d2l.DATA_HUB['wikitext-2'] = ( - 'https://s3.amazonaws.com/research.metamind.io/wikitext/' - 'wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe') + 'https://www.kaggle.com/api/v1/datasets/download/bestwater/wikitext-2-v1' + '', 'ca5f319246c1e34d406780c0b6c5d1b0ec9b9a10') def _read_wiki(data_dir): """Defined in :numref:`sec_bert-dataset`""" @@ -3202,7 +3202,11 @@ def download(url, folder='../data', sha1_hash=None): # For back compatability url, sha1_hash = DATA_HUB[url] os.makedirs(folder, exist_ok=True) - fname = os.path.join(folder, url.split('/')[-1]) + file_name = url.split('/')[-1] + if (not "." in file_name) and file_name in ["wikitext-2-v1"]: + file_name += ".zip" + fname = os.path.join(folder, file_name) + # Check if hit cache if os.path.exists(fname) and sha1_hash: sha1 = hashlib.sha1()