Skip to content

Safe optimization in the Service API #2563

Open
@Abrikosoff

Description

Hi Ax Team,

I am trying to implement a Service API version of the safe optimization idea floated by @Balandat here; so far I've come up with a snippet of the form

def estimate_probabilities_of_satisfaction(model, points: torch.Tensor, constraints: List[Callable]):
    """
    Estimate the probability of satisfying the given nonlinear inequality constraints g(x) >= 0.
    
    Args:
        model (Model): A trained BoTorch model
        points (torch.Tensor): Points at which to estimate constraint satisfaction, shape (n, d)
        constraints (List[Callable]): List of constraint functions, each should take a tensor of shape (n, d) 
                                      and return a tensor of shape (n,)
    
    Returns:
        torch.Tensor: Probabilities of satisfying all constraints, shape (n,)
    """
    posterior = model.posterior(X=points)
    mus, sigma2s = posterior.mean, posterior.variance
    
    probs = torch.ones(points.shape[0], device=points.device)
    
    for constraint in constraints:
        # Compute the mean and variance of g(x)
        g_mus = constraint(mus)
        
        # Compute the gradients of g with respect to x
        mus.requires_grad_(True)
        g_mus = constraint(mus)
        grads = torch.autograd.grad(g_mus.sum(), mus)[0]
        mus.requires_grad_(False)
        
        # Compute the variance of g(x) using the delta method
        g_vars = torch.sum((grads.unsqueeze(1) * sigma2s * grads.unsqueeze(2)), dim=(1,2))
        
        # Create a normal distribution for g(x)
        dist = torch.distributions.normal.Normal(g_mus, g_vars.sqrt())
        
        # Compute the probability of g(x) >= 0
        prob_constraint = 1 - dist.cdf(torch.zeros_like(g_mus))
        
        # Update the overall probability
        probs *= prob_constraint
    
    return probs

def probs_constraint(
        gamma: float,
        model,
        X: torch.Tensor
        constraints: List[Callable],
):
    return gamma - estimate_probabilities_of_satisfaction(model, X, constraints)

But here I am stuck as I'm not sure how to retrieve the current fitted model, since I'm thinking of passing probs_constraint as a nonlinear_inequality_constraint in a GenerationStrategy. Any ideas?

Metadata

Assignees

Labels

questionFurther information is requested

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions