This repository provides the official implementation of the paper:
Virus Infection Attack on LLMs: Your Poisoning Can Spread "VIA" Synthetic Data
[[Paper Link (Coming Soon)]]
- Python version:
>=3.10 - Install dependencies:
pip install -r re.txt
- Hardware requirement: At least one GPU with 80GB memory is recommended to reproduce our main experiments.
./DistributionAnalysis/query_comparison.py
Generate query distribution visualizations and compute poisoning relevance.
-
./defense/defense_methods.py
Defense method implementations. -
./defense/evaluate.py
Evaluation scripts for defense effectiveness.
-
./construct_direct_poisoning_dataset.py
Construct standard data poisoning baseline. -
./construct_worm_sft.py
VIA-based backdoor attack generation. -
./construct_worm_sft_new.py
VIA-based data poisoning generation. -
./backdoor_dataset_construct.py
Build backdoor datasets for experiments.
-
./asr_infer.py
Evaluate ASR for data poisoning attacks. -
./dpabackdoor_eval.py
Evaluate ASR for backdoor attacks.
-
./infer_new.py
Evaluate IR on synthetic data (data poisoning). -
./infec_infer_backdoor.py
Evaluate IR on synthetic data (backdoor attacks).
-
./analyze_sftdataset.py
Perform Hijacking Point Search (HPS) analysis. -
./getEmbedding.py
Extract embeddings for representation-level analysis. -
./ChatwithAPI.py
Query external LLM APIs. -
train.py
Fine-tune LLMs using poisoned or synthetic data. -
./seed.py
Set global random seed for reproducibility.
-
./plot_PPL_dist.py
Visualize PPL-based defense detection results. -
./plot_multigeneration.py
Plot propagation through multiple generations. -
./plot_varyInfectionRateComparison.py
Plot comparisons of IR under different settings (Figure 2). -
./plot_varyNgram_experiment.py
Plot results on varying n-gram size. -
./visualize_hps.py
Visualize HPS score distribution using bar plots.
All experiments follow a three-step pipeline:
- Construct the poisoned dataset
- Train the model
- Evaluate ASR / IR performance
# Backdoor Poisoning
python ./construct_worm_sft.py
# VIA for Data Poisoning Attack
python ./construct_worm_sft_new.py
# Standard Data Poisoning Baseline
python ./construct_direct_poisoning_dataset.pyYou can use the provided scripts to reproduce the experiments:
- Standard Poisoning:
bash ./scripts/3.1.varyPoisoningRate.sh - VIA-based Poisoning:
bash ./scripts/3.2.wormVaryPoisoningRate.sh - Backdoor Poisoning:
bash ./scripts/3.3.backdoorPoisoningVaryPR.sh
You may also define your own experiment. Example:
export python=${HOME}/anaconda3/envs/worm/bin/python3
export TORCH_USE_CUDA_DSA="1"
export root_dir="${HOME}/wormInfection/"
export from_path="meta-llama/Meta-Llama-3-8B"
export CUDA_VISIBLE_DEVICES=0
# Poisoning configurations
export prefix_path=${root_dir}"saved_poison_dataset/"
export pr_ls=(0.025 0.05 0.1 0.2 0.4)
export train_time_ls=(1 2 3)
for pr in ${pr_ls[*]}; do
for train_time in ${train_time_ls[*]}; do
export dataset_name="${prefix_path}allenai_tulu-3-sft-personas-instruction-followinggeneral-person${pr}5000.jsonl"
export savepath_suffix=$(echo "$dataset_name" | tr './' '__' | tr '/' '_')
export save_path="saved_ckpts/VaryPR_Poisoning/${savepath_suffix}tt_${train_time}${from_path}"
echo "---------------------"
echo "Save path: $save_path"
echo "---------------------"
export seed=${train_time}
$python ${root_dir}train.py \
--dataset_name $dataset_name \
--seed $seed \
--epoch 3 \
--acc_step 1 \
--log_step 2000 \
--save_step 5000 \
--overall_step 15000 \
--LR 3e-5 \
--is_lora 1 \
--rank 128 \
--lora_alpha 256 \
--batch_size 1 \
--max_length 2048 \
--from_path $from_path \
--save_path $save_path \
--temp_save_path ${save_path}temp
done
done
echo "RUNNING 3.1.varyPoisoningRate.sh DONE."# ASR Evaluation for Data Poisoning
python asr_infer.py
# IR Evaluation for Data Poisoning
python infer_new.py
# IR Evaluation for Backdoor Attacks
python ./infec_infer_backdoor.py
⚠️ Note: You must uncomment the appropriate evaluation functions in the code. For example,asr_infer.pyincludesmain2_varyPoisoningRate():
def main2_varyPoisoningRate():
base_model_pth = "meta-llama/Meta-Llama-3-8B"
pr_ls = ["0_005", "0_1", "0_025", "0_4"]
train_time_ls = ["1", "2", "3"]
device = "cuda"
res_dict = {}
for pr in pr_ls:
for train_time in train_time_ls:
ckpt = f"saved_ckpts/VaryPR_Poisoning/...{pr}5000_jsonltt_{train_time}meta-llama/Meta-Llama-3-8B/checkpoint-15000/"
clean_dataset = "allenai/tulu-3-sft-personas-instruction-following"
poison_type = "general-person"
asr = ASR_query_eval(ckpt, device, base_model_pth, task_info=poison_type, mnt=512)
res_dict[ckpt] = asr
pprint(res_dict)You should invoke this function inside the main() method.