Skip to content
This repository was archived by the owner on Apr 29, 2026. It is now read-only.

ThornsW/RAEncoder

Repository files navigation

RAEncoder: A Label-Free Reversible Adversarial Examples Encoder

论文 "RAEncoder: A Label-Free Reversible Adversarial Examples Encoder for Dataset Intellectual Property Protection" (CVPR 2024) 的 PyTorch 复现。

项目结构

raencoder/
├── checkpoints     # 存储噪声编码器训练参数
├── data            # 存储数据集
├── LOG             # 存储训练的日志文件等
├── visual          # 可视化脚本结果文件
├── config.py       # 超参数配置 (对应论文 Section 4.1.2)
├── dataset.py      # 数据集加载 + SSL Encoder + 下游分类器
├── evaluate.py     # 评估工具 (ASRProtected / ASRAttacked / SSIM / PSNR)
├── log_utils.py    # 终端日志 tee 工具
├── losses.py       # 全部损失函数 (Eq. 8-17)
├── main.py         # 主入口
├── README.md       # 项目说明文件
├── models.py       # Noise Encoder E + 高频滤波器 H
├── test_evaluate_protocol.py # 评估协议与统计口径测试
├── train.py        # 训练流程 (Figure 3)
└── visualize.py    # 可视化数据脚本文件

各 Python 文件功能说明

  • main.py: 项目命令行主入口,负责解析参数,构建配置,并调度 train / protect / evaluate / full 四种运行模式。
  • config.py: 定义 RAEncoderConfig,集中保存训练、评估、数据加载、下游分类器和输出路径等默认超参数。
  • models.py: 实现论文里的核心网络模块,包括 NoiseEncoder 以及与高频约束相关的组件。
  • train.py: 实现 RAEncoderTrainer,负责初始化模型、生成和管理 signature noise、执行训练循环、生成扰动、保存与加载 checkpoint。
  • losses.py: 实现论文 Eq. 8-17 对应的损失函数,包括主攻击损失、干扰模式损失以及损失加权汇总逻辑。
  • dataset.py: 负责数据集读取、数据增强与图像变换、SSL encoder 加载、输入归一化,以及下游分类器构建。
  • evaluate.py: 负责论文口径评估,包含整数据集级 SSIM / PSNR 统计、ASRProtected / ASRAttacked 计算、固定错误噪声生成,以及线性下游头训练。
  • visualize.py: 用于导出可视化结果,包括签名噪声 S、扰动 δ / δ'、原图、protected 图和 recovered 图。
  • log_utils.py: 提供终端输出同步写入日志文件的工具,确保训练和评估过程可追溯、可复现。
  • test_evaluate_protocol.py: 用单元测试验证当前评估协议的关键统计口径,包括 ASR 分母、全局 PSNR 聚合、SSIM 加权和 eval_seed 复现性。

核心思路

RAEncoder 通过一个 Noise Encoder 将固定的 signature noise 映射为 universal adversarial perturbation,叠加到原始图像上生成受保护样本。

  • 未授权用户: 使用受保护样本训练模型,性能显著下降
  • 授权用户: 拥有正确的 signature noise,通过 Noise Encoder 获取扰动并还原原始图像 (100% fidelity)
  • 干扰模式: 确保使用错误密钥无法恢复原始样本

环境要求

Python >= 3.12
torch >= 2.0
torchvision >= 0.15

快速开始

# 使用真实 solo-learn SSL encoder 训练 + 评估
python main.py --mode full \
    --dataset imagenet \
    --pretrain_dataset imagenet \
    --ssl_method byol \
    --ssl_source solo_learn \
    --ssl_checkpoint /path/to/solo_learn/byol_imagenet.ckpt \
    --image_size 224

# 仅训练
python main.py --mode train \
    --dataset cifar10 \
    --pretrain_dataset cifar10 \
    --ssl_method byol \
    --ssl_source solo_learn \
    --ssl_checkpoint /path/to/solo_learn/byol_cifar10.ckpt \
    --epochs 20

# 从 checkpoint 评估 (论文默认口径)
python main.py --mode evaluate \
    --checkpoint checkpoints/raencoder_cifar10_byol.pt \
    --eval_seed 0

# 所有终端输出会自动保存到 LOG/ 目录
python visualize.py --checkpoint checkpoints/raencoder_cifar10_byol.pt

# 自定义参数
python main.py --mode full \
    --dataset imagenet \
    --pretrain_dataset imagenet \
    --ssl_method byol \
    --ssl_source solo_learn \
    --ssl_checkpoint /path/to/solo_learn/byol_imagenet.ckpt \
    --backbone resnet18 \
    --epochs 20 \
    --lr 0.0002 \
    --epsilon 0.0392 \
    --batch_size 128 \
    --image_size 224

代码与论文对应关系

论文章节 代码位置 说明
Eq. 5-7 models.py: NoiseEncoder E(S)=δ, x^p=x+δ, x^r=x^p-δ
Eq. 8 losses.py: compute_main_loss L_E = αL_E_adv + βL_E_H + λL_E_mse
Eq. 9 losses.py: InfoNCELoss 对抗损失 (negative pairing)
Eq. 10 losses.py: compute_main_loss 高频约束损失
Eq. 11 losses.py: compute_main_loss MSE 不可区分性损失
Eq. 12-17 losses.py: compute_interference_loss 干扰模式损失
Figure 3 train.py: train_epoch 完整训练流程
Section 4.1.2 config.py 所有超参数
Section 4.1.3 / 4.4 evaluate.py 整数据集级 SSIM, PSNR
Section 4.2 / 4.3 evaluate.py ASRProtected, ASRAttacked

注意事项

  1. SSL 预训练 Encoder: 论文复现建议直接提供 solo-learn 的 SSL checkpoint,并通过 --ssl_checkpoint 加载。

  2. 输入归一化: 图像与扰动始终在 [0,1] 像素空间中处理;只有在送入 SSL encoder 提取特征时,才会根据 pretrain_dataset 做 normalize。

  3. ImageNet 数据集: 需要手动下载并放置到 data/imagenet/ 目录。

  4. 终端日志: main.pyvisualize.py 的终端输出会自动写入 LOG/ 目录;如需自定义位置,可通过 --log_dir 指定。

  5. GPU 内存: 默认 batch_size=128 在单张 RTX 4090 上可正常运行。

  6. 评估口径: main.py --mode evaluate 默认对齐论文口径。视觉质量会按 train + test/val 全数据集统计 protected/recovered 的 SSIM 与 PSNR;ASR 会在评估集上输出 ASRProtectedASRAttacked

  7. ImageNet 命名: 当前仓库中的 dataset=imagenet 仍表示 ImageNet-100 子集,用于近似复现论文设置;对齐的是评估协议,而不是完整 ImageNet-1K 的数据规模。

  8. 训练环境: 并非在本机进行,检查代码正常后会将代码上传到第三方云GPU租赁平台进行训练。云平台环境默认 PyTorch 2.5.1 Python 3.12(ubuntu22.04) CUDA 12.4

当前任务进度

  1. 已完成目标编码器训练(ImageNet-100 + BYOL + ResNet18),文件目录(trained_models/byol-imagenet100-zu7pprin-ep=399.ckpt)
  2. 已完成噪声编码器E训练(以ImageNet-100 + BYOL + ResNet18为目标编码器,受保护的数据集是imagenet100),文件目录(checkpoints/raencoder_imagenet_byol.pt)

任务流程

https://app.xmind.com/wal7VVRa

About

借鉴生成对抗网络中的对抗设计,在生成可逆对抗样本上的应用

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages