论文 "RAEncoder: A Label-Free Reversible Adversarial Examples Encoder for Dataset Intellectual Property Protection" (CVPR 2024) 的 PyTorch 复现。
raencoder/
├── checkpoints # 存储噪声编码器训练参数
├── data # 存储数据集
├── LOG # 存储训练的日志文件等
├── visual # 可视化脚本结果文件
├── config.py # 超参数配置 (对应论文 Section 4.1.2)
├── dataset.py # 数据集加载 + SSL Encoder + 下游分类器
├── evaluate.py # 评估工具 (ASRProtected / ASRAttacked / SSIM / PSNR)
├── log_utils.py # 终端日志 tee 工具
├── losses.py # 全部损失函数 (Eq. 8-17)
├── main.py # 主入口
├── README.md # 项目说明文件
├── models.py # Noise Encoder E + 高频滤波器 H
├── test_evaluate_protocol.py # 评估协议与统计口径测试
├── train.py # 训练流程 (Figure 3)
└── visualize.py # 可视化数据脚本文件
main.py: 项目命令行主入口,负责解析参数,构建配置,并调度train / protect / evaluate / full四种运行模式。config.py: 定义RAEncoderConfig,集中保存训练、评估、数据加载、下游分类器和输出路径等默认超参数。models.py: 实现论文里的核心网络模块,包括NoiseEncoder以及与高频约束相关的组件。train.py: 实现RAEncoderTrainer,负责初始化模型、生成和管理signature noise、执行训练循环、生成扰动、保存与加载 checkpoint。losses.py: 实现论文 Eq. 8-17 对应的损失函数,包括主攻击损失、干扰模式损失以及损失加权汇总逻辑。dataset.py: 负责数据集读取、数据增强与图像变换、SSL encoder 加载、输入归一化,以及下游分类器构建。evaluate.py: 负责论文口径评估,包含整数据集级SSIM / PSNR统计、ASRProtected / ASRAttacked计算、固定错误噪声生成,以及线性下游头训练。visualize.py: 用于导出可视化结果,包括签名噪声S、扰动δ / δ'、原图、protected 图和 recovered 图。log_utils.py: 提供终端输出同步写入日志文件的工具,确保训练和评估过程可追溯、可复现。test_evaluate_protocol.py: 用单元测试验证当前评估协议的关键统计口径,包括 ASR 分母、全局 PSNR 聚合、SSIM 加权和eval_seed复现性。
RAEncoder 通过一个 Noise Encoder 将固定的 signature noise 映射为 universal adversarial perturbation,叠加到原始图像上生成受保护样本。
- 未授权用户: 使用受保护样本训练模型,性能显著下降
- 授权用户: 拥有正确的 signature noise,通过 Noise Encoder 获取扰动并还原原始图像 (100% fidelity)
- 干扰模式: 确保使用错误密钥无法恢复原始样本
Python >= 3.12
torch >= 2.0
torchvision >= 0.15
# 使用真实 solo-learn SSL encoder 训练 + 评估
python main.py --mode full \
--dataset imagenet \
--pretrain_dataset imagenet \
--ssl_method byol \
--ssl_source solo_learn \
--ssl_checkpoint /path/to/solo_learn/byol_imagenet.ckpt \
--image_size 224
# 仅训练
python main.py --mode train \
--dataset cifar10 \
--pretrain_dataset cifar10 \
--ssl_method byol \
--ssl_source solo_learn \
--ssl_checkpoint /path/to/solo_learn/byol_cifar10.ckpt \
--epochs 20
# 从 checkpoint 评估 (论文默认口径)
python main.py --mode evaluate \
--checkpoint checkpoints/raencoder_cifar10_byol.pt \
--eval_seed 0
# 所有终端输出会自动保存到 LOG/ 目录
python visualize.py --checkpoint checkpoints/raencoder_cifar10_byol.pt
# 自定义参数
python main.py --mode full \
--dataset imagenet \
--pretrain_dataset imagenet \
--ssl_method byol \
--ssl_source solo_learn \
--ssl_checkpoint /path/to/solo_learn/byol_imagenet.ckpt \
--backbone resnet18 \
--epochs 20 \
--lr 0.0002 \
--epsilon 0.0392 \
--batch_size 128 \
--image_size 224| 论文章节 | 代码位置 | 说明 |
|---|---|---|
| Eq. 5-7 | models.py: NoiseEncoder |
E(S)=δ, x^p=x+δ, x^r=x^p-δ |
| Eq. 8 | losses.py: compute_main_loss |
L_E = αL_E_adv + βL_E_H + λL_E_mse |
| Eq. 9 | losses.py: InfoNCELoss |
对抗损失 (negative pairing) |
| Eq. 10 | losses.py: compute_main_loss |
高频约束损失 |
| Eq. 11 | losses.py: compute_main_loss |
MSE 不可区分性损失 |
| Eq. 12-17 | losses.py: compute_interference_loss |
干扰模式损失 |
| Figure 3 | train.py: train_epoch |
完整训练流程 |
| Section 4.1.2 | config.py |
所有超参数 |
| Section 4.1.3 / 4.4 | evaluate.py |
整数据集级 SSIM, PSNR |
| Section 4.2 / 4.3 | evaluate.py |
ASRProtected, ASRAttacked |
-
SSL 预训练 Encoder: 论文复现建议直接提供 solo-learn 的 SSL checkpoint,并通过
--ssl_checkpoint加载。 -
输入归一化: 图像与扰动始终在
[0,1]像素空间中处理;只有在送入 SSL encoder 提取特征时,才会根据pretrain_dataset做 normalize。 -
ImageNet 数据集: 需要手动下载并放置到
data/imagenet/目录。 -
终端日志:
main.py与visualize.py的终端输出会自动写入LOG/目录;如需自定义位置,可通过--log_dir指定。 -
GPU 内存: 默认 batch_size=128 在单张 RTX 4090 上可正常运行。
-
评估口径:
main.py --mode evaluate默认对齐论文口径。视觉质量会按train + test/val全数据集统计protected/recovered的 SSIM 与 PSNR;ASR 会在评估集上输出ASRProtected与ASRAttacked。 -
ImageNet 命名: 当前仓库中的
dataset=imagenet仍表示 ImageNet-100 子集,用于近似复现论文设置;对齐的是评估协议,而不是完整 ImageNet-1K 的数据规模。 -
训练环境: 并非在本机进行,检查代码正常后会将代码上传到第三方云GPU租赁平台进行训练。云平台环境默认 PyTorch 2.5.1 Python 3.12(ubuntu22.04) CUDA 12.4
- 已完成目标编码器训练(ImageNet-100 + BYOL + ResNet18),文件目录(trained_models/byol-imagenet100-zu7pprin-ep=399.ckpt)
- 已完成噪声编码器E训练(以ImageNet-100 + BYOL + ResNet18为目标编码器,受保护的数据集是imagenet100),文件目录(checkpoints/raencoder_imagenet_byol.pt)