-
Notifications
You must be signed in to change notification settings - Fork 197
Description
Hi!
I am trying to use the hybrid model with multiple submodules to be evaluated in FHE.
I tried putting two submodules and the compilation of the circuit is ok.
But when I run the inference, it seems that the first submodules can be evaluated.
But when evaluating the second one, it shows assert repr_input_shape in self.clients AssertionError on the client side.
So roughly, on the client side, I did something like
# model.pth is the file saved by save_and_clear_private_info
model = torch.load('./compiled_models/net/model.pth')
hybrid_model = HybridFHEModel(
model,
module_names,
server_remote_address="http://0.0.0.0:8000",
model_name=model_name,
verbose=False,
)
path_to_clients = Path(__file__).parent / "clients"
hybrid_model.init_client(path_to_clients=path_to_clients)
hybrid_model.set_fhe_mode(HybridFHEMode.REMOTE)
I am using the hybrid model for a unet. The submodules I put is up1.pixel_shuffle and final. My Unet is as follows:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
# Define a customized upsample module for standard use
class CustomPixelShuffle(nn.Module):
def __init__(self, upscale_factor):
super(CustomPixelShuffle, self).__init__()
self.upscale_factor = upscale_factor
def forward(self, x):
batch_size, channels, height, width = x.size()
upscale_factor = self.upscale_factor
channels //= (upscale_factor ** 2)
x = x.view(batch_size, channels, upscale_factor, upscale_factor, height, width)
x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
x = x.view(batch_size, channels, height * upscale_factor, width * upscale_factor)
return x
class CustomUpsample(nn.Module):
def __init__(self, in_channels, out_channels, scale_factor):
super(CustomUpsample, self).__init__()
self.scale_factor = scale_factor
self.pixel_shuffle = CustomPixelShuffle(scale_factor)
self.conv = nn.Conv2d(
in_channels // (scale_factor ** 2), out_channels, kernel_size=3, padding=1
)
def forward(self, x):
x = self.pixel_shuffle(x)
x = self.conv(x)
return x
# Define a standard convolutional layer with batch normalization and ReLU activation
class Conv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
super(Conv2d, self).__init__()
self.conv = nn.Conv2d(
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding
)
self.bn = nn.BatchNorm2d(num_features=out_channels)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
# Define a UNet architecture with standard components
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
# Encoder path
self.enc1 = nn.Sequential(
Conv2d(in_channels=1, out_channels=32),
Conv2d(in_channels=32, out_channels=32)
)
self.pool1 = nn.MaxPool2d(kernel_size=2)
self.enc2 = nn.Sequential(
Conv2d(in_channels=32, out_channels=64),
Conv2d(in_channels=64, out_channels=64)
)
self.pool2 = nn.MaxPool2d(kernel_size=2)
self.enc3 = nn.Sequential(
Conv2d(in_channels=64, out_channels=128),
Conv2d(in_channels=128, out_channels=128)
)
self.pool3 = nn.MaxPool2d(kernel_size=2)
self.enc4 = nn.Sequential(
Conv2d(in_channels=128, out_channels=256),
Conv2d(in_channels=256, out_channels=256)
)
self.pool4 = nn.MaxPool2d(kernel_size=2)
self.enc5 = nn.Sequential(
Conv2d(in_channels=256, out_channels=512),
Conv2d(in_channels=512, out_channels=512)
)
self.pool5 = nn.MaxPool2d(kernel_size=2)
# Bottleneck (central) layer
self.bottleneck = nn.Sequential(
Conv2d(in_channels=512, out_channels=1024),
Conv2d(in_channels=1024, out_channels=1024)
)
# Upsampling path with CustomUpsample and decoder layers
self.up5 = CustomUpsample(in_channels=1024, out_channels=512, scale_factor=2)
self.dec5 = nn.Sequential(
Conv2d(in_channels=1024, out_channels=512),
Conv2d(in_channels=512, out_channels=512)
)
self.up4 = CustomUpsample(in_channels=512, out_channels=256, scale_factor=2)
self.dec4 = nn.Sequential(
Conv2d(in_channels=512, out_channels=256),
Conv2d(in_channels=256, out_channels=256)
)
self.up3 = CustomUpsample(in_channels=256, out_channels=128, scale_factor=2)
self.dec3 = nn.Sequential(
Conv2d(in_channels=256, out_channels=128),
Conv2d(in_channels=128, out_channels=128)
)
self.up2 = CustomUpsample(in_channels=128, out_channels=64, scale_factor=2)
self.dec2 = nn.Sequential(
Conv2d(in_channels=128, out_channels=64),
Conv2d(in_channels=64, out_channels=64)
)
self.up1 = CustomUpsample(in_channels=64, out_channels=32, scale_factor=2)
self.dec1 = nn.Sequential(
Conv2d(in_channels=64, out_channels=32),
Conv2d(in_channels=32, out_channels=32)
)
# Final convolutional layer for output
self.final = nn.Conv2d(
in_channels=32, out_channels=1, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
# Forward pass through the network
enc1 = self.enc1(x)
enc2 = self.enc2(self.pool1(enc1))
enc3 = self.enc3(self.pool2(enc2))
enc4 = self.enc4(self.pool3(enc3))
enc5 = self.enc5(self.pool4(enc4))
bottleneck = self.bottleneck(self.pool5(enc5))
up5 = self.up5(bottleneck)
dec5 = self.dec5(torch.cat((up5, enc5), dim=1))
up4 = self.up4(dec5)
dec4 = self.dec4(torch.cat((up4, enc4), dim=1))
up3 = self.up3(dec4)
dec3 = self.dec3(torch.cat((up3, enc3), dim=1))
up2 = self.up2(dec3)
dec2 = self.dec2(torch.cat((up2, enc2), dim=1))
up1 = self.up1(dec2)
dec1 = self.dec1(torch.cat((up1, enc1), dim=1))
output = torch.sigmoid(self.final(dec1))
return output
#model = UNet()
#print(model)
#for (k, _) in model.named_modules():
# print(k)
For the server setup and compilation of the circuit, I am basically following this example https://github.com/zama-ai/concrete-ml/tree/main/use_case_examples/hybrid_model. Could you please check why this happens? Thanks!
P.S. If you need the full code from me, just DM me at Gan in CML channel of FHE discord...i am pretty active there....Thanks!
Best,
Gan