Skip to content

Commit

Permalink
better optional dependencies handling; release
Browse files Browse the repository at this point in the history
  • Loading branch information
liquidcarbon committed Dec 13, 2024
1 parent 8af9b3b commit f8d5d1c
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 47 deletions.
48 changes: 25 additions & 23 deletions affinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,32 @@
from time import time
from typing import TYPE_CHECKING, List, Optional, Tuple

import duckdb
import numpy as np
import pandas as pd


def try_import(module: str) -> object | None:
try:
return import_module(module)
except ImportError:
# print(f"{module} not found in the current environment")
return
class _modules:
"""Stores modules imported conditionally."""

def try_import(modules: List[str]) -> None:
"""Conditional imports."""
for module in modules:
try:
_module = import_module(module)
globals()[module] = _module # used here
setattr(_modules, module, _module) # used in tests
except ImportError:
setattr(_modules, module, False)


if TYPE_CHECKING:
import duckdb # type: ignore
import polars as pl # type: ignore
import pyarrow as pa # type: ignore
import pyarrow.parquet as pq # type: ignore
import awswrangler # type: ignore
import polars # type: ignore
import pyarrow # type: ignore
import pyarrow.parquet # type: ignore
else:
duckdb = try_import("duckdb")
pl = try_import("polars")
pa = try_import("pyarrow")
pq = try_import("pyarrow.parquet")
_modules.try_import(["awswrangler", "polars", "pyarrow", "pyarrow.parquet"])


@dataclass
Expand Down Expand Up @@ -213,7 +217,7 @@ def build(cls, query=None, dataframe=None, **kwargs):

@classmethod
def from_dataframe(
cls, dataframe: pd.DataFrame | Optional["pl.DataFrame"], **kwargs
cls, dataframe: pd.DataFrame | Optional["polars.DataFrame"], **kwargs
):
instance = cls()
for i, k in enumerate(dict(instance)):
Expand All @@ -237,9 +241,7 @@ def from_sql(cls, query: str, **kwargs):
@property
def athena_types(self):
"""Convert pandas types to SQL types for loading into AWS Athena."""

wr = try_import("awswrangler")
columns_types, partition_types = wr.catalog.extract_athena_types(
columns_types, partition_types = awswrangler.catalog.extract_athena_types(
df=self.df,
partition_cols=self.LOCATION.partition_by,
)
Expand Down Expand Up @@ -365,17 +367,17 @@ def df4(self) -> pd.DataFrame:
return self.df

@property
def arrow(self) -> "pa.Table":
def arrow(self) -> "pyarrow.Table":
metadata = {str(k): str(v) for k, v in self.metadata.items()}
_dict = {
k: [v.dict for v in vector] if self.is_dataset(k) else vector
for k, vector in self
}
return pa.table(_dict, metadata=metadata)
return pyarrow.table(_dict, metadata=metadata)

@property
def pl(self) -> "pl.DataFrame":
return pl.DataFrame(dict(self))
def pl(self) -> "polars.DataFrame":
return polars.DataFrame(dict(self))

def is_dataset(self, key):
attr = getattr(self, key, None)
Expand Down Expand Up @@ -428,7 +430,7 @@ def to_parquet(self, path, engine="duckdb", **kwargs):
if engine == "pandas":
self.df.to_parquet(path)
elif engine == "arrow":
pq.write_table(self.arrow, path)
pyarrow.parquet.write_table(self.arrow, path)
elif engine == "duckdb":
kv_metadata = []
for k, v in self.metadata.items():
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "affinity"
version = "1.0.0"
version = "1.0.1"
description = "Module for creating well-documented datasets, with types and annotations."
authors = [
{ name = "Alex Kislukhin" }
Expand All @@ -13,7 +13,6 @@ readme = "README.md"
requires-python = ">=3.11"

dependencies = [
"awswrangler>=3.10.1",
"duckdb>=1",
"pandas",
]
Expand All @@ -24,6 +23,9 @@ dev = [
"pyarrow>=17",
"pytest>=8",
]
aws = [
"awswrangler>=3.10.1",
]

[tool.hatch.build]
include = [
Expand Down
31 changes: 9 additions & 22 deletions test_affinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,6 @@
# https://github.com/duckdb/duckdb/issues/14179
duckdb.sql("SET python_scan_all_frames=true")

try:
import polars # noqa: F401

NO_POLARS = False
except ImportError:
NO_POLARS = True

try:
import pyarrow

NO_PYARROW = False
except ImportError:
NO_PYARROW = True


def test_location_default():
loc = af.Location()
Expand Down Expand Up @@ -266,8 +252,8 @@ class aDataset(af.Dataset):
pd.testing.assert_frame_equal(data.df, source_df.astype(default_dtypes))


@pytest.mark.skipif(NO_POLARS, reason="polars is not installed")
@pytest.mark.skipif(NO_PYARROW, reason="pyarrow is not installed")
@pytest.mark.skipif(not af._modules.polars, reason="polars is not installed")
@pytest.mark.skipif(not af._modules.pyarrow, reason="pyarrow is not installed")
def test_to_polars():
class aDataset(af.Dataset):
v1 = af.VectorBool("")
Expand All @@ -280,7 +266,7 @@ class aDataset(af.Dataset):
assert str(polars_df.dtypes) == "[Boolean, Float32, Int16]"


@pytest.mark.skipif(NO_PYARROW, reason="pyarrow is not installed")
@pytest.mark.skipif(not af._modules.pyarrow, reason="pyarrow is not installed")
def test_to_pyarrow():
class aDataset(af.Dataset):
v1 = af.VectorBool("")
Expand Down Expand Up @@ -329,6 +315,7 @@ class cDataset(af.Dataset):
cDataset().sql("SELECT v2 FROM df") # "df" != last test's data_a.df


@pytest.mark.skipif(not af._modules.awswrangler, reason="awswrangler is not installed")
def test_kwargs_for_create_athena_table():
class aDataset(af.Dataset):
"""Document me!"""
Expand All @@ -352,7 +339,7 @@ class aDataset(af.Dataset):
}


@pytest.mark.skipif(NO_PYARROW, reason="pyarrow is not installed")
@pytest.mark.skipif(not af._modules.pyarrow, reason="pyarrow is not installed")
def test_objects_as_metadata():
class aDataset(af.Dataset):
"""Objects other than strings can go into metadata."""
Expand All @@ -369,7 +356,7 @@ def try_ast_literal_eval(x: str):
data = aDataset(v1=[True], v2=[1 / 2], v3=[3])
test_file_arrow = Path("test_arrow.parquet")
data.to_parquet(test_file_arrow, engine="arrow")
pf = pyarrow.parquet.ParquetFile(test_file_arrow)
pf = af._modules.pyarrow.parquet.ParquetFile(test_file_arrow)
pf_metadata = pf.schema_arrow.metadata
decoded_metadata = {
k.decode(): try_ast_literal_eval(v.decode()) for k, v in pf_metadata.items()
Expand All @@ -378,8 +365,8 @@ def try_ast_literal_eval(x: str):
assert decoded_metadata.get("v2") == aDataset.v2.comment


@pytest.mark.skipif(NO_POLARS, reason="polars is not installed")
@pytest.mark.skipif(NO_PYARROW, reason="pyarrow is not installed")
@pytest.mark.skipif(not af._modules.polars, reason="polars is not installed")
@pytest.mark.skipif(not af._modules.pyarrow, reason="pyarrow is not installed")
def test_to_parquet_with_metadata():
class aDataset(af.Dataset):
"""Delightful data."""
Expand Down Expand Up @@ -441,7 +428,7 @@ class KeyValueMetadata(af.Dataset):
)


@pytest.mark.skipif(NO_PYARROW, reason="pyarrow is not installed")
@pytest.mark.skipif(not af._modules.pyarrow, reason="pyarrow is not installed")
def test_parquet_roundtrip_with_rename():
class IsotopeData(af.Dataset):
symbol = af.VectorObject("Element")
Expand Down

0 comments on commit f8d5d1c

Please sign in to comment.