Skip to content

Commit

Permalink
model_dump
Browse files Browse the repository at this point in the history
  • Loading branch information
liquidcarbon committed Sep 27, 2024
1 parent 063b3bb commit 6fd7210
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 35 deletions.
70 changes: 38 additions & 32 deletions affinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import pandas as pd
from importlib import import_module
from time import time
from typing import TYPE_CHECKING, Optional, Union

from typing import TYPE_CHECKING, List, Optional, Union

def try_import(module) -> Optional[object]:
try:
Expand All @@ -23,15 +22,15 @@ def try_import(module) -> Optional[object]:
import polars as pl # type: ignore
else:
duckdb = try_import("duckdb")
pl = try_import("polars")
pl = try_import("polars")
pa = try_import("pyarrow")
pq = try_import("pyarrow.parquet")


class Descriptor:
def __get__(self, instance, owner):
return self if not instance else instance.__dict__[self.name]

def __set__(self, instance, values):
try:
_values = self.array_class(
Expand Down Expand Up @@ -75,7 +74,7 @@ def __init__(self, dtype, value=None, comment=None, array_class=np.array):

def __len__(self):
return 1

def __repr__(self):
return self.info

Expand All @@ -96,7 +95,7 @@ def __init__(self, dtype, values=None, comment=None, array_class=np.array):

def __getitem__(self, key):
return self._values[key]

def __setitem__(self, key, value):
self._values[key] = value

Expand All @@ -109,7 +108,7 @@ def __getattr__(self, attr):

def __repr__(self):
return "\n".join([f"{self.info} | len {len(self)}", repr(self._values)])

def __str__(self):
return self.__repr__()

Expand All @@ -127,7 +126,7 @@ def __repr__(cls) -> str:

class Dataset(metaclass=DatasetMeta):
"""Base class for typed, annotated datasets."""

@classmethod
def get_scalars(cls):
return {k: None for k,v in cls.__dict__.items() if isinstance(v, Scalar)}
Expand All @@ -142,11 +141,11 @@ def get_dict(cls):

def __init__(self, **fields: Union[Scalar|Vector]):
"""Create dataset, dynamically setting field values.
Vectors are initialized first, ensuring all are of equal length.
Scalars are filled in afterwards.
"""

self.origin = {"created_ts": int(time() * 1000)}
_sizes = {}
self._vectors = self.__class__.get_vectors()
Expand All @@ -171,14 +170,14 @@ def __init__(self, **fields: Union[Scalar|Vector]):
_vector_from_scalar = Vector.from_scalar(_scalar, self._max_size)
setattr(self, scalar_name, _vector_from_scalar)
self._scalars[scalar_name] = _value

if len(self.origin) == 1: # only after direct __init__
self.origin["source"] = "manual"

@classmethod
def build(cls, query=None, dataframe=None, **kwargs):
"""Build from DuckDB query or a dataframe.
Build kwargs:
- rename: how to handle source with differently named fields:
None|False: field names in source must match class declaration
Expand All @@ -188,7 +187,7 @@ def build(cls, query=None, dataframe=None, **kwargs):
return cls.from_sql(query, **kwargs)
if isinstance(dataframe, (pd.DataFrame,)):
return cls.from_dataframe(dataframe, **kwargs)

@classmethod
def from_dataframe(cls, dataframe: pd.DataFrame | Optional['pl.DataFrame'], **kwargs):
instance = cls()
Expand All @@ -199,20 +198,20 @@ def from_dataframe(cls, dataframe: pd.DataFrame | Optional['pl.DataFrame'], **kw
setattr(instance, k, dataframe[dataframe.columns[i]])
instance.origin["source"] = f"dataframe, shape {dataframe.shape}"
return instance

@classmethod
def from_sql(cls, query: str, **kwargs):
if kwargs.get("method") in (None, "pandas"):
query_results = duckdb.sql(query).df()
if kwargs.get("method") in ("polars",):
query_results = duckdb.sql(query).pl()
query_results = duckdb.sql(query).pl()
instance = cls.from_dataframe(query_results, **kwargs)
instance.origin["source"] += f'\nquery:\n{query}'
return instance

def __eq__(self, other):
return self.df.equals(other.df)

def __len__(self) -> int:
return max(len(field[1]) for field in self)

Expand All @@ -231,7 +230,7 @@ def __repr__(self):
for k, v in dict_list.items():
lines.append(f"{k} = {v}".replace(", '...',", " ..."))
return "\n".join(lines)

def is_dataset(self, key):
attr = getattr(self, key, None)
if attr is None or len(attr) == 0 or isinstance(attr, Scalar):
Expand All @@ -241,7 +240,7 @@ def is_dataset(self, key):

def sql(self, query, **replacements):
"""Query the dataset with DuckDB.
DuckDB uses replacement scans to query python objects.
Class instance attributes like `FROM self.df` must be registered as views.
This is what **replacements kwargs are for.
Expand All @@ -254,6 +253,13 @@ def sql(self, query, **replacements):
duckdb.register(k, v)
return duckdb.sql(query)

def flatten(self):
"""List of dicts? Dict of lists? TBD"""
raise NotImplementedError

def model_dump(self) -> dict:
"""Similar to Pydantic's model_dump; alias for dict."""
return self.dict

def to_parquet(self, path, engine="duckdb", **kwargs):
if engine == "arrow":
Expand All @@ -276,7 +282,7 @@ def to_parquet(self, path, engine="duckdb", **kwargs):
@property
def shape(self):
return len(self), len(self._vectors) + len(self._scalars)

@property
def dict(self) -> dict:
"""JSON-like dict, with scalars as scalars and vectors as lists."""
Expand Down Expand Up @@ -304,7 +310,7 @@ def df(self) -> pd.DataFrame:
}
return pd.DataFrame(_dict)


@property
def df4(self) -> pd.DataFrame:
if len(self) > 4:
Expand All @@ -313,7 +319,7 @@ def df4(self) -> pd.DataFrame:
return df.sort_index()
else:
return self.df

@property
def arrow(self) -> "pa.Table":
metadata = {str(k): str(v) for k, v in self.metadata.items()}
Expand All @@ -322,25 +328,25 @@ def arrow(self) -> "pa.Table":
for k, vector in self
}
return pa.table(_dict, metadata=metadata)

@property
def pl(self) -> "pl.DataFrame":
return pl.DataFrame(dict(self))

ScalarObject = Scalar.factory(object, cls_name="ScalarObject")
ScalarBool = Scalar.factory("boolean", cls_name="ScalarBool")
ScalarI8 = Scalar.factory(pd.Int8Dtype(), cls_name="ScalarI8")
ScalarI16 = Scalar.factory(pd.Int16Dtype(), cls_name="ScalarI16")
ScalarI32 = Scalar.factory(pd.Int32Dtype(), cls_name="ScalarI32")
ScalarI64 = Scalar.factory(pd.Int64Dtype(), cls_name="ScalarI64")
ScalarI8 = Scalar.factory(pd.Int8Dtype(), cls_name="ScalarI8")
ScalarI16 = Scalar.factory(pd.Int16Dtype(), cls_name="ScalarI16")
ScalarI32 = Scalar.factory(pd.Int32Dtype(), cls_name="ScalarI32")
ScalarI64 = Scalar.factory(pd.Int64Dtype(), cls_name="ScalarI64")
ScalarF32 = Scalar.factory(np.float32, cls_name="ScalarF32")
ScalarF64 = Scalar.factory(np.float64, cls_name="ScalarF64")
VectorObject = Vector.factory(object, cls_name="VectorObject")
VectorObject = Vector.factory(object, cls_name="VectorObject")
VectorBool = Vector.factory("boolean", cls_name="VectorBool")
VectorI8 = Vector.factory(pd.Int8Dtype(), cls_name="VectorI8")
VectorI16 = Vector.factory(pd.Int16Dtype(), cls_name="VectorI16")
VectorI32 = Vector.factory(pd.Int32Dtype(), cls_name="VectorI32")
VectorI64 = Vector.factory(pd.Int64Dtype(), cls_name="VectorI64")
VectorF16 = Vector.factory(np.float16, cls_name="VectorF16")
VectorI8 = Vector.factory(pd.Int8Dtype(), cls_name="VectorI8")
VectorI16 = Vector.factory(pd.Int16Dtype(), cls_name="VectorI16")
VectorI32 = Vector.factory(pd.Int32Dtype(), cls_name="VectorI32")
VectorI64 = Vector.factory(pd.Int64Dtype(), cls_name="VectorI64")
VectorF16 = Vector.factory(np.float16, cls_name="VectorF16")
VectorF32 = Vector.factory(np.float32, cls_name="VectorF32")
VectorF64 = Vector.factory(np.float64, cls_name="VectorF64")
6 changes: 3 additions & 3 deletions test_affinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class aDataset(af.Dataset):
assert data.origin.get("source") == "dataframe, shape (2, 3)"
default_dtypes = source_df.dtypes
desired_dtypes = {"v1": "boolean", "v2": np.float32, "v3": pd.Int16Dtype()}
pd.testing.assert_frame_equal(data.df, source_df.astype(desired_dtypes))
pd.testing.assert_frame_equal(data.df, source_df.astype(desired_dtypes))
with pytest.raises(AssertionError):
pd.testing.assert_frame_equal(data.df, source_df.astype(default_dtypes))

Expand All @@ -198,7 +198,7 @@ class aDataset(af.Dataset):
assert data.origin.get("source") == "dataframe, shape (2, 3)\nquery:\nFROM source_df"
default_dtypes = source_df.dtypes
desired_dtypes = {"v1": "boolean", "v2": np.float32, "v3": pd.Int16Dtype()}
pd.testing.assert_frame_equal(data.df, source_df.astype(desired_dtypes))
pd.testing.assert_frame_equal(data.df, source_df.astype(desired_dtypes))
with pytest.raises(AssertionError):
pd.testing.assert_frame_equal(data.df, source_df.astype(default_dtypes))

Expand Down Expand Up @@ -350,4 +350,4 @@ class Task(af.Dataset):
],
'hours': [3, 5]
}
assert t1.dict == expected_dict
assert t1.model_dump() == expected_dict

0 comments on commit 6fd7210

Please sign in to comment.