Skip to content

Commit cd6578f

Browse files
fix lint and tests
1 parent 2b10b06 commit cd6578f

1 file changed

Lines changed: 157 additions & 63 deletions

File tree

inga/scm/random.py

Lines changed: 157 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
from dataclasses import dataclass
6-
from typing import Callable, Iterable, Literal
6+
from typing import Callable, Iterable, Literal, cast
77

88
import random
99
import torch
@@ -108,8 +108,7 @@ def to_dict(self) -> dict[str, object]:
108108
"temperature": self._temperature,
109109
"bias": list(self._bias),
110110
"gaussian_parent_weights": {
111-
key: list(value)
112-
for key, value in self._gaussian_parent_weights.items()
111+
key: list(value) for key, value in self._gaussian_parent_weights.items()
113112
},
114113
"categorical_parent_weights": {
115114
key: [list(row) for row in value]
@@ -120,28 +119,84 @@ def to_dict(self) -> dict[str, object]:
120119

121120
@classmethod
122121
def from_dict(cls, payload: dict[str, object]) -> "RandomCategoricalVariable":
122+
parent_names_raw = payload.get("parent_names", [])
123+
if not isinstance(parent_names_raw, list):
124+
raise ValueError("`parent_names` must be a list.")
125+
126+
parent_kinds_raw = payload.get("parent_kinds", {})
127+
if not isinstance(parent_kinds_raw, dict):
128+
raise ValueError("`parent_kinds` must be a dict.")
129+
parent_kinds: dict[str, Literal["gaussian", "categorical"]] = {}
130+
for key, value in parent_kinds_raw.items():
131+
value_str = str(value)
132+
if value_str not in ("gaussian", "categorical"):
133+
raise ValueError(
134+
"`parent_kinds` values must be 'gaussian' or 'categorical'."
135+
)
136+
parent_kinds[str(key)] = cast(
137+
Literal["gaussian", "categorical"],
138+
value_str,
139+
)
140+
141+
bias_raw = payload.get("bias", [])
142+
if not isinstance(bias_raw, list):
143+
raise ValueError("`bias` must be a list.")
144+
145+
gaussian_parent_weights_raw = payload.get("gaussian_parent_weights", {})
146+
if not isinstance(gaussian_parent_weights_raw, dict):
147+
raise ValueError("`gaussian_parent_weights` must be a dict.")
148+
gaussian_parent_weights: dict[str, list[float]] = {}
149+
for key, values in gaussian_parent_weights_raw.items():
150+
if not isinstance(values, list):
151+
raise ValueError(
152+
"`gaussian_parent_weights` entries must be lists of floats."
153+
)
154+
gaussian_parent_weights[str(key)] = [float(v) for v in values]
155+
156+
categorical_parent_weights_raw = payload.get("categorical_parent_weights", {})
157+
if not isinstance(categorical_parent_weights_raw, dict):
158+
raise ValueError("`categorical_parent_weights` must be a dict.")
159+
categorical_parent_weights: dict[str, list[list[float]]] = {}
160+
for key, values in categorical_parent_weights_raw.items():
161+
if not isinstance(values, list):
162+
raise ValueError(
163+
"`categorical_parent_weights` entries must be lists of lists."
164+
)
165+
rows: list[list[float]] = []
166+
for row in values:
167+
if not isinstance(row, list):
168+
raise ValueError(
169+
"`categorical_parent_weights` entries must be lists of lists."
170+
)
171+
rows.append([float(v) for v in row])
172+
categorical_parent_weights[str(key)] = rows
173+
174+
transforms_raw = payload.get("transforms", [])
175+
if not isinstance(transforms_raw, list):
176+
raise ValueError("`transforms` must be a list.")
177+
178+
num_categories_raw = payload.get("num_categories")
179+
if num_categories_raw is None:
180+
raise ValueError("`num_categories` is required.")
181+
if not isinstance(num_categories_raw, (int, float, str)):
182+
raise ValueError("`num_categories` must be int-like.")
183+
184+
temperature_raw = payload.get("temperature")
185+
if temperature_raw is None:
186+
raise ValueError("`temperature` is required.")
187+
if not isinstance(temperature_raw, (int, float, str)):
188+
raise ValueError("`temperature` must be float-like.")
189+
123190
return cls(
124191
name=str(payload["name"]),
125-
parent_names=[str(name) for name in payload.get("parent_names", [])],
126-
parent_kinds={
127-
str(key): str(value) # type: ignore[dict-item]
128-
for key, value in dict(payload.get("parent_kinds", {})).items()
129-
},
130-
num_categories=int(payload["num_categories"]),
131-
temperature=float(payload["temperature"]),
132-
bias=[float(value) for value in list(payload["bias"])],
133-
gaussian_parent_weights={
134-
str(key): [float(v) for v in list(values)]
135-
for key, values in dict(payload.get("gaussian_parent_weights", {})).items()
136-
},
137-
categorical_parent_weights={
138-
str(key): [
139-
[float(v) for v in list(row)]
140-
for row in list(values)
141-
]
142-
for key, values in dict(payload.get("categorical_parent_weights", {})).items()
143-
},
144-
transforms=[str(name) for name in list(payload.get("transforms", []))],
192+
parent_names=[str(name) for name in parent_names_raw],
193+
parent_kinds=parent_kinds,
194+
num_categories=int(num_categories_raw),
195+
temperature=float(temperature_raw),
196+
bias=[float(value) for value in bias_raw],
197+
gaussian_parent_weights=gaussian_parent_weights,
198+
categorical_parent_weights=categorical_parent_weights,
199+
transforms=[str(name) for name in transforms_raw],
145200
)
146201

147202

@@ -232,6 +287,50 @@ def f_mean(parents: dict[str, Tensor]) -> Tensor:
232287
return f_mean
233288

234289

290+
def _build_mixed_parent_mean(
291+
parent_names: list[str],
292+
intercept: float,
293+
scalar_coefs: dict[str, float],
294+
parent_projections: dict[str, list[float]],
295+
transforms: list[Callable[[Tensor], Tensor]] | None,
296+
) -> Callable[[dict[str, Tensor]], Tensor]:
297+
"""Build Gaussian mean function supporting both scalar and categorical parents."""
298+
parent_names_local = list(parent_names)
299+
intercept_local = intercept
300+
scalar_coefs_local = dict(scalar_coefs)
301+
parent_projections_local = {
302+
name: list(weights) for name, weights in parent_projections.items()
303+
}
304+
transforms_local = transforms
305+
306+
def _as_scalar(parent_name: str, parent_value: Tensor) -> Tensor:
307+
if parent_value.ndim <= 1:
308+
return parent_value
309+
projection = torch.tensor(
310+
parent_projections_local[parent_name],
311+
device=parent_value.device,
312+
dtype=parent_value.dtype,
313+
)
314+
return (parent_value * projection).sum(dim=-1)
315+
316+
def mixed_parent_mean(base_parents: dict[str, Tensor]) -> Tensor:
317+
if not parent_names_local or not base_parents:
318+
base = torch.tensor(intercept_local)
319+
else:
320+
ref = next(iter(base_parents.values()))
321+
base = (
322+
torch.zeros_like(ref[..., 0] if ref.ndim > 1 else ref) + intercept_local
323+
)
324+
for parent_name in parent_names_local:
325+
parent_scalar = _as_scalar(parent_name, base_parents[parent_name])
326+
base = base + scalar_coefs_local[parent_name] * parent_scalar
327+
if transforms_local:
328+
return _compose_transforms(base, transforms_local)
329+
return base
330+
331+
return mixed_parent_mean
332+
333+
235334
def _sample_parents(
236335
rng: random.Random,
237336
previous_names: list[str],
@@ -320,10 +419,7 @@ def random_scm(config: RandomSCMConfig) -> SCM:
320419
rng.uniform(*config.categorical_bias_range)
321420
for _ in range(num_categories)
322421
]
323-
parent_kinds = {
324-
parent: variable_kinds[parent]
325-
for parent in parents
326-
}
422+
parent_kinds = {parent: variable_kinds[parent] for parent in parents}
327423

328424
gaussian_parent_weights: dict[str, list[float]] = {}
329425
categorical_parent_weights: dict[str, list[list[float]]] = {}
@@ -343,8 +439,10 @@ def random_scm(config: RandomSCMConfig) -> SCM:
343439
for _ in range(parent_categories)
344440
]
345441

346-
transform_names = (
347-
_sample_transform_names(rng) if rng.random() < config.nonlinear_prob else []
442+
categorical_transform_names = (
443+
_sample_transform_names(rng)
444+
if rng.random() < config.nonlinear_prob
445+
else []
348446
)
349447

350448
variables.append(
@@ -357,76 +455,72 @@ def random_scm(config: RandomSCMConfig) -> SCM:
357455
bias=categorical_bias,
358456
gaussian_parent_weights=gaussian_parent_weights,
359457
categorical_parent_weights=categorical_parent_weights,
360-
transforms=transform_names,
458+
transforms=categorical_transform_names,
361459
)
362460
)
363461
variable_kinds[name] = "categorical"
364462
variable_categories[name] = num_categories
365463
continue
366464

367465
use_nonlinear = rng.random() < config.nonlinear_prob
368-
transform_names = _sample_transform_names(rng) if use_nonlinear else None
369-
transforms = resolve_transforms(transform_names) if transform_names else None
466+
gaussian_transform_names: list[str] | None = (
467+
_sample_transform_names(rng) if use_nonlinear else None
468+
)
469+
transforms = (
470+
resolve_transforms(gaussian_transform_names)
471+
if gaussian_transform_names
472+
else None
473+
)
474+
475+
parent_names_local = list(parents)
476+
intercept_local = intercept
370477

371478
parent_projections: dict[str, list[float]] = {
372479
parent: [
373480
rng.uniform(*config.coef_range)
374481
for _ in range(variable_categories[parent])
375482
]
376-
for parent in parents
483+
for parent in parent_names_local
377484
if variable_kinds.get(parent) == "categorical"
378485
}
379486

380-
def _as_scalar(parent_name: str, parent_value: Tensor) -> Tensor:
381-
if parent_value.ndim <= 1:
382-
return parent_value
383-
projection = torch.tensor(
384-
parent_projections[parent_name],
385-
device=parent_value.device,
386-
dtype=parent_value.dtype,
387-
)
388-
return (parent_value * projection).sum(dim=-1)
389-
390-
scalar_coefs = {parent: rng.uniform(*config.coef_range) for parent in parents}
391-
392-
def mixed_parent_mean(base_parents: dict[str, Tensor]) -> Tensor:
393-
if not parents:
394-
base = torch.tensor(intercept)
395-
else:
396-
ref = next(iter(base_parents.values()))
397-
base = torch.zeros_like(ref[..., 0] if ref.ndim > 1 else ref) + intercept
398-
for parent_name in parents:
399-
parent_scalar = _as_scalar(parent_name, base_parents[parent_name])
400-
base = base + scalar_coefs[parent_name] * parent_scalar
401-
if transforms:
402-
return _compose_transforms(base, transforms)
403-
return base
487+
scalar_coefs = {
488+
parent: rng.uniform(*config.coef_range) for parent in parent_names_local
489+
}
490+
mixed_parent_mean = _build_mixed_parent_mean(
491+
parent_names=parent_names_local,
492+
intercept=intercept_local,
493+
scalar_coefs=scalar_coefs,
494+
parent_projections=parent_projections,
495+
transforms=transforms,
496+
)
497+
gaussian_transform_names_local = gaussian_transform_names
404498

405499
has_categorical_parent = any(
406500
variable_kinds.get(parent_name) == "categorical"
407-
for parent_name in parents
501+
for parent_name in parent_names_local
408502
)
409503

410504
if use_nonlinear or has_categorical_parent:
411505
variables.append(
412506
FunctionalVariable(
413507
name=name,
414-
parent_names=parents,
508+
parent_names=parent_names_local,
415509
sigma=sigma,
416510
f_mean=mixed_parent_mean,
417511
coefs=scalar_coefs,
418-
intercept=intercept,
419-
transforms=transform_names,
512+
intercept=intercept_local,
513+
transforms=gaussian_transform_names_local,
420514
)
421515
)
422516
else:
423517
variables.append(
424518
LinearVariable(
425519
name=name,
426-
parent_names=parents,
520+
parent_names=parent_names_local,
427521
sigma=sigma,
428522
coefs=scalar_coefs,
429-
intercept=intercept,
523+
intercept=intercept_local,
430524
)
431525
)
432526
variable_kinds[name] = "gaussian"

0 commit comments

Comments
 (0)