Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add verify as an option to LoadFromAPI #1608

Merged
merged 8 commits into from
Feb 18, 2025
4 changes: 2 additions & 2 deletions src/unitxt/db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,11 @@ def get_table_schema(
self,
) -> str:
"""Retrieves the schema of a database."""
cur_api_url = f"{self.api_url}/datasource/{self.database_id}"
cur_api_url = f"{self.api_url}/datasources/{self.database_id}"
response = requests.get(
cur_api_url,
headers=self.headers,
verify=True,
verify=False,
timeout=self.timeout,
)
if response.status_code == 200:
Expand Down
24 changes: 12 additions & 12 deletions src/unitxt/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def process(self) -> MultiStream:
def get_splits(self):
return list(self().keys())


class LazyLoader(Loader):
split: Optional[str] = NonPositionalField(default=None)

Expand All @@ -193,9 +194,7 @@ def get_splits(self) -> List[str]:
def split_generator(self, split: str) -> Generator:
pass

def load_iterables(
self
) -> Union[Dict[str, DynamicStream], IterableDatasetDict]:
def load_iterables(self) -> Union[Dict[str, DynamicStream], IterableDatasetDict]:
if self.split is not None:
splits = [self.split]
else:
Expand Down Expand Up @@ -345,7 +344,6 @@ def get_splits(self):
dataset = self.load_dataset(split=None, streaming=False)
return list(dataset.keys())


def split_generator(self, split: str) -> Generator:
if self.get_limit() is not None:
self.log_limited_loading()
Expand Down Expand Up @@ -438,16 +436,14 @@ def split_generator(self, split: str) -> Generator:
self.log_limited_loading()

try:
dataset = reader(
self.files[split], **self.get_args()
).to_dict("records")
dataset = reader(self.files[split], **self.get_args()).to_dict(
"records"
)
except ValueError:
import fsspec

with fsspec.open(self.files[split], mode="rt") as f:
dataset = reader(
f, **self.get_args()
).to_dict("records")
dataset = reader(f, **self.get_args()).to_dict("records")
except Exception as e:
logger.debug(f"Attempt csv load {attempt + 1} failed: {e}")
if attempt < settings.loaders_max_retries - 1:
Expand Down Expand Up @@ -988,6 +984,9 @@ class LoadFromAPI(Loader):
Defaults to "data".
method (str, optional):
The HTTP method to use for API requests. Defaults to "GET".
verify_cert (bool):
Apply verification of the SSL certificate
Defaults as True
"""

urls: Dict[str, str]
Expand All @@ -998,6 +997,7 @@ class LoadFromAPI(Loader):
headers: Optional[Dict[str, Any]] = None
data_field: str = "data"
method: str = "GET"
verify_cert: bool = True

# class level shared cache:
_loader_cache = LRUCache(max_size=settings.loader_cache_size)
Expand Down Expand Up @@ -1029,13 +1029,13 @@ def load_iterables(self) -> Dict[str, Iterable]:
response = requests.get(
url,
headers=base_headers,
verify=True,
verify=self.verify_cert,
)
elif self.method == "POST":
response = requests.post(
url,
headers=base_headers,
verify=True,
verify=self.verify_cert,
json={},
)
else:
Expand Down
2 changes: 1 addition & 1 deletion utils/.secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@
"filename": "src/unitxt/loaders.py",
"hashed_secret": "840268f77a57d5553add023cfa8a4d1535f49742",
"is_verified": false,
"line_number": 597
"line_number": 593
}
],
"src/unitxt/metrics.py": [
Expand Down
Loading