@@ -1440,6 +1440,8 @@ def test_optimize_acqf_mixed_continuous_relaxation(self) -> None:
14401440 discrete_dims = discrete_dims ,
14411441 max_discrete_values = max_discrete_values or MAX_DISCRETE_VALUES ,
14421442 post_processing_func = post_processing_func ,
1443+ inequality_constraints = None ,
1444+ equality_constraints = None ,
14431445 )
14441446 discrete_call_args = wrapped_discrete .call_args .kwargs
14451447 expected_dims = [0 , 4 ] if max_discrete_values is None else [0 ]
@@ -1516,3 +1518,113 @@ def org_post_proc_func(X: Tensor) -> Tensor:
15161518 # Check that generated points are rounded.
15171519 self .assertEqual (X .shape , torch .Size ([4 , train_X .shape [- 1 ]]))
15181520 self .assertAllClose (X [..., all_integer_dims ], X [..., all_integer_dims ].round ())
1521+
1522+ def test_setup_continuous_relaxation_excludes_constrained_dims (self ) -> None :
1523+ """Test that _setup_continuous_relaxation keeps constrained discrete dims."""
1524+ for dtype in (torch .float , torch .double ):
1525+ # Setup: 3 discrete dimensions
1526+ # - Dim 0: Low cardinality (2 values) - kept regardless
1527+ # - Dim 1: High cardinality (50 values), participates in constraint - kept
1528+ # - Dim 2: High cardinality (50 values), not constrained - relaxed
1529+ discrete_dims : dict [int , list [float ]] = {
1530+ 0 : [0.0 , 1.0 ], # Low cardinality - should be kept
1531+ 1 : list (range (50 )), # High cardinality, constrained - should be kept
1532+ 2 : list (range (50 )), # High cardinality, not constrained - relaxed
1533+ }
1534+ max_discrete_values = 20
1535+ # Constraint on dim 1: x[1] >= 10
1536+ inequality_constraints = [
1537+ (
1538+ torch .tensor ([1 ], dtype = torch .long , device = self .device ),
1539+ torch .tensor ([1.0 ], dtype = dtype , device = self .device ),
1540+ 10.0 ,
1541+ )
1542+ ]
1543+ # Execute: call _setup_continuous_relaxation
1544+ dims_kept , post_processing_func = _setup_continuous_relaxation (
1545+ discrete_dims = discrete_dims ,
1546+ max_discrete_values = max_discrete_values ,
1547+ post_processing_func = None ,
1548+ inequality_constraints = inequality_constraints ,
1549+ )
1550+ # Assert: dims 0 and 1 are kept (low cardinality and constrained)
1551+ self .assertIn (0 , dims_kept )
1552+ self .assertIn (1 , dims_kept )
1553+ # Assert: dim 2 is NOT in dims_kept (relaxed)
1554+ self .assertNotIn (2 , dims_kept )
1555+ # Assert: post_processing_func is not None since dim 2 was relaxed
1556+ self .assertIsNotNone (post_processing_func )
1557+ # Assert: post_processing_func rounds dim 2 but not dims 0 or 1
1558+ X = torch .tensor (
1559+ [0.4 , 25.3 , 30.7 ], # dim 0, 1, 2 with non-integer values
1560+ dtype = dtype ,
1561+ device = self .device ,
1562+ )
1563+ X_processed = post_processing_func (X )
1564+ # Dim 0 and 1 should remain unchanged (not rounded by this func)
1565+ self .assertAllClose (
1566+ X_processed [0 ], torch .tensor (0.4 , dtype = dtype , device = self .device )
1567+ )
1568+ self .assertAllClose (
1569+ X_processed [1 ], torch .tensor (25.3 , dtype = dtype , device = self .device )
1570+ )
1571+ # Dim 2 should be rounded to nearest valid value
1572+ self .assertAllClose (
1573+ X_processed [2 ], torch .tensor (31.0 , dtype = dtype , device = self .device )
1574+ )
1575+
1576+ def test_optimize_acqf_mixed_alternating_constrained_discrete_dims (self ) -> None :
1577+ """Test full workflow produces valid discrete values with constrained dims.
1578+
1579+ Uses non-contiguous choices [8, 16, 24, 32, 40, 48] to exercise the failure
1580+ mode where rounding to nearest integer (e.g. 47) differs from rounding to
1581+ nearest valid choice (48).
1582+ """
1583+ for dtype in (torch .float , torch .double ):
1584+ # Setup: GP model with posterior mean as acquisition function
1585+ d = 2 # 1 continuous + 1 discrete dimension
1586+ train_X = torch .rand (5 , d , dtype = dtype , device = self .device )
1587+ # Non-contiguous discrete values: multiples of 8 from 8 to 48
1588+ valid_choices = [8.0 , 16.0 , 24.0 , 32.0 , 40.0 , 48.0 ]
1589+ train_X [:, 1 ] = torch .tensor (
1590+ [valid_choices [i % len (valid_choices )] for i in range (5 )],
1591+ dtype = dtype ,
1592+ device = self .device ,
1593+ )
1594+ train_Y = train_X .sum (dim = - 1 , keepdim = True )
1595+ model = SingleTaskGP (train_X , train_Y )
1596+ acqf = PosteriorMean (model = model )
1597+ # Define bounds: [0, 1] for continuous, [8, 48] for discrete
1598+ bounds = torch .tensor (
1599+ [[0.0 , 8.0 ], [1.0 , 48.0 ]], dtype = dtype , device = self .device
1600+ )
1601+ # Non-contiguous discrete dimension (6 values)
1602+ discrete_dims : dict [int , list [float ]] = {1 : valid_choices }
1603+ # Constraint: x[1] >= 20 (discrete dim must be at least 20)
1604+ inequality_constraints = [
1605+ (
1606+ torch .tensor ([1 ], dtype = torch .long , device = self .device ),
1607+ torch .tensor ([1.0 ], dtype = dtype , device = self .device ),
1608+ 20.0 ,
1609+ )
1610+ ]
1611+ X , _ = optimize_acqf_mixed_alternating (
1612+ acq_function = acqf ,
1613+ bounds = bounds ,
1614+ discrete_dims = discrete_dims ,
1615+ q = 1 ,
1616+ num_restarts = 2 ,
1617+ raw_samples = 32 ,
1618+ inequality_constraints = inequality_constraints ,
1619+ options = {"max_discrete_values" : 2 , "maxiter_alternating" : 4 },
1620+ )
1621+ # Assert: discrete value is within the valid set (not just rounded int)
1622+ valid_choices_tensor = torch .tensor (
1623+ valid_choices , dtype = dtype , device = self .device
1624+ )
1625+ self .assertTrue (
1626+ torch .all (torch .isin (X [..., 1 ], valid_choices_tensor )),
1627+ f"Returned candidate { X [..., 1 ].item ()} not in { valid_choices } " ,
1628+ )
1629+ # Assert: constraint is satisfied (x[1] >= 20)
1630+ self .assertTrue (torch .all (X [..., 1 ] >= 20.0 - 1e-6 ))
0 commit comments