diff --git a/affinity.py b/affinity.py index 7eadb4a..ad7713e 100644 --- a/affinity.py +++ b/affinity.py @@ -6,8 +6,7 @@ import pandas as pd from importlib import import_module from time import time -from typing import TYPE_CHECKING, Optional, Union - +from typing import TYPE_CHECKING, List, Optional, Union def try_import(module) -> Optional[object]: try: @@ -23,7 +22,7 @@ def try_import(module) -> Optional[object]: import polars as pl # type: ignore else: duckdb = try_import("duckdb") - pl = try_import("polars") + pl = try_import("polars") pa = try_import("pyarrow") pq = try_import("pyarrow.parquet") @@ -31,7 +30,7 @@ def try_import(module) -> Optional[object]: class Descriptor: def __get__(self, instance, owner): return self if not instance else instance.__dict__[self.name] - + def __set__(self, instance, values): try: _values = self.array_class( @@ -75,7 +74,7 @@ def __init__(self, dtype, value=None, comment=None, array_class=np.array): def __len__(self): return 1 - + def __repr__(self): return self.info @@ -96,7 +95,7 @@ def __init__(self, dtype, values=None, comment=None, array_class=np.array): def __getitem__(self, key): return self._values[key] - + def __setitem__(self, key, value): self._values[key] = value @@ -109,7 +108,7 @@ def __getattr__(self, attr): def __repr__(self): return "\n".join([f"{self.info} | len {len(self)}", repr(self._values)]) - + def __str__(self): return self.__repr__() @@ -127,7 +126,7 @@ def __repr__(cls) -> str: class Dataset(metaclass=DatasetMeta): """Base class for typed, annotated datasets.""" - + @classmethod def get_scalars(cls): return {k: None for k,v in cls.__dict__.items() if isinstance(v, Scalar)} @@ -142,11 +141,11 @@ def get_dict(cls): def __init__(self, **fields: Union[Scalar|Vector]): """Create dataset, dynamically setting field values. - + Vectors are initialized first, ensuring all are of equal length. Scalars are filled in afterwards. """ - + self.origin = {"created_ts": int(time() * 1000)} _sizes = {} self._vectors = self.__class__.get_vectors() @@ -171,14 +170,14 @@ def __init__(self, **fields: Union[Scalar|Vector]): _vector_from_scalar = Vector.from_scalar(_scalar, self._max_size) setattr(self, scalar_name, _vector_from_scalar) self._scalars[scalar_name] = _value - + if len(self.origin) == 1: # only after direct __init__ self.origin["source"] = "manual" @classmethod def build(cls, query=None, dataframe=None, **kwargs): """Build from DuckDB query or a dataframe. - + Build kwargs: - rename: how to handle source with differently named fields: None|False: field names in source must match class declaration @@ -188,7 +187,7 @@ def build(cls, query=None, dataframe=None, **kwargs): return cls.from_sql(query, **kwargs) if isinstance(dataframe, (pd.DataFrame,)): return cls.from_dataframe(dataframe, **kwargs) - + @classmethod def from_dataframe(cls, dataframe: pd.DataFrame | Optional['pl.DataFrame'], **kwargs): instance = cls() @@ -199,20 +198,20 @@ def from_dataframe(cls, dataframe: pd.DataFrame | Optional['pl.DataFrame'], **kw setattr(instance, k, dataframe[dataframe.columns[i]]) instance.origin["source"] = f"dataframe, shape {dataframe.shape}" return instance - + @classmethod def from_sql(cls, query: str, **kwargs): if kwargs.get("method") in (None, "pandas"): query_results = duckdb.sql(query).df() if kwargs.get("method") in ("polars",): - query_results = duckdb.sql(query).pl() + query_results = duckdb.sql(query).pl() instance = cls.from_dataframe(query_results, **kwargs) instance.origin["source"] += f'\nquery:\n{query}' return instance def __eq__(self, other): return self.df.equals(other.df) - + def __len__(self) -> int: return max(len(field[1]) for field in self) @@ -231,7 +230,7 @@ def __repr__(self): for k, v in dict_list.items(): lines.append(f"{k} = {v}".replace(", '...',", " ...")) return "\n".join(lines) - + def is_dataset(self, key): attr = getattr(self, key, None) if attr is None or len(attr) == 0 or isinstance(attr, Scalar): @@ -241,7 +240,7 @@ def is_dataset(self, key): def sql(self, query, **replacements): """Query the dataset with DuckDB. - + DuckDB uses replacement scans to query python objects. Class instance attributes like `FROM self.df` must be registered as views. This is what **replacements kwargs are for. @@ -254,6 +253,13 @@ def sql(self, query, **replacements): duckdb.register(k, v) return duckdb.sql(query) + def flatten(self): + """List of dicts? Dict of lists? TBD""" + raise NotImplementedError + + def model_dump(self) -> dict: + """Similar to Pydantic's model_dump; alias for dict.""" + return self.dict def to_parquet(self, path, engine="duckdb", **kwargs): if engine == "arrow": @@ -276,7 +282,7 @@ def to_parquet(self, path, engine="duckdb", **kwargs): @property def shape(self): return len(self), len(self._vectors) + len(self._scalars) - + @property def dict(self) -> dict: """JSON-like dict, with scalars as scalars and vectors as lists.""" @@ -304,7 +310,7 @@ def df(self) -> pd.DataFrame: } return pd.DataFrame(_dict) - + @property def df4(self) -> pd.DataFrame: if len(self) > 4: @@ -313,7 +319,7 @@ def df4(self) -> pd.DataFrame: return df.sort_index() else: return self.df - + @property def arrow(self) -> "pa.Table": metadata = {str(k): str(v) for k, v in self.metadata.items()} @@ -322,25 +328,25 @@ def arrow(self) -> "pa.Table": for k, vector in self } return pa.table(_dict, metadata=metadata) - + @property def pl(self) -> "pl.DataFrame": return pl.DataFrame(dict(self)) ScalarObject = Scalar.factory(object, cls_name="ScalarObject") ScalarBool = Scalar.factory("boolean", cls_name="ScalarBool") -ScalarI8 = Scalar.factory(pd.Int8Dtype(), cls_name="ScalarI8") -ScalarI16 = Scalar.factory(pd.Int16Dtype(), cls_name="ScalarI16") -ScalarI32 = Scalar.factory(pd.Int32Dtype(), cls_name="ScalarI32") -ScalarI64 = Scalar.factory(pd.Int64Dtype(), cls_name="ScalarI64") +ScalarI8 = Scalar.factory(pd.Int8Dtype(), cls_name="ScalarI8") +ScalarI16 = Scalar.factory(pd.Int16Dtype(), cls_name="ScalarI16") +ScalarI32 = Scalar.factory(pd.Int32Dtype(), cls_name="ScalarI32") +ScalarI64 = Scalar.factory(pd.Int64Dtype(), cls_name="ScalarI64") ScalarF32 = Scalar.factory(np.float32, cls_name="ScalarF32") ScalarF64 = Scalar.factory(np.float64, cls_name="ScalarF64") -VectorObject = Vector.factory(object, cls_name="VectorObject") +VectorObject = Vector.factory(object, cls_name="VectorObject") VectorBool = Vector.factory("boolean", cls_name="VectorBool") -VectorI8 = Vector.factory(pd.Int8Dtype(), cls_name="VectorI8") -VectorI16 = Vector.factory(pd.Int16Dtype(), cls_name="VectorI16") -VectorI32 = Vector.factory(pd.Int32Dtype(), cls_name="VectorI32") -VectorI64 = Vector.factory(pd.Int64Dtype(), cls_name="VectorI64") -VectorF16 = Vector.factory(np.float16, cls_name="VectorF16") +VectorI8 = Vector.factory(pd.Int8Dtype(), cls_name="VectorI8") +VectorI16 = Vector.factory(pd.Int16Dtype(), cls_name="VectorI16") +VectorI32 = Vector.factory(pd.Int32Dtype(), cls_name="VectorI32") +VectorI64 = Vector.factory(pd.Int64Dtype(), cls_name="VectorI64") +VectorF16 = Vector.factory(np.float16, cls_name="VectorF16") VectorF32 = Vector.factory(np.float32, cls_name="VectorF32") VectorF64 = Vector.factory(np.float64, cls_name="VectorF64") diff --git a/test_affinity.py b/test_affinity.py index 821ca81..47c326d 100644 --- a/test_affinity.py +++ b/test_affinity.py @@ -179,7 +179,7 @@ class aDataset(af.Dataset): assert data.origin.get("source") == "dataframe, shape (2, 3)" default_dtypes = source_df.dtypes desired_dtypes = {"v1": "boolean", "v2": np.float32, "v3": pd.Int16Dtype()} - pd.testing.assert_frame_equal(data.df, source_df.astype(desired_dtypes)) + pd.testing.assert_frame_equal(data.df, source_df.astype(desired_dtypes)) with pytest.raises(AssertionError): pd.testing.assert_frame_equal(data.df, source_df.astype(default_dtypes)) @@ -198,7 +198,7 @@ class aDataset(af.Dataset): assert data.origin.get("source") == "dataframe, shape (2, 3)\nquery:\nFROM source_df" default_dtypes = source_df.dtypes desired_dtypes = {"v1": "boolean", "v2": np.float32, "v3": pd.Int16Dtype()} - pd.testing.assert_frame_equal(data.df, source_df.astype(desired_dtypes)) + pd.testing.assert_frame_equal(data.df, source_df.astype(desired_dtypes)) with pytest.raises(AssertionError): pd.testing.assert_frame_equal(data.df, source_df.astype(default_dtypes)) @@ -350,4 +350,4 @@ class Task(af.Dataset): ], 'hours': [3, 5] } - assert t1.dict == expected_dict + assert t1.model_dump() == expected_dict