1
+ import random
2
+ from typing import Any , ClassVar , Iterator
1
3
from datasets import load_dataset
4
+ from datasets .dataset_dict import DatasetDict , IterableDatasetDict
5
+ from datasets .arrow_dataset import Dataset
6
+ from datasets .iterable_dataset import IterableDataset
2
7
from pydantic import ConfigDict , model_validator
3
8
4
9
from shared .base import BaseDataset , DatasetEntry
13
18
OUTPUT_LINES = 10
14
19
MAX_LINES = 500
15
20
RETRIES = 50 # Increased retry limit
21
+ RANDOM_SKIP = 1_000
16
22
17
23
18
24
class HuggingFaceGithubDatasetEntry (DatasetEntry ):
@@ -24,17 +30,18 @@ class HuggingFaceGithubDatasetEntry(DatasetEntry):
24
30
25
31
class HuggingFaceGithubDataset (BaseDataset ):
26
32
language : str = "python"
27
- dataset : any = None
28
- iterator : any = None
33
+ dataset : ClassVar [ DatasetDict | Dataset | IterableDatasetDict | IterableDataset | None ] = None
34
+ iterator : ClassVar [ Iterator [ Any ] | None ] = None
29
35
30
36
model_config = ConfigDict (arbitrary_types_allowed = True )
31
37
32
38
@model_validator (mode = "after" )
33
39
def load_dataset (self ) -> "HuggingFaceGithubDataset" :
34
- self .dataset = load_dataset (
35
- "macrocosm-os/code-parrot-github-code" , streaming = True , split = "train" , trust_remote_code = True
36
- )
37
- self .iterator = iter (self .dataset .filter (self ._filter_function ))
40
+ if HuggingFaceGithubDataset .dataset is None or self .iterator is None :
41
+ HuggingFaceGithubDataset .dataset = load_dataset (
42
+ "macrocosm-os/code-parrot-github-code" , streaming = True , split = "train" , trust_remote_code = True
43
+ )
44
+ HuggingFaceGithubDataset .iterator = iter (HuggingFaceGithubDataset .dataset .filter (self ._filter_function ))
38
45
return self
39
46
40
47
def _filter_function (self , example ):
@@ -55,9 +62,12 @@ def get(self) -> HuggingFaceGithubDatasetEntry:
55
62
return self .next ()
56
63
57
64
def next (self ) -> HuggingFaceGithubDatasetEntry :
65
+ for _ in range (random .randint (0 , RANDOM_SKIP )):
66
+ next (HuggingFaceGithubDataset .iterator )
67
+
58
68
for _ in range (RETRIES ):
59
69
try :
60
- entry = next (self .iterator )
70
+ entry = next (HuggingFaceGithubDataset .iterator )
61
71
return self ._process_entry (entry ) # Throws failed to get a valid file after multiple attempts
62
72
except StopIteration :
63
73
self .reset ()
@@ -69,7 +79,7 @@ def random(self) -> HuggingFaceGithubDatasetEntry:
69
79
return self .next ()
70
80
71
81
def reset (self ):
72
- self .iterator = iter (self .dataset .filter (self ._filter_function ))
82
+ HuggingFaceGithubDataset .iterator = iter (HuggingFaceGithubDataset .dataset .filter (self ._filter_function ))
73
83
74
84
75
85
if __name__ == "__main__" :
0 commit comments