Skip to content

Commit bf5a147

Browse files
lucamelismeta-codesync[bot]
authored andcommitted
Custom dataset (#113)
Summary: Pull Request resolved: #113 This diff introduces support for custom datasets in PrivacyGuard and provides comprehensive tutorials—both script and notebook formats—on how to utilize this feature. * Modified `dataset.py` in the `privacy_guard/shadow_model_training` directory. Added a new `CustomDataset` class, enabling users to work with custom datasets in PrivacyGuard. * Added `custom_dataset_tutorial.ipynb` in the `privacy_guard/github/tutorial_notebooks` directory. This file provides a tutorial on integrating custom datasets with PrivacyGuard, including an example implementation of a custom dataset class. Reviewed By: iden-kalemaj Differential Revision: D94982552 fbshipit-source-id: cab5f8db639df468750ae76fb0d3e874a2d80f71
1 parent 286f534 commit bf5a147

File tree

7 files changed

+2179
-322
lines changed

7 files changed

+2179
-322
lines changed
Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,15 @@
11
# pyre-strict
2-
"""
3-
Shadow model training package for privacy attacks.
4-
5-
This package provides utilities for training shadow models and performing privacy attacks.
6-
"""
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# Shadow model training package for privacy attacks.
74

85
from privacy_guard.shadow_model_training.dataset import (
96
create_shadow_datasets,
107
load_cifar10,
118
)
129
from privacy_guard.shadow_model_training.model import create_model
1310
from privacy_guard.shadow_model_training.training import (
14-
evaluate_model,
15-
get_softmax_scores,
1611
get_transformed_logits,
1712
prepare_lira_data,
18-
prepare_rmia_data,
1913
train_model,
2014
)
2115
from privacy_guard.shadow_model_training.visualization import (
@@ -28,13 +22,10 @@
2822
"analyze_attack",
2923
"create_model",
3024
"create_shadow_datasets",
31-
"evaluate_model",
32-
"get_softmax_scores",
3325
"get_transformed_logits",
3426
"load_cifar10",
3527
"plot_roc_curve",
3628
"plot_score_distributions",
37-
"prepare_rmia_data",
3829
"prepare_lira_data",
3930
"train_model",
4031
]

privacy_guard/shadow_model_training/dataset.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818
"""
1919

2020
from pathlib import Path
21-
from typing import cast, List, Optional, Protocol, Sequence, Tuple, TypeVar
21+
from typing import cast, List, Optional, Protocol, Sequence, Tuple, TypeVar, Union
2222

2323
import numpy as np
24+
import torch
2425
from torch.utils.data import Dataset, Subset
2526
from torchvision import transforms
2627
from torchvision.datasets import CIFAR10
@@ -38,6 +39,81 @@ def __getitem__(self, index: int) -> object: ...
3839
DatasetT: TypeVar = TypeVar("DatasetT", bound=Dataset)
3940

4041

42+
class CustomDataset(Dataset):
43+
"""A general-purpose dataset wrapper for user-provided data.
44+
45+
Wraps numpy arrays or PyTorch tensors into a PyTorch Dataset that is
46+
compatible with the PrivacyGuard pipeline (shadow model training,
47+
membership inference attacks, etc.).
48+
49+
Args:
50+
data: Feature data as a numpy array or PyTorch tensor.
51+
Shape should be (n_samples, ...) where ... represents
52+
any feature dimensions (e.g., (n, d) for tabular data
53+
or (n, c, h, w) for images).
54+
targets: Labels as a numpy array or PyTorch tensor.
55+
Shape should be (n_samples,).
56+
transform: Optional callable applied to each data sample
57+
when __getitem__ is called.
58+
59+
Example:
60+
>>> import numpy as np
61+
>>> X = np.random.randn(1000, 10).astype(np.float32)
62+
>>> y = np.random.randint(0, 3, size=1000)
63+
>>> dataset = CustomDataset(X, y)
64+
>>> data, label = dataset[0]
65+
"""
66+
67+
def __init__(
68+
self,
69+
data: Union[np.ndarray, torch.Tensor],
70+
targets: Union[np.ndarray, torch.Tensor],
71+
transform: Optional[object] = None,
72+
) -> None:
73+
if isinstance(data, np.ndarray):
74+
self.data: torch.Tensor = torch.from_numpy(data)
75+
else:
76+
self.data = data
77+
78+
if isinstance(targets, np.ndarray):
79+
self.targets: torch.Tensor = torch.from_numpy(targets).long()
80+
else:
81+
self.targets = targets.long()
82+
83+
if len(self.data) != len(self.targets):
84+
raise ValueError(
85+
f"data and targets must have the same length, "
86+
f"got {len(self.data)} and {len(self.targets)}"
87+
)
88+
89+
if len(self.data) == 0:
90+
raise ValueError("data must not be empty")
91+
92+
self.transform = transform
93+
94+
def __len__(self) -> int:
95+
return len(self.data)
96+
97+
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
98+
sample = self.data[index]
99+
target = self.targets[index]
100+
101+
if self.transform is not None and callable(self.transform):
102+
sample = self.transform(sample)
103+
104+
return cast(torch.Tensor, sample), target
105+
106+
@property
107+
def num_classes(self) -> int:
108+
"""Return the number of unique classes in the dataset."""
109+
return int(self.targets.unique().numel())
110+
111+
@property
112+
def input_shape(self) -> torch.Size:
113+
"""Return the shape of a single data sample (excluding batch dim)."""
114+
return self.data.shape[1:]
115+
116+
41117
def get_cifar10_transforms() -> Tuple[transforms.Compose, transforms.Compose]:
42118
"""
43119
Get transforms for CIFAR-10 dataset.

privacy_guard/shadow_model_training/model.py

Lines changed: 95 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
This module provides neural network model definitions for privacy attack experiments.
1717
"""
1818

19+
from typing import List, Optional
20+
1921
import torch
2022
import torch.nn as nn
2123

@@ -92,12 +94,14 @@ class DeepCNN(nn.Module):
9294
Architecture inspired by ResNet.
9395
"""
9496

95-
def __init__(self, num_classes: int = 10) -> None:
97+
def __init__(self, num_classes: int = 10, input_channels: int = 3) -> None:
9698
super().__init__()
9799

98100
# Initial feature extraction
99101
self.input_block: nn.Sequential = nn.Sequential(
100-
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
102+
nn.Conv2d(
103+
input_channels, 64, kernel_size=3, stride=1, padding=1, bias=False
104+
),
101105
nn.BatchNorm2d(64),
102106
nn.ReLU(inplace=True),
103107
)
@@ -176,6 +180,92 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
176180
return x
177181

178182

179-
def create_model() -> DeepCNN:
180-
"""Create a deep CNN model for CIFAR-10 classification."""
181-
return DeepCNN(num_classes=10)
183+
class SimpleMLP(nn.Module):
184+
"""A simple multi-layer perceptron for tabular/flat data.
185+
186+
Suitable for use with CustomDataset when the data is not image-like
187+
(e.g., tabular features, embeddings).
188+
189+
Args:
190+
input_dim: Number of input features.
191+
num_classes: Number of output classes.
192+
hidden_dims: List of hidden layer sizes. Defaults to [256, 128].
193+
"""
194+
195+
def __init__(
196+
self,
197+
input_dim: int,
198+
num_classes: int,
199+
hidden_dims: Optional[List[int]] = None,
200+
) -> None:
201+
super().__init__()
202+
203+
if hidden_dims is None:
204+
hidden_dims = [256, 128]
205+
206+
layers: List[nn.Module] = []
207+
prev_dim = input_dim
208+
for hidden_dim in hidden_dims:
209+
layers.extend(
210+
[
211+
nn.Linear(prev_dim, hidden_dim),
212+
nn.ReLU(inplace=True),
213+
nn.BatchNorm1d(hidden_dim),
214+
]
215+
)
216+
prev_dim = hidden_dim
217+
218+
layers.append(nn.Linear(prev_dim, num_classes))
219+
self.network: nn.Sequential = nn.Sequential(*layers)
220+
221+
self._initialize_weights()
222+
223+
def _initialize_weights(self) -> None:
224+
"""Initialize model weights."""
225+
for m in self.modules():
226+
if isinstance(m, nn.Linear):
227+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
228+
if m.bias is not None:
229+
nn.init.constant_(m.bias, 0)
230+
elif isinstance(m, nn.BatchNorm1d):
231+
nn.init.constant_(m.weight, 1)
232+
nn.init.constant_(m.bias, 0)
233+
234+
def forward(self, x: torch.Tensor) -> torch.Tensor:
235+
x = torch.flatten(x, 1)
236+
return self.network(x)
237+
238+
239+
def create_model(num_classes: int = 10, input_channels: int = 3) -> DeepCNN:
240+
"""Create a deep CNN model for image classification.
241+
242+
Args:
243+
num_classes: Number of output classes.
244+
input_channels: Number of input channels (e.g., 3 for RGB, 1 for grayscale).
245+
246+
Returns:
247+
A DeepCNN model instance.
248+
"""
249+
return DeepCNN(num_classes=num_classes, input_channels=input_channels)
250+
251+
252+
def create_mlp_model(
253+
input_dim: int,
254+
num_classes: int,
255+
hidden_dims: Optional[List[int]] = None,
256+
) -> SimpleMLP:
257+
"""Create an MLP model for tabular/flat data classification.
258+
259+
Args:
260+
input_dim: Number of input features.
261+
num_classes: Number of output classes.
262+
hidden_dims: List of hidden layer sizes. Defaults to [256, 128].
263+
264+
Returns:
265+
A SimpleMLP model instance.
266+
"""
267+
return SimpleMLP(
268+
input_dim=input_dim,
269+
num_classes=num_classes,
270+
hidden_dims=hidden_dims,
271+
)

privacy_guard/shadow_model_training/tests/test_dataset.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616

1717
import unittest
1818

19+
import numpy as np
1920
import torch
2021
from privacy_guard.shadow_model_training.dataset import (
2122
create_rmia_datasets,
2223
create_shadow_datasets,
24+
CustomDataset,
2325
get_cifar10_transforms,
2426
)
2527
from torch.utils.data import Dataset
@@ -146,5 +148,58 @@ def test_create_rmia_datasets_minimum_references(self) -> None:
146148
create_rmia_datasets(train_dataset, test_dataset, num_references=2)
147149

148150

151+
class TestCustomDataset(unittest.TestCase):
152+
"""Test cases for the CustomDataset class."""
153+
154+
def test_construction_and_properties(self) -> None:
155+
"""Test creation from numpy/tensors and num_classes/input_shape."""
156+
# From numpy arrays
157+
data_np = np.random.randn(100, 10).astype(np.float32)
158+
targets_np = np.array([0, 1, 2, 3, 4] * 20)
159+
ds = CustomDataset(data_np, targets_np)
160+
self.assertEqual(len(ds), 100)
161+
self.assertEqual(ds.num_classes, 5)
162+
self.assertEqual(ds.input_shape, torch.Size([10]))
163+
sample, label = ds[0]
164+
self.assertIsInstance(sample, torch.Tensor)
165+
self.assertIsInstance(label, torch.Tensor)
166+
167+
# From torch tensors with image-like shape
168+
ds2 = CustomDataset(torch.randn(50, 3, 32, 32), torch.randint(0, 5, (50,)))
169+
self.assertEqual(len(ds2), 50)
170+
self.assertEqual(ds2.input_shape, torch.Size([3, 32, 32]))
171+
172+
def test_validation_errors(self) -> None:
173+
"""Test that invalid inputs raise ValueError."""
174+
with self.assertRaises(ValueError):
175+
CustomDataset(
176+
np.zeros((100, 10), dtype=np.float32), np.zeros(50, dtype=np.int64)
177+
)
178+
with self.assertRaises(ValueError):
179+
CustomDataset(
180+
np.zeros((0, 10), dtype=np.float32), np.array([], dtype=np.int64)
181+
)
182+
183+
def test_transform_and_shadow_integration(self) -> None:
184+
"""Test transforms and compatibility with create_shadow_datasets."""
185+
data = np.random.randn(200, 10).astype(np.float32)
186+
targets = np.random.randint(0, 3, size=200)
187+
transform_called: list[bool] = [False]
188+
189+
def my_transform(x: torch.Tensor) -> torch.Tensor:
190+
transform_called[0] = True
191+
return x * 2.0
192+
193+
dataset = CustomDataset(data, targets, transform=my_transform)
194+
dataset[0]
195+
self.assertTrue(transform_called[0])
196+
197+
shadow_datasets, target_dataset = create_shadow_datasets(
198+
dataset, n_shadows=4, pkeep=0.5, seed=42
199+
)
200+
self.assertEqual(len(shadow_datasets), 3)
201+
self.assertIsInstance(target_dataset, tuple)
202+
203+
149204
if __name__ == "__main__":
150205
unittest.main()

privacy_guard/shadow_model_training/tests/test_model.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818

1919
import torch
2020
from privacy_guard.shadow_model_training.model import (
21+
create_mlp_model,
2122
create_model,
2223
DeepCNN,
2324
ResidualUnit,
25+
SimpleMLP,
2426
)
2527

2628

@@ -54,6 +56,41 @@ def test_create_model(self) -> None:
5456
self.assertIsInstance(model, DeepCNN)
5557
self.assertEqual(model.classifier.out_features, 10)
5658

59+
def test_create_model_custom_classes(self) -> None:
60+
"""Test that create_model accepts custom num_classes."""
61+
model = create_model(num_classes=5)
62+
self.assertIsInstance(model, DeepCNN)
63+
self.assertEqual(model.classifier.out_features, 5)
64+
65+
def test_deep_cnn_custom_input_channels(self) -> None:
66+
"""Test DeepCNN with custom input channels."""
67+
model = DeepCNN(num_classes=3, input_channels=1)
68+
x = torch.randn(2, 1, 32, 32)
69+
y = model(x)
70+
self.assertEqual(y.shape, (2, 3))
71+
72+
def test_simple_mlp(self) -> None:
73+
"""Test that SimpleMLP forward pass works."""
74+
model = SimpleMLP(input_dim=20, num_classes=5)
75+
x = torch.randn(4, 20)
76+
y = model(x)
77+
self.assertEqual(y.shape, (4, 5))
78+
79+
def test_simple_mlp_custom_hidden(self) -> None:
80+
"""Test SimpleMLP with custom hidden dimensions."""
81+
model = SimpleMLP(input_dim=50, num_classes=3, hidden_dims=[64, 32, 16])
82+
x = torch.randn(4, 50)
83+
y = model(x)
84+
self.assertEqual(y.shape, (4, 3))
85+
86+
def test_create_mlp_model(self) -> None:
87+
"""Test that create_mlp_model returns a SimpleMLP instance."""
88+
model = create_mlp_model(input_dim=10, num_classes=4)
89+
self.assertIsInstance(model, SimpleMLP)
90+
x = torch.randn(2, 10)
91+
y = model(x)
92+
self.assertEqual(y.shape, (2, 4))
93+
5794

5895
if __name__ == "__main__":
5996
unittest.main()

0 commit comments

Comments
 (0)