@@ -484,4 +484,176 @@ def test_joint_distribution_with_multiple_inputs_model_has_correct_parameter_nam
484484
485485 assert joint_dist (x_dist = x_val , y_dist = y_val , data_dist = np .array ([2 ,2 ,3 ])).likelihood .get_parameter_names () == ['z_dist' ]
486486 assert joint_dist (x_dist = x_val , z_dist = z_val , data_dist = np .array ([2 ,2 ,3 ])).likelihood .get_parameter_names () == ['y_dist' ]
487- assert joint_dist (y_dist = y_val , z_dist = z_val , data_dist = np .array ([2 ,2 ,3 ])).likelihood .get_parameter_names () == ['x_dist' ]
487+ assert joint_dist (y_dist = y_val , z_dist = z_val , data_dist = np .array ([2 ,2 ,3 ])).likelihood .get_parameter_names () == ['x_dist' ]
488+
489+
490+ def test_FD_enabled_is_set_correctly ():
491+ """ Test that FD_enabled property is set correctly in JointDistribution """
492+
493+ # Create a joint distribution with two distributions
494+ d1 = cuqi .distribution .Normal (0 , 1 , name = "x" )
495+ d2 = cuqi .distribution .Gamma (lambda x : x ** 2 , 1 , name = "y" )
496+ J = cuqi .distribution .JointDistribution (d1 , d2 )
497+
498+ # Initially FD should be disabled for both
499+ assert J .FD_enabled == {"x" : False , "y" : False }
500+
501+ # Enable FD for x
502+ J .enable_FD (epsilon = {"x" : 1e-6 , "y" : None })
503+ assert J .FD_enabled == {"x" : True , "y" : False }
504+ assert J .FD_epsilon == {"x" : 1e-6 , "y" : None }
505+
506+ # Enable FD for y as well
507+ J .enable_FD (epsilon = {"x" : 1e-6 , "y" : 1e-5 })
508+ assert J .FD_enabled == {"x" : True , "y" : True }
509+ assert J .FD_epsilon == {"x" : 1e-6 , "y" : 1e-5 }
510+
511+ # Disable FD for x
512+ J .enable_FD (epsilon = {"x" : None , "y" : 1e-5 })
513+ assert J .FD_enabled == {"x" : False , "y" : True }
514+ assert J .FD_epsilon == {"x" : None , "y" : 1e-5 }
515+
516+ # Disable FD for all
517+ J .disable_FD ()
518+ assert J .FD_enabled == {"x" : False , "y" : False }
519+ assert J .FD_epsilon == {"x" : None , "y" : None }
520+
521+ # Enable FD and reduce to single density
522+ J .enable_FD () # Enable FD for all
523+ J_given_x = J (x = 0 )
524+ J_given_y = J (y = 1 )
525+
526+ # Check types and FD_enabled status of J_given_x
527+ assert isinstance (J_given_x , cuqi .distribution .Gamma )
528+ assert not J_given_x .FD_enabled # intentionally disabled for single remaining
529+ # distribution
530+ assert J_given_x .FD_epsilon == None
531+
532+ # Check types and FD_enabled status of J_given_y
533+ assert isinstance (J_given_y , cuqi .distribution .Posterior )
534+ assert J_given_y .FD_enabled
535+ assert J_given_y .FD_epsilon == 1e-8 # Default epsilon for remaining density
536+
537+ # Catch error if epsilon keys do not match parameter names
538+ with pytest .raises (ValueError , match = r"Keys of FD_epsilon must match" ):
539+ J .enable_FD (epsilon = {"x" : 1e-6 }) # Missing "y" key
540+
541+ def test_FD_enabled_is_set_correctly_for_stacked_joint_distribution ():
542+ """ Test that FD_enabled property is set correctly in JointDistribution """
543+
544+ # Create a joint distribution with two distributions
545+ x = cuqi .distribution .Normal (0 , 1 , name = "x" )
546+ y = cuqi .distribution .Uniform (1 , 2 , name = "y" )
547+ J = cuqi .distribution ._StackedJointDistribution (x , y )
548+ J .enable_FD (epsilon = {"x" : 1e-6 , "y" : None })
549+
550+ assert J .FD_enabled == {"x" : True , "y" : False }
551+ assert J .FD_epsilon == {"x" : 1e-6 , "y" : None }
552+
553+ # Reduce to single density (substitute y)
554+ J_given_y = J (y = 1.5 )
555+ assert isinstance (J_given_y , cuqi .distribution .Normal )
556+ assert J_given_y .FD_enabled == False # Intentionally disabled for
557+ # single remaining
558+ # distribution
559+ assert J_given_y .FD_epsilon is None
560+
561+ # Reduce to single density (substitute x)
562+ J_given_x = J (x = 0 )
563+ assert isinstance (J_given_x , cuqi .distribution .Uniform )
564+ assert J_given_x .FD_enabled == False
565+ assert J_given_x .FD_epsilon is None
566+
567+
568+
569+ @pytest .mark .parametrize (
570+ "densities,kwargs,fd_epsilon,expected_type,expected_fd_enabled" ,
571+ [
572+ # Case 0: Single Distribution, FD enabled
573+ (
574+ [cuqi .distribution .Normal (np .zeros (3 ), 1 , name = "x" )],
575+ {},
576+ {"x" : 1e-5 },
577+ cuqi .distribution .Normal ,
578+ False , # Intentionally disabled for single remaining distribution
579+ ),
580+ # Case 1: Single Distribution, FD disabled
581+ (
582+ [cuqi .distribution .Normal (np .zeros (3 ), 1 , name = "x" )],
583+ {},
584+ {"x" : None },
585+ cuqi .distribution .Normal ,
586+ False ,
587+ ),
588+ # Case 2: Distribution + Data distribution, substitute y
589+ (
590+ [
591+ cuqi .distribution .Normal (np .zeros (3 ), 1 , name = "x" ),
592+ cuqi .distribution .Gaussian (lambda x : x ** 2 , np .ones (3 ), name = "y" ),
593+ ],
594+ {"y" : np .ones (3 )},
595+ {"x" : 1e-6 , "y" : 1e-7 },
596+ cuqi .distribution .Posterior ,
597+ True ,
598+ ),
599+ # Case 3: Distribution + data distribution, substitute x
600+ (
601+ [
602+ cuqi .distribution .Normal (np .zeros (3 ), 1 , name = "x" ),
603+ cuqi .distribution .Gaussian (lambda x : x ** 2 , np .ones (3 ), name = "y" ),
604+ ],
605+ {"x" : np .ones (3 )},
606+ {"x" : 1e-5 , "y" : 1e-6 },
607+ cuqi .distribution .Distribution ,
608+ False , # Intentionally disabled for single remaining distribution
609+ ),
610+ # Case 4: Multiple data distributions + prior (MultipleLikelihoodPosterior)
611+ (
612+ [
613+ cuqi .distribution .Normal (np .zeros (3 ), 1 , name = "x" ),
614+ cuqi .distribution .Gaussian (lambda x : x , np .ones (3 ), name = "y1" ),
615+ cuqi .distribution .Gaussian (lambda x : x + 1 , np .ones (3 ), name = "y2" ),
616+ ],
617+ {"y1" : np .ones (3 ), "y2" : np .ones (3 )},
618+ {"x" : 1e-5 , "y1" : 1e-6 , "y2" : 1e-7 },
619+ cuqi .distribution .MultipleLikelihoodPosterior ,
620+ {"x" : True },
621+ ),
622+ # Case 5: Distribution, substitute x
623+ (
624+ [cuqi .distribution .Normal (np .zeros (3 ), 1 , name = "x" )],
625+ {"x" : np .ones (3 )},
626+ {"x" : 1e-8 },
627+ cuqi .distribution .JointDistribution ,
628+ {},
629+ ),
630+ ],
631+ )
632+ def test_fd_enabled_of_joint_distribution_after_substitution_is_correct (
633+ densities , kwargs , fd_epsilon , expected_type , expected_fd_enabled
634+ ):
635+ """ Test that FD_enabled and FD_epsilon properties are set correctly in JointDistribution even after substitution."""
636+ joint = cuqi .distribution .JointDistribution (* densities )
637+ joint .enable_FD (epsilon = fd_epsilon )
638+
639+ # Assert FD_epsilon is set correctly
640+ assert joint .FD_epsilon == fd_epsilon
641+
642+ # Substitute parameters (if any), which reduces the joint distribution
643+ reduced = joint (** kwargs )
644+
645+ # Assert the type and FD_enabled status of the reduced distribution
646+ assert isinstance (reduced , expected_type )
647+ assert reduced .FD_enabled == expected_fd_enabled
648+
649+ # Assert FD_epsilon is set correctly in the reduced distribution
650+ if expected_fd_enabled is not False :
651+ fd_epsilon_reduced = {
652+ k : v for k , v in fd_epsilon .items () if k not in kwargs .keys ()
653+ }
654+ if len (fd_epsilon_reduced ) == 1 and not isinstance (
655+ reduced , cuqi .distribution .MultipleLikelihoodPosterior
656+ ):
657+ # Single value instead of dict in this case
658+ fd_epsilon_reduced = list (fd_epsilon_reduced .values ())[0 ]
659+ assert reduced .FD_epsilon == fd_epsilon_reduced
0 commit comments