33from __future__ import annotations
44
55from dataclasses import dataclass
6- from typing import Callable , Iterable , Literal
6+ from typing import Callable , Iterable , Literal , cast
77
88import random
99import torch
@@ -108,8 +108,7 @@ def to_dict(self) -> dict[str, object]:
108108 "temperature" : self ._temperature ,
109109 "bias" : list (self ._bias ),
110110 "gaussian_parent_weights" : {
111- key : list (value )
112- for key , value in self ._gaussian_parent_weights .items ()
111+ key : list (value ) for key , value in self ._gaussian_parent_weights .items ()
113112 },
114113 "categorical_parent_weights" : {
115114 key : [list (row ) for row in value ]
@@ -120,28 +119,84 @@ def to_dict(self) -> dict[str, object]:
120119
121120 @classmethod
122121 def from_dict (cls , payload : dict [str , object ]) -> "RandomCategoricalVariable" :
122+ parent_names_raw = payload .get ("parent_names" , [])
123+ if not isinstance (parent_names_raw , list ):
124+ raise ValueError ("`parent_names` must be a list." )
125+
126+ parent_kinds_raw = payload .get ("parent_kinds" , {})
127+ if not isinstance (parent_kinds_raw , dict ):
128+ raise ValueError ("`parent_kinds` must be a dict." )
129+ parent_kinds : dict [str , Literal ["gaussian" , "categorical" ]] = {}
130+ for key , value in parent_kinds_raw .items ():
131+ value_str = str (value )
132+ if value_str not in ("gaussian" , "categorical" ):
133+ raise ValueError (
134+ "`parent_kinds` values must be 'gaussian' or 'categorical'."
135+ )
136+ parent_kinds [str (key )] = cast (
137+ Literal ["gaussian" , "categorical" ],
138+ value_str ,
139+ )
140+
141+ bias_raw = payload .get ("bias" , [])
142+ if not isinstance (bias_raw , list ):
143+ raise ValueError ("`bias` must be a list." )
144+
145+ gaussian_parent_weights_raw = payload .get ("gaussian_parent_weights" , {})
146+ if not isinstance (gaussian_parent_weights_raw , dict ):
147+ raise ValueError ("`gaussian_parent_weights` must be a dict." )
148+ gaussian_parent_weights : dict [str , list [float ]] = {}
149+ for key , values in gaussian_parent_weights_raw .items ():
150+ if not isinstance (values , list ):
151+ raise ValueError (
152+ "`gaussian_parent_weights` entries must be lists of floats."
153+ )
154+ gaussian_parent_weights [str (key )] = [float (v ) for v in values ]
155+
156+ categorical_parent_weights_raw = payload .get ("categorical_parent_weights" , {})
157+ if not isinstance (categorical_parent_weights_raw , dict ):
158+ raise ValueError ("`categorical_parent_weights` must be a dict." )
159+ categorical_parent_weights : dict [str , list [list [float ]]] = {}
160+ for key , values in categorical_parent_weights_raw .items ():
161+ if not isinstance (values , list ):
162+ raise ValueError (
163+ "`categorical_parent_weights` entries must be lists of lists."
164+ )
165+ rows : list [list [float ]] = []
166+ for row in values :
167+ if not isinstance (row , list ):
168+ raise ValueError (
169+ "`categorical_parent_weights` entries must be lists of lists."
170+ )
171+ rows .append ([float (v ) for v in row ])
172+ categorical_parent_weights [str (key )] = rows
173+
174+ transforms_raw = payload .get ("transforms" , [])
175+ if not isinstance (transforms_raw , list ):
176+ raise ValueError ("`transforms` must be a list." )
177+
178+ num_categories_raw = payload .get ("num_categories" )
179+ if num_categories_raw is None :
180+ raise ValueError ("`num_categories` is required." )
181+ if not isinstance (num_categories_raw , (int , float , str )):
182+ raise ValueError ("`num_categories` must be int-like." )
183+
184+ temperature_raw = payload .get ("temperature" )
185+ if temperature_raw is None :
186+ raise ValueError ("`temperature` is required." )
187+ if not isinstance (temperature_raw , (int , float , str )):
188+ raise ValueError ("`temperature` must be float-like." )
189+
123190 return cls (
124191 name = str (payload ["name" ]),
125- parent_names = [str (name ) for name in payload .get ("parent_names" , [])],
126- parent_kinds = {
127- str (key ): str (value ) # type: ignore[dict-item]
128- for key , value in dict (payload .get ("parent_kinds" , {})).items ()
129- },
130- num_categories = int (payload ["num_categories" ]),
131- temperature = float (payload ["temperature" ]),
132- bias = [float (value ) for value in list (payload ["bias" ])],
133- gaussian_parent_weights = {
134- str (key ): [float (v ) for v in list (values )]
135- for key , values in dict (payload .get ("gaussian_parent_weights" , {})).items ()
136- },
137- categorical_parent_weights = {
138- str (key ): [
139- [float (v ) for v in list (row )]
140- for row in list (values )
141- ]
142- for key , values in dict (payload .get ("categorical_parent_weights" , {})).items ()
143- },
144- transforms = [str (name ) for name in list (payload .get ("transforms" , []))],
192+ parent_names = [str (name ) for name in parent_names_raw ],
193+ parent_kinds = parent_kinds ,
194+ num_categories = int (num_categories_raw ),
195+ temperature = float (temperature_raw ),
196+ bias = [float (value ) for value in bias_raw ],
197+ gaussian_parent_weights = gaussian_parent_weights ,
198+ categorical_parent_weights = categorical_parent_weights ,
199+ transforms = [str (name ) for name in transforms_raw ],
145200 )
146201
147202
@@ -232,6 +287,50 @@ def f_mean(parents: dict[str, Tensor]) -> Tensor:
232287 return f_mean
233288
234289
290+ def _build_mixed_parent_mean (
291+ parent_names : list [str ],
292+ intercept : float ,
293+ scalar_coefs : dict [str , float ],
294+ parent_projections : dict [str , list [float ]],
295+ transforms : list [Callable [[Tensor ], Tensor ]] | None ,
296+ ) -> Callable [[dict [str , Tensor ]], Tensor ]:
297+ """Build Gaussian mean function supporting both scalar and categorical parents."""
298+ parent_names_local = list (parent_names )
299+ intercept_local = intercept
300+ scalar_coefs_local = dict (scalar_coefs )
301+ parent_projections_local = {
302+ name : list (weights ) for name , weights in parent_projections .items ()
303+ }
304+ transforms_local = transforms
305+
306+ def _as_scalar (parent_name : str , parent_value : Tensor ) -> Tensor :
307+ if parent_value .ndim <= 1 :
308+ return parent_value
309+ projection = torch .tensor (
310+ parent_projections_local [parent_name ],
311+ device = parent_value .device ,
312+ dtype = parent_value .dtype ,
313+ )
314+ return (parent_value * projection ).sum (dim = - 1 )
315+
316+ def mixed_parent_mean (base_parents : dict [str , Tensor ]) -> Tensor :
317+ if not parent_names_local or not base_parents :
318+ base = torch .tensor (intercept_local )
319+ else :
320+ ref = next (iter (base_parents .values ()))
321+ base = (
322+ torch .zeros_like (ref [..., 0 ] if ref .ndim > 1 else ref ) + intercept_local
323+ )
324+ for parent_name in parent_names_local :
325+ parent_scalar = _as_scalar (parent_name , base_parents [parent_name ])
326+ base = base + scalar_coefs_local [parent_name ] * parent_scalar
327+ if transforms_local :
328+ return _compose_transforms (base , transforms_local )
329+ return base
330+
331+ return mixed_parent_mean
332+
333+
235334def _sample_parents (
236335 rng : random .Random ,
237336 previous_names : list [str ],
@@ -320,10 +419,7 @@ def random_scm(config: RandomSCMConfig) -> SCM:
320419 rng .uniform (* config .categorical_bias_range )
321420 for _ in range (num_categories )
322421 ]
323- parent_kinds = {
324- parent : variable_kinds [parent ]
325- for parent in parents
326- }
422+ parent_kinds = {parent : variable_kinds [parent ] for parent in parents }
327423
328424 gaussian_parent_weights : dict [str , list [float ]] = {}
329425 categorical_parent_weights : dict [str , list [list [float ]]] = {}
@@ -343,8 +439,10 @@ def random_scm(config: RandomSCMConfig) -> SCM:
343439 for _ in range (parent_categories )
344440 ]
345441
346- transform_names = (
347- _sample_transform_names (rng ) if rng .random () < config .nonlinear_prob else []
442+ categorical_transform_names = (
443+ _sample_transform_names (rng )
444+ if rng .random () < config .nonlinear_prob
445+ else []
348446 )
349447
350448 variables .append (
@@ -357,76 +455,72 @@ def random_scm(config: RandomSCMConfig) -> SCM:
357455 bias = categorical_bias ,
358456 gaussian_parent_weights = gaussian_parent_weights ,
359457 categorical_parent_weights = categorical_parent_weights ,
360- transforms = transform_names ,
458+ transforms = categorical_transform_names ,
361459 )
362460 )
363461 variable_kinds [name ] = "categorical"
364462 variable_categories [name ] = num_categories
365463 continue
366464
367465 use_nonlinear = rng .random () < config .nonlinear_prob
368- transform_names = _sample_transform_names (rng ) if use_nonlinear else None
369- transforms = resolve_transforms (transform_names ) if transform_names else None
466+ gaussian_transform_names : list [str ] | None = (
467+ _sample_transform_names (rng ) if use_nonlinear else None
468+ )
469+ transforms = (
470+ resolve_transforms (gaussian_transform_names )
471+ if gaussian_transform_names
472+ else None
473+ )
474+
475+ parent_names_local = list (parents )
476+ intercept_local = intercept
370477
371478 parent_projections : dict [str , list [float ]] = {
372479 parent : [
373480 rng .uniform (* config .coef_range )
374481 for _ in range (variable_categories [parent ])
375482 ]
376- for parent in parents
483+ for parent in parent_names_local
377484 if variable_kinds .get (parent ) == "categorical"
378485 }
379486
380- def _as_scalar (parent_name : str , parent_value : Tensor ) -> Tensor :
381- if parent_value .ndim <= 1 :
382- return parent_value
383- projection = torch .tensor (
384- parent_projections [parent_name ],
385- device = parent_value .device ,
386- dtype = parent_value .dtype ,
387- )
388- return (parent_value * projection ).sum (dim = - 1 )
389-
390- scalar_coefs = {parent : rng .uniform (* config .coef_range ) for parent in parents }
391-
392- def mixed_parent_mean (base_parents : dict [str , Tensor ]) -> Tensor :
393- if not parents :
394- base = torch .tensor (intercept )
395- else :
396- ref = next (iter (base_parents .values ()))
397- base = torch .zeros_like (ref [..., 0 ] if ref .ndim > 1 else ref ) + intercept
398- for parent_name in parents :
399- parent_scalar = _as_scalar (parent_name , base_parents [parent_name ])
400- base = base + scalar_coefs [parent_name ] * parent_scalar
401- if transforms :
402- return _compose_transforms (base , transforms )
403- return base
487+ scalar_coefs = {
488+ parent : rng .uniform (* config .coef_range ) for parent in parent_names_local
489+ }
490+ mixed_parent_mean = _build_mixed_parent_mean (
491+ parent_names = parent_names_local ,
492+ intercept = intercept_local ,
493+ scalar_coefs = scalar_coefs ,
494+ parent_projections = parent_projections ,
495+ transforms = transforms ,
496+ )
497+ gaussian_transform_names_local = gaussian_transform_names
404498
405499 has_categorical_parent = any (
406500 variable_kinds .get (parent_name ) == "categorical"
407- for parent_name in parents
501+ for parent_name in parent_names_local
408502 )
409503
410504 if use_nonlinear or has_categorical_parent :
411505 variables .append (
412506 FunctionalVariable (
413507 name = name ,
414- parent_names = parents ,
508+ parent_names = parent_names_local ,
415509 sigma = sigma ,
416510 f_mean = mixed_parent_mean ,
417511 coefs = scalar_coefs ,
418- intercept = intercept ,
419- transforms = transform_names ,
512+ intercept = intercept_local ,
513+ transforms = gaussian_transform_names_local ,
420514 )
421515 )
422516 else :
423517 variables .append (
424518 LinearVariable (
425519 name = name ,
426- parent_names = parents ,
520+ parent_names = parent_names_local ,
427521 sigma = sigma ,
428522 coefs = scalar_coefs ,
429- intercept = intercept ,
523+ intercept = intercept_local ,
430524 )
431525 )
432526 variable_kinds [name ] = "gaussian"
0 commit comments