Skip to content

[Public issue] AssertionError assert repr_input_shape in self.clients in hybrid model with multiple submodules #844

@gy-cao

Description

@gy-cao

Hi!

I am trying to use the hybrid model with multiple submodules to be evaluated in FHE.

I tried putting two submodules and the compilation of the circuit is ok.

But when I run the inference, it seems that the first submodules can be evaluated.

But when evaluating the second one, it shows assert repr_input_shape in self.clients AssertionError on the client side.

So roughly, on the client side, I did something like

# model.pth is the file saved by save_and_clear_private_info
model = torch.load('./compiled_models/net/model.pth')  
hybrid_model = HybridFHEModel(
        model,
        module_names,
        server_remote_address="http://0.0.0.0:8000",
        model_name=model_name,
        verbose=False,
    )
path_to_clients = Path(__file__).parent / "clients"
hybrid_model.init_client(path_to_clients=path_to_clients)
hybrid_model.set_fhe_mode(HybridFHEMode.REMOTE)

I am using the hybrid model for a unet. The submodules I put is up1.pixel_shuffle and final. My Unet is as follows:

import torch
import torch.nn as nn
import torch.nn.utils.prune as prune

# Define a customized upsample module for standard use
class CustomPixelShuffle(nn.Module):
    def __init__(self, upscale_factor):
        super(CustomPixelShuffle, self).__init__()
        self.upscale_factor = upscale_factor

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        upscale_factor = self.upscale_factor
        channels //= (upscale_factor ** 2)

        x = x.view(batch_size, channels, upscale_factor, upscale_factor, height, width)
        x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
        x = x.view(batch_size, channels, height * upscale_factor, width * upscale_factor)
        return x

class CustomUpsample(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor):
        super(CustomUpsample, self).__init__()
        self.scale_factor = scale_factor
        self.pixel_shuffle = CustomPixelShuffle(scale_factor)
        self.conv = nn.Conv2d(
            in_channels // (scale_factor ** 2), out_channels, kernel_size=3, padding=1
        )

    def forward(self, x):
        x = self.pixel_shuffle(x)
        x = self.conv(x)
        return x

# Define a standard convolutional layer with batch normalization and ReLU activation
class Conv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(Conv2d, self).__init__()
        self.conv = nn.Conv2d(
            in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding
        )
        self.bn = nn.BatchNorm2d(num_features=out_channels)
        self.relu = nn.ReLU() 

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

# Define a UNet architecture with standard components
class UNet(nn.Module):
    def __init__(self): 
        super(UNet, self).__init__()

        # Encoder path
        self.enc1 = nn.Sequential(
            Conv2d(in_channels=1, out_channels=32),  
            Conv2d(in_channels=32, out_channels=32)
        )
        self.pool1 = nn.MaxPool2d(kernel_size=2)

        self.enc2 = nn.Sequential(
            Conv2d(in_channels=32, out_channels=64),
            Conv2d(in_channels=64, out_channels=64) 
        )
        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.enc3 = nn.Sequential(
            Conv2d(in_channels=64, out_channels=128), 
            Conv2d(in_channels=128, out_channels=128) 
        )
        self.pool3 = nn.MaxPool2d(kernel_size=2)

        self.enc4 = nn.Sequential(
            Conv2d(in_channels=128, out_channels=256), 
            Conv2d(in_channels=256, out_channels=256) 
        )
        self.pool4 = nn.MaxPool2d(kernel_size=2)

        self.enc5 = nn.Sequential(
            Conv2d(in_channels=256, out_channels=512),
            Conv2d(in_channels=512, out_channels=512)
        )
        self.pool5 = nn.MaxPool2d(kernel_size=2)

        # Bottleneck (central) layer
        self.bottleneck = nn.Sequential(
            Conv2d(in_channels=512, out_channels=1024), 
            Conv2d(in_channels=1024, out_channels=1024) 
        )

        # Upsampling path with CustomUpsample and decoder layers
        self.up5 = CustomUpsample(in_channels=1024, out_channels=512, scale_factor=2)
        self.dec5 = nn.Sequential(
            Conv2d(in_channels=1024, out_channels=512), 
            Conv2d(in_channels=512, out_channels=512)  
        )

        self.up4 = CustomUpsample(in_channels=512, out_channels=256, scale_factor=2)
        self.dec4 = nn.Sequential(
            Conv2d(in_channels=512, out_channels=256), 
            Conv2d(in_channels=256, out_channels=256)  
        )

        self.up3 = CustomUpsample(in_channels=256, out_channels=128, scale_factor=2)
        self.dec3 = nn.Sequential(
            Conv2d(in_channels=256, out_channels=128), 
            Conv2d(in_channels=128, out_channels=128)  
        )

        self.up2 = CustomUpsample(in_channels=128, out_channels=64, scale_factor=2)
        self.dec2 = nn.Sequential(
            Conv2d(in_channels=128, out_channels=64),  
            Conv2d(in_channels=64, out_channels=64)
        )

        self.up1 = CustomUpsample(in_channels=64, out_channels=32, scale_factor=2)
        self.dec1 = nn.Sequential(
            Conv2d(in_channels=64, out_channels=32), 
            Conv2d(in_channels=32, out_channels=32) 
        )

        # Final convolutional layer for output
        self.final = nn.Conv2d(
            in_channels=32, out_channels=1, kernel_size=1, stride=1, padding=0
        )

    def forward(self, x):
        # Forward pass through the network
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool1(enc1))
        enc3 = self.enc3(self.pool2(enc2))
        enc4 = self.enc4(self.pool3(enc3))
        enc5 = self.enc5(self.pool4(enc4))

        bottleneck = self.bottleneck(self.pool5(enc5))

        up5 = self.up5(bottleneck)
        dec5 = self.dec5(torch.cat((up5, enc5), dim=1))

        up4 = self.up4(dec5)
        dec4 = self.dec4(torch.cat((up4, enc4), dim=1))

        up3 = self.up3(dec4)
        dec3 = self.dec3(torch.cat((up3, enc3), dim=1))

        up2 = self.up2(dec3)
        dec2 = self.dec2(torch.cat((up2, enc2), dim=1))

        up1 = self.up1(dec2)
        dec1 = self.dec1(torch.cat((up1, enc1), dim=1))

        output = torch.sigmoid(self.final(dec1))
        return output

#model = UNet()
#print(model)
#for (k, _) in model.named_modules():
#    print(k)

For the server setup and compilation of the circuit, I am basically following this example https://github.com/zama-ai/concrete-ml/tree/main/use_case_examples/hybrid_model. Could you please check why this happens? Thanks!

P.S. If you need the full code from me, just DM me at Gan in CML channel of FHE discord...i am pretty active there....Thanks!

Best,
Gan

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions