Skip to content

Commit 8552610

Browse files
Add narwhals materializer for dataframe agnosticism. (#189)
Co-authored-by: Marco Edward Gorelli <[email protected]>
1 parent 80c3bc1 commit 8552610

File tree

11 files changed

+484
-123
lines changed

11 files changed

+484
-123
lines changed

formulaic/materializers/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
from .arrow import ArrowMaterializer
21
from .base import FormulaMaterializer
2+
from .narwhals import NarwhalsMaterializer
33
from .pandas import PandasMaterializer
44
from .types import ClusterBy, FactorValues, NAAction
55

66
__all__ = [
7-
"ArrowMaterializer",
87
"FormulaMaterializer",
8+
"NarwhalsMaterializer",
99
"PandasMaterializer",
1010
# Useful types
1111
"ClusterBy",

formulaic/materializers/arrow.py

-50
This file was deleted.

formulaic/materializers/base.py

+61-9
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from formulaic.transforms import TRANSFORMS
3434
from formulaic.utils.cast import as_columns
3535
from formulaic.utils.layered_mapping import LayeredMapping
36+
from formulaic.utils.null_handling import find_nulls
3637
from formulaic.utils.stateful_transforms import stateful_eval
3738
from formulaic.utils.variables import Variable
3839

@@ -85,26 +86,44 @@ def for_data(cls, data: Any, output: Hashable = None) -> type[FormulaMaterialize
8586
datacls = data.__class__
8687
input_type = f"{datacls.__module__}.{datacls.__qualname__}"
8788

88-
if input_type not in cls.REGISTERED_INPUTS:
89+
materializers_supporting_input = []
90+
91+
if input_type in cls.REGISTERED_INPUTS:
92+
materializers_supporting_input.extend(cls.REGISTERED_INPUTS[input_type])
93+
94+
if output is None and materializers_supporting_input:
95+
return materializers_supporting_input[0]
96+
97+
for materializer in sorted(
98+
set(cls.REGISTERED_NAMES.values()),
99+
key=lambda x: x.REGISTER_PRECEDENCE,
100+
reverse=True,
101+
):
102+
if materializer.SUPPORTS_INPUT(data):
103+
materializers_supporting_input.append(materializer)
104+
105+
if not materializers_supporting_input:
89106
raise FormulaMaterializerNotFoundError(
90-
f"No materializer has been registered for input type {repr(input_type)}. Available input types are: {set(cls.REGISTER_INPUTS)}."
107+
f"No materializer is available for input type {repr(input_type)}. Explicitly registered input types are: {tuple(sorted(cls.REGISTERED_INPUTS))}."
91108
)
92109

93110
if output is None:
94-
return cls.REGISTERED_INPUTS[input_type][0]
111+
return materializers_supporting_input[0]
95112

96-
for materializer in cls.REGISTERED_INPUTS[input_type]:
113+
for materializer in materializers_supporting_input:
97114
if output in materializer.REGISTER_OUTPUTS:
98115
return materializer
99116

100117
output_types: set[Hashable] = set(
101-
*itertools.chain(
102-
materializer.REGISTER_OUTPUTS
103-
for materializer in cls.REGISTERED_INPUTS[input_type]
118+
itertools.chain(
119+
*[
120+
materializer.REGISTER_OUTPUTS
121+
for materializer in materializers_supporting_input
122+
]
104123
)
105124
)
106125
raise FormulaMaterializerNotFoundError(
107-
f"No materializer has been registered for input type {repr(input_type)} that supports output type {repr(output)}. Available output types for {repr(input_type)} are: {output_types}."
126+
f"No materializer is available for input type {repr(input_type)} that also supports output type {repr(output)}. Available output types for {repr(input_type)} are: {tuple(sorted(output_types, key=lambda x: str(x)))}."
108127
)
109128

110129

@@ -114,6 +133,19 @@ class FormulaMaterializer(metaclass=FormulaMaterializerMeta):
114133
REGISTER_OUTPUTS: Sequence[Hashable] = ()
115134
REGISTER_PRECEDENCE: float = 100
116135

136+
@classmethod
137+
def SUPPORTS_INPUT(cls, data: Any) -> bool:
138+
"""
139+
Check whether this materializer materializer supports the given data.
140+
This allows for non-explicit input registration where additional
141+
dynamism is required, or where this materializer should act as a
142+
fallback.
143+
144+
Note: meterializers with explicitly registered inputs will always take
145+
priority.
146+
"""
147+
return False
148+
117149
# Public API
118150

119151
@inherit_docs(method="_init")
@@ -619,7 +651,27 @@ def _is_categorical(self, values: Any) -> bool:
619651
def _check_for_nulls(
620652
self, name: str, values: Any, na_action: NAAction, drop_rows: set[int]
621653
) -> None:
622-
pass # pragma: no cover
654+
if na_action is NAAction.IGNORE:
655+
return
656+
657+
try:
658+
null_indices = find_nulls(values)
659+
660+
if na_action is NAAction.RAISE:
661+
if null_indices:
662+
raise ValueError(f"`{name}` contains null values after evaluation.")
663+
664+
elif na_action is NAAction.DROP:
665+
drop_rows.update(null_indices)
666+
667+
else:
668+
raise ValueError(
669+
f"Do not know how to interpret `na_action` = {repr(na_action)}."
670+
) # pragma: no cover; this is currently impossible to reach
671+
except ValueError as e:
672+
raise ValueError(
673+
f"Error encountered while checking for nulls in `{name}`: {e}"
674+
) from e
623675

624676
def _encode_evaled_factor(
625677
self,

formulaic/materializers/narwhals.py

+207
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
# pragma: no cover; TODO: experimental
2+
3+
from __future__ import annotations
4+
5+
import functools
6+
import itertools
7+
from collections.abc import Sequence
8+
from typing import TYPE_CHECKING, Any
9+
10+
import narwhals.stable.v1 as nw
11+
import numpy
12+
import pandas
13+
import scipy.sparse as spsparse
14+
from interface_meta import override
15+
16+
from formulaic.utils.cast import as_columns
17+
from formulaic.utils.null_handling import drop_rows as drop_nulls
18+
19+
from .base import FormulaMaterializer
20+
21+
if TYPE_CHECKING: # pragma: no cover
22+
from formulaic.model_spec import ModelSpec
23+
24+
25+
class NarwhalsMaterializer(FormulaMaterializer):
26+
REGISTER_NAME = "narwhals"
27+
REGISTER_INPUTS: Sequence[str] = (
28+
"narwhals.DataFrame",
29+
"narwhals.stable.v1.DataFrame",
30+
)
31+
REGISTER_OUTPUTS: Sequence[str] = ("narwhals", "pandas", "numpy", "sparse")
32+
33+
@override
34+
@classmethod
35+
def SUPPORTS_INPUT(cls, data: Any) -> bool:
36+
return nw.dependencies.is_into_dataframe(data)
37+
38+
@override
39+
def _init(self) -> None:
40+
self.__narwhals_data = nw.from_native(self.data, eager_only=True)
41+
self.__data_context = self.__narwhals_data.to_dict()
42+
43+
@override # type: ignore
44+
@property
45+
def data_context(self):
46+
return self.__data_context
47+
48+
@override
49+
def _is_categorical(self, values: Any) -> bool:
50+
if nw.dependencies.is_narwhals_series(values):
51+
if not values.dtype.is_numeric():
52+
return True
53+
return super()._is_categorical(values)
54+
55+
@override
56+
def _encode_constant(
57+
self,
58+
value: Any,
59+
metadata: Any,
60+
encoder_state: dict[str, Any],
61+
spec: ModelSpec,
62+
drop_rows: Sequence[int],
63+
) -> Any:
64+
nrows = self.nrows - len(drop_rows)
65+
if spec.output == "sparse":
66+
return spsparse.csc_matrix(numpy.array([value] * nrows).reshape((nrows, 1)))
67+
series = value * numpy.ones(nrows)
68+
return series
69+
70+
@override
71+
def _encode_numerical(
72+
self,
73+
values: Any,
74+
metadata: Any,
75+
encoder_state: dict[str, Any],
76+
spec: ModelSpec,
77+
drop_rows: Sequence[int],
78+
) -> Any:
79+
if drop_rows:
80+
values = drop_nulls(values, indices=drop_rows)
81+
if spec.output == "sparse":
82+
return spsparse.csc_matrix(
83+
numpy.array(values).reshape((values.shape[0], 1))
84+
)
85+
return values
86+
87+
@override
88+
def _encode_categorical(
89+
self,
90+
values: Any,
91+
metadata: Any,
92+
encoder_state: dict[str, Any],
93+
spec: ModelSpec,
94+
drop_rows: Sequence[int],
95+
reduced_rank: bool = False,
96+
) -> Any:
97+
# Even though we could reduce rank here, we do not, so that the same
98+
# encoding can be cached for both reduced and unreduced rank. The
99+
# rank will be reduced in the _encode_evaled_factor method.
100+
from formulaic.transforms import encode_contrasts
101+
102+
if drop_rows:
103+
values = drop_nulls(values, indices=drop_rows)
104+
if nw.dependencies.is_narwhals_series(values):
105+
values = values.to_pandas()
106+
107+
return as_columns(
108+
encode_contrasts(
109+
values,
110+
reduced_rank=False,
111+
output="pandas" if spec.output == "narwhals" else spec.output,
112+
_metadata=metadata,
113+
_state=encoder_state,
114+
_spec=spec,
115+
)
116+
)
117+
118+
@override
119+
def _get_columns_for_term(
120+
self, factors: list[dict[str, Any]], spec: ModelSpec, scale: float = 1
121+
) -> dict[str, Any]:
122+
out = {}
123+
124+
names = [
125+
":".join(reversed(product))
126+
for product in itertools.product(*reversed(factors))
127+
]
128+
129+
# Pre-multiply factors with only one set of values (improves performance)
130+
solo_factors = {}
131+
indices = []
132+
for i, factor in enumerate(factors):
133+
if len(factor) == 1:
134+
solo_factors.update(factor)
135+
indices.append(i)
136+
if solo_factors:
137+
for index in reversed(indices):
138+
factors.pop(index)
139+
if spec.output == "sparse":
140+
factors.append(
141+
{
142+
":".join(solo_factors): functools.reduce(
143+
spsparse.csc_matrix.multiply, solo_factors.values()
144+
)
145+
}
146+
)
147+
else:
148+
factors.append(
149+
{
150+
":".join(solo_factors): functools.reduce(
151+
numpy.multiply,
152+
(numpy.asanyarray(p) for p in solo_factors.values()),
153+
)
154+
}
155+
)
156+
157+
for i, reversed_product in enumerate(
158+
itertools.product(*(factor.items() for factor in reversed(factors)))
159+
):
160+
if spec.output == "sparse":
161+
out[names[i]] = scale * functools.reduce(
162+
spsparse.csc_matrix.multiply,
163+
(p[1] for p in reversed(reversed_product)),
164+
)
165+
else:
166+
out[names[i]] = scale * functools.reduce(
167+
numpy.multiply,
168+
(numpy.array(p[1]) for p in reversed(reversed_product)),
169+
)
170+
return out
171+
172+
@override
173+
def _combine_columns(
174+
self, cols: Sequence[tuple[str, Any]], spec: ModelSpec, drop_rows: Sequence[int]
175+
) -> pandas.DataFrame:
176+
# Special case no columns to empty csc_matrix, array, or DataFrame
177+
if not cols:
178+
values = numpy.empty((self.data.shape[0], 0))
179+
if spec.output == "sparse":
180+
return spsparse.csc_matrix(values)
181+
if spec.output == "narwhals":
182+
# TODO: Inconsistent with non-empty case below (where we use to-native)
183+
return nw.from_native(values, eager_only=True)
184+
if spec.output == "numpy":
185+
return values
186+
return pandas.DataFrame(values)
187+
188+
# Otherwise, concatenate columns into model matrix
189+
if spec.output == "sparse":
190+
return spsparse.hstack([col[1] for col in cols])
191+
192+
# TODO: Can we do better than this? Having to reconstitute raw data
193+
# does not seem ideal.
194+
combined = nw.from_dict(
195+
{name: nw.to_native(col, pass_through=True) for name, col in cols},
196+
native_namespace=nw.get_native_namespace(self.__narwhals_data),
197+
)
198+
if spec.output == "narwhals":
199+
if nw.dependencies.is_narwhals_dataframe(self.data):
200+
return combined
201+
return combined.to_native()
202+
if spec.output == "pandas":
203+
df = combined.to_pandas()
204+
return df
205+
if spec.output == "numpy":
206+
return combined.to_numpy()
207+
raise ValueError(f"Invalid output type: {spec.output}")

0 commit comments

Comments
 (0)