Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 2 additions & 11 deletions privacy_guard/shadow_model_training/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
# pyre-strict
"""
Shadow model training package for privacy attacks.

This package provides utilities for training shadow models and performing privacy attacks.
"""
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Shadow model training package for privacy attacks.

from privacy_guard.shadow_model_training.dataset import (
create_shadow_datasets,
load_cifar10,
)
from privacy_guard.shadow_model_training.model import create_model
from privacy_guard.shadow_model_training.training import (
evaluate_model,
get_softmax_scores,
get_transformed_logits,
prepare_lira_data,
prepare_rmia_data,
train_model,
)
from privacy_guard.shadow_model_training.visualization import (
Expand All @@ -28,13 +22,10 @@
"analyze_attack",
"create_model",
"create_shadow_datasets",
"evaluate_model",
"get_softmax_scores",
"get_transformed_logits",
"load_cifar10",
"plot_roc_curve",
"plot_score_distributions",
"prepare_rmia_data",
"prepare_lira_data",
"train_model",
]
78 changes: 77 additions & 1 deletion privacy_guard/shadow_model_training/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
"""

from pathlib import Path
from typing import cast, List, Optional, Protocol, Sequence, Tuple, TypeVar
from typing import cast, List, Optional, Protocol, Sequence, Tuple, TypeVar, Union

import numpy as np
import torch
from torch.utils.data import Dataset, Subset
from torchvision import transforms
from torchvision.datasets import CIFAR10
Expand All @@ -38,6 +39,81 @@ def __getitem__(self, index: int) -> object: ...
DatasetT: TypeVar = TypeVar("DatasetT", bound=Dataset)


class CustomDataset(Dataset):
"""A general-purpose dataset wrapper for user-provided data.

Wraps numpy arrays or PyTorch tensors into a PyTorch Dataset that is
compatible with the PrivacyGuard pipeline (shadow model training,
membership inference attacks, etc.).

Args:
data: Feature data as a numpy array or PyTorch tensor.
Shape should be (n_samples, ...) where ... represents
any feature dimensions (e.g., (n, d) for tabular data
or (n, c, h, w) for images).
targets: Labels as a numpy array or PyTorch tensor.
Shape should be (n_samples,).
transform: Optional callable applied to each data sample
when __getitem__ is called.

Example:
>>> import numpy as np
>>> X = np.random.randn(1000, 10).astype(np.float32)
>>> y = np.random.randint(0, 3, size=1000)
>>> dataset = CustomDataset(X, y)
>>> data, label = dataset[0]
"""

def __init__(
self,
data: Union[np.ndarray, torch.Tensor],
targets: Union[np.ndarray, torch.Tensor],
transform: Optional[object] = None,
) -> None:
if isinstance(data, np.ndarray):
self.data: torch.Tensor = torch.from_numpy(data)
else:
self.data = data

if isinstance(targets, np.ndarray):
self.targets: torch.Tensor = torch.from_numpy(targets).long()
else:
self.targets = targets.long()

if len(self.data) != len(self.targets):
raise ValueError(
f"data and targets must have the same length, "
f"got {len(self.data)} and {len(self.targets)}"
)

if len(self.data) == 0:
raise ValueError("data must not be empty")

self.transform = transform

def __len__(self) -> int:
return len(self.data)

def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
sample = self.data[index]
target = self.targets[index]

if self.transform is not None and callable(self.transform):
sample = self.transform(sample)

return cast(torch.Tensor, sample), target

@property
def num_classes(self) -> int:
"""Return the number of unique classes in the dataset."""
return int(self.targets.unique().numel())

@property
def input_shape(self) -> torch.Size:
"""Return the shape of a single data sample (excluding batch dim)."""
return self.data.shape[1:]


def get_cifar10_transforms() -> Tuple[transforms.Compose, transforms.Compose]:
"""
Get transforms for CIFAR-10 dataset.
Expand Down
100 changes: 95 additions & 5 deletions privacy_guard/shadow_model_training/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
This module provides neural network model definitions for privacy attack experiments.
"""

from typing import List, Optional

import torch
import torch.nn as nn

Expand Down Expand Up @@ -92,12 +94,14 @@ class DeepCNN(nn.Module):
Architecture inspired by ResNet.
"""

def __init__(self, num_classes: int = 10) -> None:
def __init__(self, num_classes: int = 10, input_channels: int = 3) -> None:
super().__init__()

# Initial feature extraction
self.input_block: nn.Sequential = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
nn.Conv2d(
input_channels, 64, kernel_size=3, stride=1, padding=1, bias=False
),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
)
Expand Down Expand Up @@ -176,6 +180,92 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


def create_model() -> DeepCNN:
"""Create a deep CNN model for CIFAR-10 classification."""
return DeepCNN(num_classes=10)
class SimpleMLP(nn.Module):
"""A simple multi-layer perceptron for tabular/flat data.

Suitable for use with CustomDataset when the data is not image-like
(e.g., tabular features, embeddings).

Args:
input_dim: Number of input features.
num_classes: Number of output classes.
hidden_dims: List of hidden layer sizes. Defaults to [256, 128].
"""

def __init__(
self,
input_dim: int,
num_classes: int,
hidden_dims: Optional[List[int]] = None,
) -> None:
super().__init__()

if hidden_dims is None:
hidden_dims = [256, 128]

layers: List[nn.Module] = []
prev_dim = input_dim
for hidden_dim in hidden_dims:
layers.extend(
[
nn.Linear(prev_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.BatchNorm1d(hidden_dim),
]
)
prev_dim = hidden_dim

layers.append(nn.Linear(prev_dim, num_classes))
self.network: nn.Sequential = nn.Sequential(*layers)

self._initialize_weights()

def _initialize_weights(self) -> None:
"""Initialize model weights."""
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.flatten(x, 1)
return self.network(x)


def create_model(num_classes: int = 10, input_channels: int = 3) -> DeepCNN:
"""Create a deep CNN model for image classification.

Args:
num_classes: Number of output classes.
input_channels: Number of input channels (e.g., 3 for RGB, 1 for grayscale).

Returns:
A DeepCNN model instance.
"""
return DeepCNN(num_classes=num_classes, input_channels=input_channels)


def create_mlp_model(
input_dim: int,
num_classes: int,
hidden_dims: Optional[List[int]] = None,
) -> SimpleMLP:
"""Create an MLP model for tabular/flat data classification.

Args:
input_dim: Number of input features.
num_classes: Number of output classes.
hidden_dims: List of hidden layer sizes. Defaults to [256, 128].

Returns:
A SimpleMLP model instance.
"""
return SimpleMLP(
input_dim=input_dim,
num_classes=num_classes,
hidden_dims=hidden_dims,
)
55 changes: 55 additions & 0 deletions privacy_guard/shadow_model_training/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@

import unittest

import numpy as np
import torch
from privacy_guard.shadow_model_training.dataset import (
create_rmia_datasets,
create_shadow_datasets,
CustomDataset,
get_cifar10_transforms,
)
from torch.utils.data import Dataset
Expand Down Expand Up @@ -146,5 +148,58 @@ def test_create_rmia_datasets_minimum_references(self) -> None:
create_rmia_datasets(train_dataset, test_dataset, num_references=2)


class TestCustomDataset(unittest.TestCase):
"""Test cases for the CustomDataset class."""

def test_construction_and_properties(self) -> None:
"""Test creation from numpy/tensors and num_classes/input_shape."""
# From numpy arrays
data_np = np.random.randn(100, 10).astype(np.float32)
targets_np = np.array([0, 1, 2, 3, 4] * 20)
ds = CustomDataset(data_np, targets_np)
self.assertEqual(len(ds), 100)
self.assertEqual(ds.num_classes, 5)
self.assertEqual(ds.input_shape, torch.Size([10]))
sample, label = ds[0]
self.assertIsInstance(sample, torch.Tensor)
self.assertIsInstance(label, torch.Tensor)

# From torch tensors with image-like shape
ds2 = CustomDataset(torch.randn(50, 3, 32, 32), torch.randint(0, 5, (50,)))
self.assertEqual(len(ds2), 50)
self.assertEqual(ds2.input_shape, torch.Size([3, 32, 32]))

def test_validation_errors(self) -> None:
"""Test that invalid inputs raise ValueError."""
with self.assertRaises(ValueError):
CustomDataset(
np.zeros((100, 10), dtype=np.float32), np.zeros(50, dtype=np.int64)
)
with self.assertRaises(ValueError):
CustomDataset(
np.zeros((0, 10), dtype=np.float32), np.array([], dtype=np.int64)
)

def test_transform_and_shadow_integration(self) -> None:
"""Test transforms and compatibility with create_shadow_datasets."""
data = np.random.randn(200, 10).astype(np.float32)
targets = np.random.randint(0, 3, size=200)
transform_called: list[bool] = [False]

def my_transform(x: torch.Tensor) -> torch.Tensor:
transform_called[0] = True
return x * 2.0

dataset = CustomDataset(data, targets, transform=my_transform)
dataset[0]
self.assertTrue(transform_called[0])

shadow_datasets, target_dataset = create_shadow_datasets(
dataset, n_shadows=4, pkeep=0.5, seed=42
)
self.assertEqual(len(shadow_datasets), 3)
self.assertIsInstance(target_dataset, tuple)


if __name__ == "__main__":
unittest.main()
37 changes: 37 additions & 0 deletions privacy_guard/shadow_model_training/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@

import torch
from privacy_guard.shadow_model_training.model import (
create_mlp_model,
create_model,
DeepCNN,
ResidualUnit,
SimpleMLP,
)


Expand Down Expand Up @@ -54,6 +56,41 @@ def test_create_model(self) -> None:
self.assertIsInstance(model, DeepCNN)
self.assertEqual(model.classifier.out_features, 10)

def test_create_model_custom_classes(self) -> None:
"""Test that create_model accepts custom num_classes."""
model = create_model(num_classes=5)
self.assertIsInstance(model, DeepCNN)
self.assertEqual(model.classifier.out_features, 5)

def test_deep_cnn_custom_input_channels(self) -> None:
"""Test DeepCNN with custom input channels."""
model = DeepCNN(num_classes=3, input_channels=1)
x = torch.randn(2, 1, 32, 32)
y = model(x)
self.assertEqual(y.shape, (2, 3))

def test_simple_mlp(self) -> None:
"""Test that SimpleMLP forward pass works."""
model = SimpleMLP(input_dim=20, num_classes=5)
x = torch.randn(4, 20)
y = model(x)
self.assertEqual(y.shape, (4, 5))

def test_simple_mlp_custom_hidden(self) -> None:
"""Test SimpleMLP with custom hidden dimensions."""
model = SimpleMLP(input_dim=50, num_classes=3, hidden_dims=[64, 32, 16])
x = torch.randn(4, 50)
y = model(x)
self.assertEqual(y.shape, (4, 3))

def test_create_mlp_model(self) -> None:
"""Test that create_mlp_model returns a SimpleMLP instance."""
model = create_mlp_model(input_dim=10, num_classes=4)
self.assertIsInstance(model, SimpleMLP)
x = torch.randn(2, 10)
y = model(x)
self.assertEqual(y.shape, (2, 4))


if __name__ == "__main__":
unittest.main()
Loading
Loading