From f8d5d1c099b4c397c1601b6a6c7427658096e23f Mon Sep 17 00:00:00 2001 From: liquidcarbon Date: Fri, 13 Dec 2024 00:14:51 -0700 Subject: [PATCH] better optional dependencies handling; release --- affinity.py | 48 +++++++++++++++++++++++++----------------------- pyproject.toml | 6 ++++-- test_affinity.py | 31 +++++++++---------------------- 3 files changed, 38 insertions(+), 47 deletions(-) diff --git a/affinity.py b/affinity.py index f40b07b..e1ccdcd 100644 --- a/affinity.py +++ b/affinity.py @@ -8,28 +8,32 @@ from time import time from typing import TYPE_CHECKING, List, Optional, Tuple +import duckdb import numpy as np import pandas as pd -def try_import(module: str) -> object | None: - try: - return import_module(module) - except ImportError: - # print(f"{module} not found in the current environment") - return +class _modules: + """Stores modules imported conditionally.""" + + def try_import(modules: List[str]) -> None: + """Conditional imports.""" + for module in modules: + try: + _module = import_module(module) + globals()[module] = _module # used here + setattr(_modules, module, _module) # used in tests + except ImportError: + setattr(_modules, module, False) if TYPE_CHECKING: - import duckdb # type: ignore - import polars as pl # type: ignore - import pyarrow as pa # type: ignore - import pyarrow.parquet as pq # type: ignore + import awswrangler # type: ignore + import polars # type: ignore + import pyarrow # type: ignore + import pyarrow.parquet # type: ignore else: - duckdb = try_import("duckdb") - pl = try_import("polars") - pa = try_import("pyarrow") - pq = try_import("pyarrow.parquet") + _modules.try_import(["awswrangler", "polars", "pyarrow", "pyarrow.parquet"]) @dataclass @@ -213,7 +217,7 @@ def build(cls, query=None, dataframe=None, **kwargs): @classmethod def from_dataframe( - cls, dataframe: pd.DataFrame | Optional["pl.DataFrame"], **kwargs + cls, dataframe: pd.DataFrame | Optional["polars.DataFrame"], **kwargs ): instance = cls() for i, k in enumerate(dict(instance)): @@ -237,9 +241,7 @@ def from_sql(cls, query: str, **kwargs): @property def athena_types(self): """Convert pandas types to SQL types for loading into AWS Athena.""" - - wr = try_import("awswrangler") - columns_types, partition_types = wr.catalog.extract_athena_types( + columns_types, partition_types = awswrangler.catalog.extract_athena_types( df=self.df, partition_cols=self.LOCATION.partition_by, ) @@ -365,17 +367,17 @@ def df4(self) -> pd.DataFrame: return self.df @property - def arrow(self) -> "pa.Table": + def arrow(self) -> "pyarrow.Table": metadata = {str(k): str(v) for k, v in self.metadata.items()} _dict = { k: [v.dict for v in vector] if self.is_dataset(k) else vector for k, vector in self } - return pa.table(_dict, metadata=metadata) + return pyarrow.table(_dict, metadata=metadata) @property - def pl(self) -> "pl.DataFrame": - return pl.DataFrame(dict(self)) + def pl(self) -> "polars.DataFrame": + return polars.DataFrame(dict(self)) def is_dataset(self, key): attr = getattr(self, key, None) @@ -428,7 +430,7 @@ def to_parquet(self, path, engine="duckdb", **kwargs): if engine == "pandas": self.df.to_parquet(path) elif engine == "arrow": - pq.write_table(self.arrow, path) + pyarrow.parquet.write_table(self.arrow, path) elif engine == "duckdb": kv_metadata = [] for k, v in self.metadata.items(): diff --git a/pyproject.toml b/pyproject.toml index 49e1092..2bf9b78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "affinity" -version = "1.0.0" +version = "1.0.1" description = "Module for creating well-documented datasets, with types and annotations." authors = [ { name = "Alex Kislukhin" } @@ -13,7 +13,6 @@ readme = "README.md" requires-python = ">=3.11" dependencies = [ - "awswrangler>=3.10.1", "duckdb>=1", "pandas", ] @@ -24,6 +23,9 @@ dev = [ "pyarrow>=17", "pytest>=8", ] +aws = [ + "awswrangler>=3.10.1", +] [tool.hatch.build] include = [ diff --git a/test_affinity.py b/test_affinity.py index e7c3374..9da0202 100644 --- a/test_affinity.py +++ b/test_affinity.py @@ -11,20 +11,6 @@ # https://github.com/duckdb/duckdb/issues/14179 duckdb.sql("SET python_scan_all_frames=true") -try: - import polars # noqa: F401 - - NO_POLARS = False -except ImportError: - NO_POLARS = True - -try: - import pyarrow - - NO_PYARROW = False -except ImportError: - NO_PYARROW = True - def test_location_default(): loc = af.Location() @@ -266,8 +252,8 @@ class aDataset(af.Dataset): pd.testing.assert_frame_equal(data.df, source_df.astype(default_dtypes)) -@pytest.mark.skipif(NO_POLARS, reason="polars is not installed") -@pytest.mark.skipif(NO_PYARROW, reason="pyarrow is not installed") +@pytest.mark.skipif(not af._modules.polars, reason="polars is not installed") +@pytest.mark.skipif(not af._modules.pyarrow, reason="pyarrow is not installed") def test_to_polars(): class aDataset(af.Dataset): v1 = af.VectorBool("") @@ -280,7 +266,7 @@ class aDataset(af.Dataset): assert str(polars_df.dtypes) == "[Boolean, Float32, Int16]" -@pytest.mark.skipif(NO_PYARROW, reason="pyarrow is not installed") +@pytest.mark.skipif(not af._modules.pyarrow, reason="pyarrow is not installed") def test_to_pyarrow(): class aDataset(af.Dataset): v1 = af.VectorBool("") @@ -329,6 +315,7 @@ class cDataset(af.Dataset): cDataset().sql("SELECT v2 FROM df") # "df" != last test's data_a.df +@pytest.mark.skipif(not af._modules.awswrangler, reason="awswrangler is not installed") def test_kwargs_for_create_athena_table(): class aDataset(af.Dataset): """Document me!""" @@ -352,7 +339,7 @@ class aDataset(af.Dataset): } -@pytest.mark.skipif(NO_PYARROW, reason="pyarrow is not installed") +@pytest.mark.skipif(not af._modules.pyarrow, reason="pyarrow is not installed") def test_objects_as_metadata(): class aDataset(af.Dataset): """Objects other than strings can go into metadata.""" @@ -369,7 +356,7 @@ def try_ast_literal_eval(x: str): data = aDataset(v1=[True], v2=[1 / 2], v3=[3]) test_file_arrow = Path("test_arrow.parquet") data.to_parquet(test_file_arrow, engine="arrow") - pf = pyarrow.parquet.ParquetFile(test_file_arrow) + pf = af._modules.pyarrow.parquet.ParquetFile(test_file_arrow) pf_metadata = pf.schema_arrow.metadata decoded_metadata = { k.decode(): try_ast_literal_eval(v.decode()) for k, v in pf_metadata.items() @@ -378,8 +365,8 @@ def try_ast_literal_eval(x: str): assert decoded_metadata.get("v2") == aDataset.v2.comment -@pytest.mark.skipif(NO_POLARS, reason="polars is not installed") -@pytest.mark.skipif(NO_PYARROW, reason="pyarrow is not installed") +@pytest.mark.skipif(not af._modules.polars, reason="polars is not installed") +@pytest.mark.skipif(not af._modules.pyarrow, reason="pyarrow is not installed") def test_to_parquet_with_metadata(): class aDataset(af.Dataset): """Delightful data.""" @@ -441,7 +428,7 @@ class KeyValueMetadata(af.Dataset): ) -@pytest.mark.skipif(NO_PYARROW, reason="pyarrow is not installed") +@pytest.mark.skipif(not af._modules.pyarrow, reason="pyarrow is not installed") def test_parquet_roundtrip_with_rename(): class IsotopeData(af.Dataset): symbol = af.VectorObject("Element")