Skip to content

Commit

Permalink
make LoadfromDictionary and LoadAPI Lazy too, and return a dataset so…
Browse files Browse the repository at this point in the history
… that no need to recursive_copy each instance from loader

Signed-off-by: dafnapension <[email protected]>
  • Loading branch information
dafnapension committed Feb 23, 2025
1 parent 070690e commit ff9a435
Showing 1 changed file with 107 additions and 139 deletions.
246 changes: 107 additions & 139 deletions src/unitxt/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,68 +288,53 @@ def is_streaming(self) -> bool:
def load_dataset(
self, split: str, streaming=None, disable_memory_caching=False
) -> Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict]:
dataset = None #self.__class__._loader_cache.get(str(self) + "_" + str(split), None)
if dataset is None:
if streaming is None:
streaming = self.is_streaming()

with tempfile.TemporaryDirectory() as dir_to_be_deleted:
if settings.disable_hf_datasets_cache:
cache_dir = dir_to_be_deleted
else:
cache_dir = None
kwargs = {
"path": self.path,
"name": self.name,
"data_dir": self.data_dir,
"data_files": self.data_files,
"revision": self.revision,
"streaming": streaming,
"cache_dir": cache_dir,
"verification_mode": "no_checks",
"split": split,
"trust_remote_code": settings.allow_unverified_code,
"num_proc": self.num_proc,
"download_config" : DownloadConfig(
max_retries=settings.loaders_max_retries,
# extract_on_the_fly=True,
),
}
try:
# load the dataset and verify that it is loaded safe and sound
dataset = hf_load_dataset(**kwargs)
if split is not None:
next(iter(dataset))
else:
for k in dataset:
next(iter(dataset[k]))
if streaming is None:
streaming = self.is_streaming()

except Exception as e:
if "trust_remote_code" in str(e):
raise ValueError(
f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
) from e

# try the opposite way of streaming
current_streaming = kwargs["streaming"]
logger.info(
f"needed to swap streaming from {current_streaming} to {not current_streaming} for path {self.path}"
)
kwargs["streaming"] = not kwargs["streaming"]
dataset = hf_load_dataset(**kwargs)
if split is not None:
next(iter(dataset))
else:
for k in dataset:
next(iter(dataset[k]))
with tempfile.TemporaryDirectory() as dir_to_be_deleted:
if settings.disable_hf_datasets_cache:
cache_dir = dir_to_be_deleted
else:
cache_dir = None
kwargs = {
"data_dir": self.data_dir,
"data_files": self.data_files,
"revision": self.revision,
"streaming": streaming,
"cache_dir": cache_dir,
"split": split,
"num_proc": self.num_proc,
}
try:
# load the dataset and verify that it is loaded safe and sound
dataset = hf_load_dataset(self.path, self.name, **kwargs)
if split is not None:
next(iter(dataset))
else:
for k in dataset:
next(iter(dataset[k]))

if self.filtering_lambda is not None:
dataset = dataset.filter(eval(self.filtering_lambda))
except Exception as e:
if "trust_remote_code" in str(e):
raise ValueError(
f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
) from e

# try the opposite way of streaming
current_streaming = kwargs["streaming"]
logger.info(
f"needed to swap streaming from {current_streaming} to {not current_streaming} for path {self.path}"
)
kwargs["streaming"] = not kwargs["streaming"]
dataset = hf_load_dataset(self.path, self.name, **kwargs)
if split is not None:
next(iter(dataset))
else:
for k in dataset:
next(iter(dataset[k]))

# if not disable_memory_caching:
# self.__class__._loader_cache.max_size = settings.loader_cache_size
# self.__class__._loader_cache[str(self) + "_" + str(split)] = dataset
if self.filtering_lambda is not None:
dataset = dataset.filter(eval(self.filtering_lambda))

return dataset

Expand Down Expand Up @@ -472,33 +457,28 @@ def get_splits(self) -> List[str]:
def split_generator(self, split: str) -> Generator:
import fsspec
fsspec.core.DEFAULT_EXPAND = True
dataset = None #self.__class__._loader_cache.get(str(self) + "_" + split, None)
if dataset is None:
reader = self.get_reader()
for attempt in range(settings.loaders_max_retries):
reader = self.get_reader()
for attempt in range(settings.loaders_max_retries):
try:
dataset = reader(
self.files[split], **self.get_args()
).to_dict("records")
next(iter(dataset))
break
except:
try:
dataset = reader(
self.files[split], **self.get_args()
).to_dict("records")
with fsspec.open(self.files[split], mode="rt", expand=True) as f:
dataset = reader(
f, **self.get_args()
).to_dict("records")
next(iter(dataset))
break
except:
try:
with fsspec.open(self.files[split], mode="rt", expand=True) as f:
dataset = reader(
f, **self.get_args()
).to_dict("records")
next(iter(dataset))
break
except Exception as e:
logger.info(f"Attempt csv load {attempt + 1} failed: {e}")
if attempt < settings.loaders_max_retries - 1:
time.sleep(2)
else:
raise e
# self.__class__._loader_cache.max_size = settings.loader_cache_size
# self.__class__._loader_cache[str(self) + "_" + split] = dataset

except Exception as e:
logger.info(f"Attempt csv load {attempt + 1} failed: {e}")
if attempt < settings.loaders_max_retries - 1:
time.sleep(2)
else:
raise e
yield from dataset


Expand Down Expand Up @@ -547,15 +527,11 @@ def get_splits(self):
return self.splits

def split_generator(self, split: str) -> Generator:
dataset = None #self.__class__._loader_cache.get(str(self) + "_" + split, None)
if dataset is None:
split_data = self.downloader(subset=split)
targets = [split_data["target_names"][t] for t in split_data["target"]]
df = pd.DataFrame([split_data["data"], targets]).T
df.columns = ["data", "target"]
dataset = df.to_dict("records")
# self.__class__._loader_cache.max_size = settings.loader_cache_size
# self.__class__._loader_cache[str(self) + "_" + split] = dataset
split_data = self.downloader(subset=split)
targets = [split_data["target_names"][t] for t in split_data["target"]]
df = pd.DataFrame([split_data["data"], targets]).T
df.columns = ["data", "target"]
dataset = df.to_dict("records")
yield from dataset


Expand Down Expand Up @@ -816,7 +792,7 @@ class MultipleSourceLoader(LazyLoader):
MultipleSourceLoader(sources = [ LoadCSV({"test": "mytest1.csv"}, LoadCSV({"test": "mytest2.csv"}) ])
"""

sources: List[Loader]
sources: List[LazyLoader]

def add_data_classification(self, multi_stream: MultiStream) -> MultiStream:
if self.data_classification_policy is None:
Expand All @@ -837,7 +813,7 @@ def split_generator(self, split: str) -> Generator[Any, None, None]:
)()[split]


class LoadFromDictionary(Loader):
class LoadFromDictionary(LazyLoader):
"""Allows loading data from a dictionary of constants.
The loader can be used, for example, when debugging or working with small datasets.
Expand Down Expand Up @@ -885,9 +861,12 @@ def _maybe_set_classification_policy(self):
["proprietary"], "when loading from python dictionary"
)

def load_iterables(self) -> Dict[str, List[Dict[str, Any]]]:
return self.data
def get_splits(self) -> List[str]:
return sorted(self.data.keys())

def split_generator(self, split: str) -> Generator:
dataset = Dataset.from_generator(generator = self.data[split].__iter__)
yield from dataset

class LoadFromHFSpace(LazyLoader):
"""Used to load data from HuggingFace Spaces lazily.
Expand Down Expand Up @@ -1009,7 +988,7 @@ def split_generator(self, split: str) -> Generator:



class LoadFromAPI(Loader):
class LoadFromAPI(LazyLoader):
"""Loads data from from API.
This loader is designed to fetch data from an API endpoint,
Expand Down Expand Up @@ -1054,7 +1033,11 @@ class LoadFromAPI(Loader):
def _maybe_set_classification_policy(self):
self.set_default_data_classification(["proprietary"], "when loading from API")

def load_iterables(self) -> Dict[str, Iterable]:
def get_splits(self) -> List[str]:
return sorted(self.urls.keys())

def split_generator(self, split: str) -> Generator:
# def load_iterables(self) -> Dict[str, Iterable]:
api_key = os.getenv(self.api_key_env_var, None)
if not api_key:
raise ValueError(
Expand All @@ -1069,50 +1052,35 @@ def load_iterables(self) -> Dict[str, Iterable]:
if self.headers:
base_headers.update(self.headers)

iterables = {}
for split_name, url in self.urls.items():
if self.get_limit() is not None:
self.log_limited_loading()

if self.method == "GET":
response = requests.get(
url,
headers=base_headers,
verify=self.verify_cert,
)
elif self.method == "POST":
response = requests.post(
url,
headers=base_headers,
verify=self.verify_cert,
json={},
)
else:
raise ValueError(f"Method {self.method} not supported")

response.raise_for_status()

data = json.loads(response.text)
if self.method == "GET":
response = requests.get(
self.urls[split],
headers=base_headers,
verify=self.verify_cert,
)
elif self.method == "POST":
response = requests.post(
self.urls[split],
headers=base_headers,
verify=self.verify_cert,
json={},
)
else:
raise ValueError(f"Method {self.method} not supported")

if self.data_field:
if self.data_field not in data:
raise ValueError(
f"Data field '{self.data_field}' not found in API response."
)
data = data[self.data_field]
response.raise_for_status()

if self.get_limit() is not None:
data = data[: self.get_limit()]
data = json.loads(response.text)

iterables[split_name] = data
if self.data_field:
if self.data_field not in data:
raise ValueError(
f"Data field '{self.data_field}' not found in API response."
)
data = data[self.data_field]

return iterables
if self.get_limit() is not None:
data = data[: self.get_limit()]

def process(self) -> MultiStream:
self._maybe_set_classification_policy()
iterables = None #self.__class__._loader_cache.get(str(self), None)
if iterables is None:
iterables = self.load_iterables()
# self.__class__._loader_cache.max_size = settings.loader_cache_size
# self.__class__._loader_cache[str(self)] = iterables
return MultiStream.from_iterables(iterables)
dataset = Dataset.from_generator(generator = data.__iter__)
yield from dataset

0 comments on commit ff9a435

Please sign in to comment.