Skip to content

Commit 9ed7f43

Browse files
committed
Fix hugging face read; Adjust timeouts
1 parent ff352e3 commit 9ed7f43

File tree

3 files changed

+20
-10
lines changed

3 files changed

+20
-10
lines changed

prompting/datasets/huggingface_github.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1+
import random
2+
from typing import Any, ClassVar, Iterator
13
from datasets import load_dataset
4+
from datasets.dataset_dict import DatasetDict, IterableDatasetDict
5+
from datasets.arrow_dataset import Dataset
6+
from datasets.iterable_dataset import IterableDataset
27
from pydantic import ConfigDict, model_validator
38

49
from shared.base import BaseDataset, DatasetEntry
@@ -13,6 +18,7 @@
1318
OUTPUT_LINES = 10
1419
MAX_LINES = 500
1520
RETRIES = 50 # Increased retry limit
21+
RANDOM_SKIP = 1_000
1622

1723

1824
class HuggingFaceGithubDatasetEntry(DatasetEntry):
@@ -24,17 +30,18 @@ class HuggingFaceGithubDatasetEntry(DatasetEntry):
2430

2531
class HuggingFaceGithubDataset(BaseDataset):
2632
language: str = "python"
27-
dataset: any = None
28-
iterator: any = None
33+
dataset: ClassVar[DatasetDict | Dataset | IterableDatasetDict | IterableDataset | None] = None
34+
iterator: ClassVar[Iterator[Any] | None] = None
2935

3036
model_config = ConfigDict(arbitrary_types_allowed=True)
3137

3238
@model_validator(mode="after")
3339
def load_dataset(self) -> "HuggingFaceGithubDataset":
34-
self.dataset = load_dataset(
35-
"macrocosm-os/code-parrot-github-code", streaming=True, split="train", trust_remote_code=True
36-
)
37-
self.iterator = iter(self.dataset.filter(self._filter_function))
40+
if HuggingFaceGithubDataset.dataset is None or self.iterator is None:
41+
HuggingFaceGithubDataset.dataset = load_dataset(
42+
"macrocosm-os/code-parrot-github-code", streaming=True, split="train", trust_remote_code=True
43+
)
44+
HuggingFaceGithubDataset.iterator = iter(HuggingFaceGithubDataset.dataset.filter(self._filter_function))
3845
return self
3946

4047
def _filter_function(self, example):
@@ -55,9 +62,12 @@ def get(self) -> HuggingFaceGithubDatasetEntry:
5562
return self.next()
5663

5764
def next(self) -> HuggingFaceGithubDatasetEntry:
65+
for _ in range(random.randint(0, RANDOM_SKIP)):
66+
next(HuggingFaceGithubDataset.iterator)
67+
5868
for _ in range(RETRIES):
5969
try:
60-
entry = next(self.iterator)
70+
entry = next(HuggingFaceGithubDataset.iterator)
6171
return self._process_entry(entry) # Throws failed to get a valid file after multiple attempts
6272
except StopIteration:
6373
self.reset()
@@ -69,7 +79,7 @@ def random(self) -> HuggingFaceGithubDatasetEntry:
6979
return self.next()
7080

7181
def reset(self):
72-
self.iterator = iter(self.dataset.filter(self._filter_function))
82+
HuggingFaceGithubDataset.iterator = iter(HuggingFaceGithubDataset.dataset.filter(self._filter_function))
7383

7484

7585
if __name__ == "__main__":

prompting/tasks/web_retrieval.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class WebRetrievalTask(BaseTextTask):
4040
augmentation_system_prompt: ClassVar[str] = ""
4141
query_system_prompt: ClassVar[Optional[str]] = QUERY_SYSTEM_PROMPT
4242
target_results: int = Field(default_factory=lambda: random.randint(1, 10))
43-
timeout: int = Field(default_factory=lambda: random.randint(5, 20))
43+
timeout: int = Field(default_factory=lambda: random.randint(5, 15))
4444

4545
async def make_query(self, dataset_entry: DDGDatasetEntry) -> str:
4646
self.query = await self.generate_query(

shared/epistula.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ async def query_miners(
124124
for uid in uids:
125125
try:
126126
timeout_connect = 10
127-
timeout_postprocess = 5
127+
timeout_postprocess = 1
128128
response = asyncio.wait_for(
129129
asyncio.create_task(
130130
make_openai_query(

0 commit comments

Comments
 (0)