From cfeb421a11b8569b269f71340c3a5bfb687db805 Mon Sep 17 00:00:00 2001 From: Dariush Wahdany Date: Tue, 25 Jun 2024 19:07:31 +0200 Subject: [PATCH 1/5] feat: DP Prototypical Learning https://arxiv.org/abs/2406.08039 --- research/dppl_2024/.DS_Store | Bin 0 -> 6148 bytes research/dppl_2024/README.md | 88 ++++++ research/dppl_2024/conf/common.yaml | 9 + research/dppl_2024/conf/mean.yaml | 4 + research/dppl_2024/conf/public.yaml | 9 + research/dppl_2024/conf/public_topk.yaml | 10 + research/dppl_2024/dppl_mean.py | 53 ++++ research/dppl_2024/dppl_public.py | 71 +++++ research/dppl_2024/dppl_public_topk.py | 83 ++++++ research/dppl_2024/env.yaml | 72 +++++ research/dppl_2024/hparams_mean.md | 322 ++++++++++++++++++++++ research/dppl_2024/hparams_public.md | 258 +++++++++++++++++ research/dppl_2024/hparams_public_topk.md | 258 +++++++++++++++++ research/dppl_2024/lib/__init__.py | 0 research/dppl_2024/lib/coinpress.py | 77 ++++++ research/dppl_2024/lib/public.py | 150 ++++++++++ research/dppl_2024/lib/utils.py | 140 ++++++++++ research/dppl_2024/requirements.txt | 43 +++ 18 files changed, 1647 insertions(+) create mode 100644 research/dppl_2024/.DS_Store create mode 100644 research/dppl_2024/README.md create mode 100644 research/dppl_2024/conf/common.yaml create mode 100644 research/dppl_2024/conf/mean.yaml create mode 100644 research/dppl_2024/conf/public.yaml create mode 100644 research/dppl_2024/conf/public_topk.yaml create mode 100644 research/dppl_2024/dppl_mean.py create mode 100644 research/dppl_2024/dppl_public.py create mode 100644 research/dppl_2024/dppl_public_topk.py create mode 100644 research/dppl_2024/env.yaml create mode 100644 research/dppl_2024/hparams_mean.md create mode 100644 research/dppl_2024/hparams_public.md create mode 100644 research/dppl_2024/hparams_public_topk.md create mode 100644 research/dppl_2024/lib/__init__.py create mode 100644 research/dppl_2024/lib/coinpress.py create mode 100644 research/dppl_2024/lib/public.py create mode 100644 research/dppl_2024/lib/utils.py create mode 100644 research/dppl_2024/requirements.txt diff --git a/research/dppl_2024/.DS_Store b/research/dppl_2024/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..878677845a65fa897646859996ebdbbabb735c39 GIT binary patch literal 6148 zcmeH~JqiLr422WjLa^D=avBfd4F=H@cmYwd61NchIl3=D2(H#5@&d^>$xK-G6+0Ud z(e?eb66r-`1~ zw}U0m)np4syJ!v{nom}nVqhBWq6G;|tAl|GP=S#G)5v?f|F`f@^Z%%YDHWgsf2M$T zo84xOm&&{K?e(m_&#J8(9Q5M|FFyfD>?&Ts-LPM50oG&-q5|WOfXl!@1-`1l11psh AsQ>@~ literal 0 HcmV?d00001 diff --git a/research/dppl_2024/README.md b/research/dppl_2024/README.md new file mode 100644 index 00000000..88b51e40 --- /dev/null +++ b/research/dppl_2024/README.md @@ -0,0 +1,88 @@ +# Supplemental Material: Code Submission + + +## Paper Title: *Beyond the Mean: Differentially Private Prototypes for Private Transfer Learning* + +Abstract: +Machine learning (ML) models have been shown to leak private information from their training datasets. Differential Privacy (DP), typically implemented through the differential private stochastic gradient descent algorithm (DP-SGD), has become the standard solution to bound leakage from the models. Despite recent improvments, DP-SGD-based approaches for private learning still usually struggle in the high privacy ($\varepsilon<0.1$) and low data regimes, and when the private training datasets are imbalanced. To overcome these limitations, we propose Differentially Private Prototype Learning (DPPL) as a new paradigm for private transfer learning. DPPL leverages publicly pre-trained encoders to extract features from private data and generates DP prototypes that represent each private class in the embedding space and can be publicly released for inference. Since our DP prototypes can be obtained from only a few private training data points and without iterative noise addition, they offer high-utility predictions and strong privacy guarantees even under the notion of pure DP. We additionally show that privacy-utility trade-offs can be further improved when leveraging the public data beyond pre-training of the encoder: we are able to privately sample our DP prototypes from the publicly available data points used to train the encoder. Our experimental evaluation with four state-of-the-art encoders, four vision datasets, and under different data and unbalancedness regimes demonstrate DPPL's high performance under strong privacy guarantees in challenging private learning setups. + + + +## Table of Contents + +- [Installation](#installation) +- [Description](#description) +- [Usage](#usage) +- [Contributing](#contributing) +- [License](#license) + +## Installation + +### Conda +```bash +conda env create -f env.yaml +``` + +### Pip +```bash +pip install -r requirements.txt +``` + +## Description + +### Imbalanced Datasets +We construct the imbalanced datasets in `lib.utils.give_imbalanced_set`. The function places an upper bound on the number of samples per class according to the minimum number of samples per class. So for an imbalance ratio of $1$, the dataset is actually balanced. `lib.utils.decay` implements the decay function $f(c)=N\exp{-\lambda c}$. The class indices are shuffled depending on the seed, therefore whether classes are part of the majority or minority classes is random. + +### DPPL-Mean +The implementation of **DPPL-Mean** can be found in `dppl_mean.py`. We first load the private dataset, average-pool its features and obtain imbalanced datasets as described above. +The private mean estimation occurs using the Jax-reimplementation of [*CoinPress*](https://proceedings.neurips.cc/paper_files/paper/2020/hash/a684eceee76fc522773286a895bc8436-Abstract.html) in `lib.utils.coinpress`. + +### DPPL-Public +The implementation of **DPPL-Public** can be found in `dppl_public.py`. We first load the private dataset and obtain imbalanced datasets as described above. The scores are computed using `lib.utils.pairwise_distance`, a function returning cosine distances $\in [0,2]$. `lib.utils.scores_single` implements the score calculation for a single public sample, by substracting the distance to each private sample from $2$, clipping the result to $[d_{\min},d_{\max}]$ and normalizing it to $[0,1]$, before summing over all the private samples. In our implementation the sensitivity is therefore always $1$, but the mechanism is identical to one where the scores are not normalized to $[0,1]$ and the sensitivity is reduced instead. + +Finally, given the scores `lib.public.exponential` implements the exponential mechanism. Depending on whether the utility function is monotonic or not, we multiply the sensitivity by $2$ to achieve $\epsilon$-DP. For numerical reasons, the substract from all exponents the maximum exponent. Since this is the constant factor $\exp(-c)$ for all samples, the proportionality of the probalities and therefore the mechanism doesn't change, since the exponential mechanism is invariant to scaling of the utility function. + +### DPPL-Public Top-K +The implementation of **DPPL-Public Top-K** can be found in `dppl_public_topk.py`. We first load the private dataset and obtain imbalanced datasets as described above. The scores are computed as in [DPPL-Public](#dppl-public). Our unordered top-K selection is implemented using the efficient sampling algorithm from [Duff](http://arxiv.org/abs/2010.04235) (Prop. 5). `lib.public.give_topk_proto_idx` returns the indices of the prototypes w.r.t. to the order of C, i.e. if it returns $0$ it means the best utility, $1$ the second best and so on. To do so, the utility is sampled with `lib.public.exponential_parallel` using the exponential mechanism in parallel for all classes. The remainder of `lib.public.give_topk_proto_idx` is just to uniformly sample the remaining $K-1$ prototypes, s.t. their utility is higher than the sampled one. + +### Hyperparameters +We provide the hyperparameters for the models and datasets we used in `hparams_mean.md`, `hparams_public.md` and `hparams_public_topk.md`. + +## Usage + +Before running any of the experiments, set the path to your embeddings in `config/common.yaml`. Further options are +- Epsilon +- Imbalance Ratio +- Seed + +### DPPL-Mean +(Optional): In `config/mean.yaml`, change `pool` to any desired integer value. It configures the optional average pooling before the mean estimation and can improve utility especially at strict privacy budgets. + +```bash +python dppl_mean.py +``` +### DPPL-Public +(Optional): In `config/public.yaml`, change `max_score` and `min_score` to any desired values in [0,2], s.t. min_score < max_score. It defines the clipping of the scores and can improve utility especially at strict privacy budgets. + +**Required**: In `config/public.yaml`, change `dataset.public_data` to the path to your public dataset embeddings. + + +```bash +python dppl_mean.py +``` + +### DPPL-Public Top-K +(Optional): In `config/public_topk.yaml`, change `max_score` and `min_score` to any desired values in [0,2], s.t. min_score < max_score. It defines the clipping of the scores and can improve utility especially at strict privacy budgets. Also, change `k` to any integer value. It defines how many prototypes are selected per class and can improve utility especially at lower privacy regimes. + +**Required**: In `config/public_topk.yaml`, change `dataset.public_data` to the path to your public dataset embeddings. + + +```bash +python dppl_public_topk.py +``` + +## Contributing +We welcome any feedback during the review process. + +## License +Submitted to 38th Conference on Neural Information Processing Systems (NeurIPS 2024). Do not distribute diff --git a/research/dppl_2024/conf/common.yaml b/research/dppl_2024/conf/common.yaml new file mode 100644 index 00000000..afe9e980 --- /dev/null +++ b/research/dppl_2024/conf/common.yaml @@ -0,0 +1,9 @@ +seed: 42 +dataset: + train_data: "embeddings/vit_h_14_cifar100_train.npy" + train_labels: "embeddings/vit_h_14_cifar100_train_targets.npy" + test_data: "embeddings/vit_h_14_cifar100_test.npy" + test_labels: "embeddings/vit_h_14_cifar100_test_targets.npy" + +imbalance_ratio: 1 +epsilon: 0.5 diff --git a/research/dppl_2024/conf/mean.yaml b/research/dppl_2024/conf/mean.yaml new file mode 100644 index 00000000..7b342957 --- /dev/null +++ b/research/dppl_2024/conf/mean.yaml @@ -0,0 +1,4 @@ +defaults: + - common + - _self_ +pool: 1 diff --git a/research/dppl_2024/conf/public.yaml b/research/dppl_2024/conf/public.yaml new file mode 100644 index 00000000..dc4cf203 --- /dev/null +++ b/research/dppl_2024/conf/public.yaml @@ -0,0 +1,9 @@ +defaults: + - common + - _self_ + +dataset: + public_data: "embeddings/vit_h_14_imagenet64.npy" + +max_score: 1.65 +min_score: 1.35 diff --git a/research/dppl_2024/conf/public_topk.yaml b/research/dppl_2024/conf/public_topk.yaml new file mode 100644 index 00000000..0b6abde7 --- /dev/null +++ b/research/dppl_2024/conf/public_topk.yaml @@ -0,0 +1,10 @@ +defaults: + - common + - _self_ + +dataset: + public_data: "embeddings/vit_h_14_imagenet64.npy" + +k: 5 +max_score: 1.65 +min_score: 1.35 diff --git a/research/dppl_2024/dppl_mean.py b/research/dppl_2024/dppl_mean.py new file mode 100644 index 00000000..35558930 --- /dev/null +++ b/research/dppl_2024/dppl_mean.py @@ -0,0 +1,53 @@ +import flax.linen.pooling as pooling +import hydra +import jax +import jax.numpy as jnp +from lib import coinpress, utils +from omegaconf import DictConfig, OmegaConf + + +@hydra.main(config_path="conf", config_name="mean", version_base=None) +def main(cfg: DictConfig): + print(OmegaConf.to_yaml(cfg)) + + X_train, Y_train, X_test, Y_test = utils.load_dataset(cfg) + X_train = pooling.avg_pool( + X_train.T, window_shape=(cfg.pool,), strides=(cfg.pool,) + ).T + X_test = pooling.avg_pool(X_test.T, window_shape=(cfg.pool,), strides=(cfg.pool,)).T + x_imbalanced, y_imbalanced = utils.give_imbalanced_set( + X_train, Y_train, cfg.imbalance_ratio + ) + classes = jnp.unique(y_imbalanced) + if cfg.epsilon < jnp.inf: + rho = utils.zcdp_of_naive_epsilon(cfg.epsilon) + Ps = jnp.array([5 / 64, 7 / 64, 52 / 64]) * rho + key = jax.random.key(cfg.seed) + class_keys = jax.random.split(key, len(classes)) + r = jnp.sqrt(x_imbalanced.shape[1]) + protos = jnp.stack( + [ + coinpress.private_mean_jit( + x_imbalanced[y_imbalanced == i], Ps, key=class_keys[i], r=r + ) + for i in classes + ] + ) + else: + protos = jnp.stack( + [x_imbalanced[y_imbalanced == i].mean(axis=0) for i in classes] + ) + dists_test = utils.pairwise_distance(protos, X_test) + test_acc = float((dists_test.argmin(axis=0) == Y_test).mean()) + test_acc_per_class = jnp.stack( + [ + (dists_test[..., Y_test == target].argmin(axis=0) == target).mean() + for target in classes + ] + ) + print(f"Test accuracy: {test_acc}") + print(f"Test accuracy per class: {test_acc_per_class}") + + +if __name__ == "__main__": + main() diff --git a/research/dppl_2024/dppl_public.py b/research/dppl_2024/dppl_public.py new file mode 100644 index 00000000..9829f372 --- /dev/null +++ b/research/dppl_2024/dppl_public.py @@ -0,0 +1,71 @@ +import warnings + +import hydra +import jax +import jax.numpy as jnp +import numpy as np +from lib import public, utils +from omegaconf import DictConfig, OmegaConf + + +@hydra.main(config_path="conf", config_name="public", version_base=None) +def main(cfg: DictConfig): + print(OmegaConf.to_yaml(cfg)) + + rho = utils.zcdp_of_naive_epsilon(cfg.epsilon) + actual_epsilon = utils.exponential_epsilon_of_zcdp(rho) + print( + f"Converted settings epsilon {cfg.epsilon} to rho {rho} to exponential epsilon {actual_epsilon}" + ) + + X_train, Y_train, X_test, Y_test = utils.load_dataset(cfg) + X_public = utils.load_public_dataset(cfg) + x_imbalanced, y_imbalanced = utils.give_imbalanced_set( + X_train, Y_train, cfg.imbalance_ratio + ) + classes = jnp.unique(y_imbalanced) + try: + jax.devices("gpu") + except RuntimeError: + warnings.warn("No GPU found, falling back to CPU. This will be slow.") + scores = jnp.stack( + [ + utils.scores_multiple( + x_imbalanced[y_imbalanced == target], + X_public, + cfg.min_score, + cfg.max_score, + ) + for target in classes + ] + ) + sensitivity = 1.0 + proto_idx_per_class = [] + for target in classes: + proto_idx_per_class.append( + public.exponential( + scores=scores[target], + sensitivity=sensitivity, + epsilon=actual_epsilon, + size=1, + monotonic=True, + key=int(cfg.seed + target), + ) + ) + public_protos = X_public[np.concatenate(proto_idx_per_class)].reshape( + len(classes), X_public.shape[-1] + ) + dists_test = utils.pairwise_distance(public_protos, X_test) + test_acc = float((dists_test.argmin(axis=0) == Y_test).mean()) + test_acc_per_class = jnp.stack( + [ + (dists_test[..., Y_test == target].argmin(axis=0) == target).mean() + for target in classes + ] + ) + print(f"Test accuracy: {test_acc}") + print(f"Test accuracy per class: {test_acc_per_class}") + + +if __name__ == "__main__": + main() diff --git a/research/dppl_2024/dppl_public_topk.py b/research/dppl_2024/dppl_public_topk.py new file mode 100644 index 00000000..f6868a0b --- /dev/null +++ b/research/dppl_2024/dppl_public_topk.py @@ -0,0 +1,83 @@ +import warnings +from functools import partial + +import hydra +import jax +import jax.numpy as jnp +from lib import public, utils +from omegaconf import DictConfig, OmegaConf + + +@hydra.main(config_path="conf", config_name="public_topk", version_base=None) +def main(cfg: DictConfig): + print(OmegaConf.to_yaml(cfg)) + + rho = utils.zcdp_of_naive_epsilon(cfg.epsilon) + actual_epsilon = utils.exponential_epsilon_of_zcdp(rho) + print( + f"Converted settings epsilon {cfg.epsilon} to rho {rho} to exponential epsilon {actual_epsilon}" + ) + + X_train, Y_train, X_test, Y_test = utils.load_dataset(cfg) + X_public = utils.load_public_dataset(cfg) + x_imbalanced, y_imbalanced = utils.give_imbalanced_set( + X_train, Y_train, cfg.imbalance_ratio + ) + classes = jnp.unique(y_imbalanced) + try: + jax.devices("gpu") + except RuntimeError: + warnings.warn("No GPU found, falling back to CPU. This will be slow.") + scores = jnp.stack( + [ + utils.scores_multiple( + x_imbalanced[y_imbalanced == target], + X_public, + cfg.min_score, + cfg.max_score, + ) + for target in classes + ] + ) + C_idx = jnp.argsort(scores, axis=1, descending=True) + if cfg.epsilon < jnp.inf: + C = jnp.stack([scores[i, C_idx[i]] for i in range(scores.shape[0])]) + U = C - C[:, cfg.k - 1][:, jnp.newaxis] + with jax.experimental.enable_x64(): + logm = jax.vmap(partial(public.log_binom, k=cfg.k), in_axes=(0))( + jnp.arange(scores.shape[-1]) + ) + proto_idx_C = public.give_topk_proto_idx( + U, + logm, + cfg.k, + U.shape[0], + U.shape[1], + actual_epsilon, + cfg.seed, + ) + proto_idx = jnp.stack( + [ + C_idx[jnp.arange(C_idx.shape[0]), proto_idx_C[:, k_i]] + for k_i in range(cfg.k) + ] + ).T + else: + proto_idx = jnp.stack( + [C_idx[jnp.arange(C_idx.shape[0]), k_i] for k_i in range(cfg.k)] + ).T + public_protos = X_public[proto_idx.flatten()].reshape((*proto_idx.shape, -1)) + dists_test = utils.pairwise_distance(public_protos, X_test) + test_acc = float((dists_test.argmin(axis=0) == Y_test).mean()) + test_acc_per_class = jnp.stack( + [ + (dists_test[..., Y_test == target].argmin(axis=0) == target).mean() + for target in classes + ] + ) + print(f"Test accuracy: {test_acc}") + print(f"Test accuracy per class: {test_acc_per_class}") + + +if __name__ == "__main__": + main() diff --git a/research/dppl_2024/env.yaml b/research/dppl_2024/env.yaml new file mode 100644 index 00000000..c3c4c39e --- /dev/null +++ b/research/dppl_2024/env.yaml @@ -0,0 +1,72 @@ +name: submission +channels: + - conda-forge +dependencies: + - _libgcc_mutex=0.1 + - _openmp_mutex=4.5 + - bzip2=1.0.8 + - ca-certificates=2024.2.2 + - ld_impl_linux-64=2.40 + - libffi=3.4.2 + - libgcc-ng=13.2.0 + - libgomp=13.2.0 + - libnsl=2.0.1 + - libsqlite=3.45.3 + - libuuid=2.38.1 + - libxcrypt=4.4.36 + - libzlib=1.2.13 + - ncurses=6.5 + - openssl=3.3.0 + - pip=24.0 + - python=3.10.14 + - readline=8.2 + - setuptools=69.5.1 + - tk=8.6.13 + - tzdata=2024a + - wheel=0.43.0 + - xz=5.2.6 + - pip: + - absl-py==2.1.0 + - antlr4-python3-runtime==4.9.3 + - chex==0.1.86 + - etils==1.7.0 + - flax==0.8.3 + - fsspec==2024.5.0 + - hydra-core==1.3.2 + - importlib-resources==6.4.0 + - jax==0.4.28 + - jax-cuda12-pjrt==0.4.28 + - jax-cuda12-plugin==0.4.28 + - jaxlib==0.4.28 + - markdown-it-py==3.0.0 + - mdurl==0.1.2 + - ml-dtypes==0.4.0 + - msgpack==1.0.8 + - nest-asyncio==1.6.0 + - numpy==1.26.4 + - nvidia-cublas-cu12==12.4.5.8 + - nvidia-cuda-cupti-cu12==12.4.127 + - nvidia-cuda-nvcc-cu12==12.4.131 + - nvidia-cuda-nvrtc-cu12==12.4.127 + - nvidia-cuda-runtime-cu12==12.4.127 + - nvidia-cudnn-cu12==8.9.7.29 + - nvidia-cufft-cu12==11.2.1.3 + - nvidia-cusolver-cu12==11.6.1.9 + - nvidia-cusparse-cu12==12.3.1.170 + - nvidia-nccl-cu12==2.21.5 + - nvidia-nvjitlink-cu12==12.4.127 + - omegaconf==2.3.0 + - opt-einsum==3.3.0 + - optax==0.2.2 + - orbax-checkpoint==0.5.11 + - packaging==24.0 + - protobuf==5.26.1 + - pygments==2.18.0 + - pyyaml==6.0.1 + - rich==13.7.1 + - scipy==1.13.0 + - tensorstore==0.1.59 + - toolz==0.12.1 + - typing-extensions==4.11.0 + - zipp==3.18.2 +prefix: /opt/conda/envs/submission diff --git a/research/dppl_2024/hparams_mean.md b/research/dppl_2024/hparams_mean.md new file mode 100644 index 00000000..f33038bb --- /dev/null +++ b/research/dppl_2024/hparams_mean.md @@ -0,0 +1,322 @@ +| dataset | imbalance_ratio | encoder | epsilon | pooling | +|:-----------|------------------:|:---------------------------------|----------:|----------:| +| cifar10 | 1 | dino_resnet50 | 0.1 | 1 | +| cifar10 | 1 | dino_resnet50 | 0.2 | 1 | +| cifar10 | 1 | dino_resnet50 | 1 | 1 | +| cifar10 | 1 | dino_resnet50 | 8 | 1 | +| cifar10 | 1 | vit_b_16 | 0.1 | 1 | +| cifar10 | 1 | vit_b_16 | 0.2 | 1 | +| cifar10 | 1 | vit_b_16 | 1 | 1 | +| cifar10 | 1 | vit_b_16 | 8 | 1 | +| cifar10 | 1 | vit_h_14 | 0.1 | 1 | +| cifar10 | 1 | vit_h_14 | 0.2 | 1 | +| cifar10 | 1 | vit_h_14 | 1 | 1 | +| cifar10 | 1 | vit_h_14 | 8 | 1 | +| cifar10 | 1 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1 | +| cifar10 | 1 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1 | +| cifar10 | 1 | vit_large_patch14_dinov2.lvd142m | 1 | 1 | +| cifar10 | 1 | vit_large_patch14_dinov2.lvd142m | 8 | 1 | +| cifar10 | 10 | dino_resnet50 | 0.1 | 20 | +| cifar10 | 10 | dino_resnet50 | 0.2 | 20 | +| cifar10 | 10 | dino_resnet50 | 1 | 1 | +| cifar10 | 10 | dino_resnet50 | 8 | 1 | +| cifar10 | 10 | vit_b_16 | 0.1 | 5 | +| cifar10 | 10 | vit_b_16 | 0.2 | 5 | +| cifar10 | 10 | vit_b_16 | 1 | 1 | +| cifar10 | 10 | vit_b_16 | 8 | 1 | +| cifar10 | 10 | vit_h_14 | 0.1 | 5 | +| cifar10 | 10 | vit_h_14 | 0.2 | 5 | +| cifar10 | 10 | vit_h_14 | 1 | 1 | +| cifar10 | 10 | vit_h_14 | 8 | 1 | +| cifar10 | 10 | vit_large_patch14_dinov2.lvd142m | 0.1 | 2 | +| cifar10 | 10 | vit_large_patch14_dinov2.lvd142m | 0.2 | 2 | +| cifar10 | 10 | vit_large_patch14_dinov2.lvd142m | 1 | 1 | +| cifar10 | 10 | vit_large_patch14_dinov2.lvd142m | 8 | 1 | +| cifar10 | 100 | dino_resnet50 | 0.1 | 5 | +| cifar10 | 100 | dino_resnet50 | 0.2 | 20 | +| cifar10 | 100 | dino_resnet50 | 1 | 1 | +| cifar10 | 100 | dino_resnet50 | 8 | 1 | +| cifar10 | 100 | vit_b_16 | 0.1 | 2 | +| cifar10 | 100 | vit_b_16 | 0.2 | 2 | +| cifar10 | 100 | vit_b_16 | 1 | 1 | +| cifar10 | 100 | vit_b_16 | 8 | 1 | +| cifar10 | 100 | vit_h_14 | 0.1 | 10 | +| cifar10 | 100 | vit_h_14 | 0.2 | 5 | +| cifar10 | 100 | vit_h_14 | 1 | 1 | +| cifar10 | 100 | vit_h_14 | 8 | 1 | +| cifar10 | 100 | vit_large_patch14_dinov2.lvd142m | 0.1 | 2 | +| cifar10 | 100 | vit_large_patch14_dinov2.lvd142m | 0.2 | 2 | +| cifar10 | 100 | vit_large_patch14_dinov2.lvd142m | 1 | 1 | +| cifar10 | 100 | vit_large_patch14_dinov2.lvd142m | 8 | 1 | +| cifar10 | 50 | dino_resnet50 | 0.1 | 20 | +| cifar10 | 50 | dino_resnet50 | 0.2 | 20 | +| cifar10 | 50 | dino_resnet50 | 1 | 1 | +| cifar10 | 50 | dino_resnet50 | 8 | 1 | +| cifar10 | 50 | vit_b_16 | 0.1 | 5 | +| cifar10 | 50 | vit_b_16 | 0.2 | 2 | +| cifar10 | 50 | vit_b_16 | 1 | 2 | +| cifar10 | 50 | vit_b_16 | 8 | 1 | +| cifar10 | 50 | vit_h_14 | 0.1 | 10 | +| cifar10 | 50 | vit_h_14 | 0.2 | 5 | +| cifar10 | 50 | vit_h_14 | 1 | 2 | +| cifar10 | 50 | vit_h_14 | 8 | 1 | +| cifar10 | 50 | vit_large_patch14_dinov2.lvd142m | 0.1 | 2 | +| cifar10 | 50 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1 | +| cifar10 | 50 | vit_large_patch14_dinov2.lvd142m | 1 | 1 | +| cifar10 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 1 | +| cifar100 | 1 | dino_resnet50 | 0.1 | 5 | +| cifar100 | 1 | dino_resnet50 | 0.2 | 5 | +| cifar100 | 1 | dino_resnet50 | 1 | 1 | +| cifar100 | 1 | dino_resnet50 | 8 | 1 | +| cifar100 | 1 | vit_b_16 | 0.1 | 5 | +| cifar100 | 1 | vit_b_16 | 0.2 | 2 | +| cifar100 | 1 | vit_b_16 | 1 | 1 | +| cifar100 | 1 | vit_b_16 | 8 | 1 | +| cifar100 | 1 | vit_h_14 | 0.1 | 20 | +| cifar100 | 1 | vit_h_14 | 0.2 | 5 | +| cifar100 | 1 | vit_h_14 | 1 | 1 | +| cifar100 | 1 | vit_h_14 | 8 | 1 | +| cifar100 | 1 | vit_large_patch14_dinov2.lvd142m | 0.1 | 10 | +| cifar100 | 1 | vit_large_patch14_dinov2.lvd142m | 0.2 | 5 | +| cifar100 | 1 | vit_large_patch14_dinov2.lvd142m | 1 | 1 | +| cifar100 | 1 | vit_large_patch14_dinov2.lvd142m | 8 | 1 | +| cifar100 | 10 | dino_resnet50 | 0.1 | 5 | +| cifar100 | 10 | dino_resnet50 | 0.2 | 50 | +| cifar100 | 10 | dino_resnet50 | 1 | 1 | +| cifar100 | 10 | dino_resnet50 | 8 | 1 | +| cifar100 | 10 | vit_b_16 | 0.1 | 5 | +| cifar100 | 10 | vit_b_16 | 0.2 | 5 | +| cifar100 | 10 | vit_b_16 | 1 | 2 | +| cifar100 | 10 | vit_b_16 | 8 | 1 | +| cifar100 | 10 | vit_h_14 | 0.1 | 20 | +| cifar100 | 10 | vit_h_14 | 0.2 | 10 | +| cifar100 | 10 | vit_h_14 | 1 | 5 | +| cifar100 | 10 | vit_h_14 | 8 | 1 | +| cifar100 | 10 | vit_large_patch14_dinov2.lvd142m | 0.1 | 10 | +| cifar100 | 10 | vit_large_patch14_dinov2.lvd142m | 0.2 | 5 | +| cifar100 | 10 | vit_large_patch14_dinov2.lvd142m | 1 | 5 | +| cifar100 | 10 | vit_large_patch14_dinov2.lvd142m | 8 | 1 | +| cifar100 | 100 | dino_resnet50 | 0.1 | 5 | +| cifar100 | 100 | dino_resnet50 | 0.2 | 50 | +| cifar100 | 100 | dino_resnet50 | 1 | 1 | +| cifar100 | 100 | dino_resnet50 | 8 | 1 | +| cifar100 | 100 | vit_b_16 | 0.1 | 5 | +| cifar100 | 100 | vit_b_16 | 0.2 | 5 | +| cifar100 | 100 | vit_b_16 | 1 | 5 | +| cifar100 | 100 | vit_b_16 | 8 | 2 | +| cifar100 | 100 | vit_h_14 | 0.1 | 20 | +| cifar100 | 100 | vit_h_14 | 0.2 | 10 | +| cifar100 | 100 | vit_h_14 | 1 | 10 | +| cifar100 | 100 | vit_h_14 | 8 | 2 | +| cifar100 | 100 | vit_large_patch14_dinov2.lvd142m | 0.1 | 10 | +| cifar100 | 100 | vit_large_patch14_dinov2.lvd142m | 0.2 | 5 | +| cifar100 | 100 | vit_large_patch14_dinov2.lvd142m | 1 | 5 | +| cifar100 | 100 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | +| cifar100 | 50 | dino_resnet50 | 0.1 | 5 | +| cifar100 | 50 | dino_resnet50 | 0.2 | 50 | +| cifar100 | 50 | dino_resnet50 | 1 | 1 | +| cifar100 | 50 | dino_resnet50 | 8 | 1 | +| cifar100 | 50 | vit_b_16 | 0.1 | 5 | +| cifar100 | 50 | vit_b_16 | 0.2 | 5 | +| cifar100 | 50 | vit_b_16 | 1 | 2 | +| cifar100 | 50 | vit_b_16 | 8 | 1 | +| cifar100 | 50 | vit_h_14 | 0.1 | 20 | +| cifar100 | 50 | vit_h_14 | 0.2 | 10 | +| cifar100 | 50 | vit_h_14 | 1 | 10 | +| cifar100 | 50 | vit_h_14 | 8 | 2 | +| cifar100 | 50 | vit_large_patch14_dinov2.lvd142m | 0.1 | 10 | +| cifar100 | 50 | vit_large_patch14_dinov2.lvd142m | 0.2 | 5 | +| cifar100 | 50 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | +| cifar100 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | +| flowers102 | 1 | dino_resnet50 | 0.1 | 5 | +| flowers102 | 1 | dino_resnet50 | 0.2 | 5 | +| flowers102 | 1 | dino_resnet50 | 1 | 5 | +| flowers102 | 1 | dino_resnet50 | 8 | 5 | +| flowers102 | 1 | vit_b_16 | 0.1 | 100 | +| flowers102 | 1 | vit_b_16 | 0.2 | 100 | +| flowers102 | 1 | vit_b_16 | 1 | 100 | +| flowers102 | 1 | vit_b_16 | 8 | 1 | +| flowers102 | 1 | vit_h_14 | 0.1 | 100 | +| flowers102 | 1 | vit_h_14 | 0.2 | 100 | +| flowers102 | 1 | vit_h_14 | 1 | 100 | +| flowers102 | 1 | vit_h_14 | 8 | 1 | +| flowers102 | 1 | vit_large_patch14_dinov2.lvd142m | 0.1 | 2 | +| flowers102 | 1 | vit_large_patch14_dinov2.lvd142m | 0.2 | 2 | +| flowers102 | 1 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | +| flowers102 | 1 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | +| flowers102 | 10 | dino_resnet50 | 0.1 | 2 | +| flowers102 | 10 | dino_resnet50 | 0.2 | 2 | +| flowers102 | 10 | dino_resnet50 | 1 | 2 | +| flowers102 | 10 | dino_resnet50 | 8 | 100 | +| flowers102 | 10 | vit_b_16 | 0.1 | 100 | +| flowers102 | 10 | vit_b_16 | 0.2 | 100 | +| flowers102 | 10 | vit_b_16 | 1 | 100 | +| flowers102 | 10 | vit_b_16 | 8 | 5 | +| flowers102 | 10 | vit_h_14 | 0.1 | 1 | +| flowers102 | 10 | vit_h_14 | 0.2 | 1 | +| flowers102 | 10 | vit_h_14 | 1 | 1 | +| flowers102 | 10 | vit_h_14 | 8 | 5 | +| flowers102 | 10 | vit_large_patch14_dinov2.lvd142m | 0.1 | 2 | +| flowers102 | 10 | vit_large_patch14_dinov2.lvd142m | 0.2 | 2 | +| flowers102 | 10 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | +| flowers102 | 10 | vit_large_patch14_dinov2.lvd142m | 8 | 10 | +| flowers102 | 100 | dino_resnet50 | 0.1 | 5 | +| flowers102 | 100 | dino_resnet50 | 0.2 | 5 | +| flowers102 | 100 | dino_resnet50 | 1 | 50 | +| flowers102 | 100 | dino_resnet50 | 8 | 2 | +| flowers102 | 100 | vit_b_16 | 0.1 | 20 | +| flowers102 | 100 | vit_b_16 | 0.2 | 20 | +| flowers102 | 100 | vit_b_16 | 1 | 2 | +| flowers102 | 100 | vit_b_16 | 8 | 2 | +| flowers102 | 100 | vit_h_14 | 0.1 | 50 | +| flowers102 | 100 | vit_h_14 | 0.2 | 50 | +| flowers102 | 100 | vit_h_14 | 1 | 5 | +| flowers102 | 100 | vit_h_14 | 8 | 10 | +| flowers102 | 100 | vit_large_patch14_dinov2.lvd142m | 0.1 | 10 | +| flowers102 | 100 | vit_large_patch14_dinov2.lvd142m | 0.2 | 10 | +| flowers102 | 100 | vit_large_patch14_dinov2.lvd142m | 1 | 5 | +| flowers102 | 100 | vit_large_patch14_dinov2.lvd142m | 8 | 10 | +| flowers102 | 50 | dino_resnet50 | 0.1 | 5 | +| flowers102 | 50 | dino_resnet50 | 0.2 | 20 | +| flowers102 | 50 | dino_resnet50 | 1 | 50 | +| flowers102 | 50 | dino_resnet50 | 8 | 1 | +| flowers102 | 50 | vit_b_16 | 0.1 | 5 | +| flowers102 | 50 | vit_b_16 | 0.2 | 5 | +| flowers102 | 50 | vit_b_16 | 1 | 5 | +| flowers102 | 50 | vit_b_16 | 8 | 5 | +| flowers102 | 50 | vit_h_14 | 0.1 | 50 | +| flowers102 | 50 | vit_h_14 | 0.2 | 50 | +| flowers102 | 50 | vit_h_14 | 1 | 50 | +| flowers102 | 50 | vit_h_14 | 8 | 10 | +| flowers102 | 50 | vit_large_patch14_dinov2.lvd142m | 0.1 | 2 | +| flowers102 | 50 | vit_large_patch14_dinov2.lvd142m | 0.2 | 2 | +| flowers102 | 50 | vit_large_patch14_dinov2.lvd142m | 1 | 10 | +| flowers102 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 10 | +| food101 | 1 | dino_resnet50 | 0.1 | 5 | +| food101 | 1 | dino_resnet50 | 0.2 | 2 | +| food101 | 1 | dino_resnet50 | 1 | 1 | +| food101 | 1 | dino_resnet50 | 8 | 1 | +| food101 | 1 | vit_b_16 | 0.1 | 5 | +| food101 | 1 | vit_b_16 | 0.2 | 2 | +| food101 | 1 | vit_b_16 | 1 | 1 | +| food101 | 1 | vit_b_16 | 8 | 1 | +| food101 | 1 | vit_h_14 | 0.1 | 10 | +| food101 | 1 | vit_h_14 | 0.2 | 2 | +| food101 | 1 | vit_h_14 | 1 | 1 | +| food101 | 1 | vit_h_14 | 8 | 1 | +| food101 | 1 | vit_large_patch14_dinov2.lvd142m | 0.1 | 5 | +| food101 | 1 | vit_large_patch14_dinov2.lvd142m | 0.2 | 2 | +| food101 | 1 | vit_large_patch14_dinov2.lvd142m | 1 | 1 | +| food101 | 1 | vit_large_patch14_dinov2.lvd142m | 8 | 1 | +| food101 | 10 | dino_resnet50 | 0.1 | 50 | +| food101 | 10 | dino_resnet50 | 0.2 | 10 | +| food101 | 10 | dino_resnet50 | 1 | 1 | +| food101 | 10 | dino_resnet50 | 8 | 1 | +| food101 | 10 | vit_b_16 | 0.1 | 5 | +| food101 | 10 | vit_b_16 | 0.2 | 2 | +| food101 | 10 | vit_b_16 | 1 | 2 | +| food101 | 10 | vit_b_16 | 8 | 1 | +| food101 | 10 | vit_h_14 | 0.1 | 10 | +| food101 | 10 | vit_h_14 | 0.2 | 10 | +| food101 | 10 | vit_h_14 | 1 | 5 | +| food101 | 10 | vit_h_14 | 8 | 1 | +| food101 | 10 | vit_large_patch14_dinov2.lvd142m | 0.1 | 10 | +| food101 | 10 | vit_large_patch14_dinov2.lvd142m | 0.2 | 5 | +| food101 | 10 | vit_large_patch14_dinov2.lvd142m | 1 | 5 | +| food101 | 10 | vit_large_patch14_dinov2.lvd142m | 8 | 1 | +| food101 | 100 | dino_resnet50 | 0.1 | 50 | +| food101 | 100 | dino_resnet50 | 0.2 | 50 | +| food101 | 100 | dino_resnet50 | 1 | 1 | +| food101 | 100 | dino_resnet50 | 8 | 1 | +| food101 | 100 | vit_b_16 | 0.1 | 5 | +| food101 | 100 | vit_b_16 | 0.2 | 2 | +| food101 | 100 | vit_b_16 | 1 | 2 | +| food101 | 100 | vit_b_16 | 8 | 1 | +| food101 | 100 | vit_h_14 | 0.1 | 10 | +| food101 | 100 | vit_h_14 | 0.2 | 10 | +| food101 | 100 | vit_h_14 | 1 | 5 | +| food101 | 100 | vit_h_14 | 8 | 2 | +| food101 | 100 | vit_large_patch14_dinov2.lvd142m | 0.1 | 5 | +| food101 | 100 | vit_large_patch14_dinov2.lvd142m | 0.2 | 5 | +| food101 | 100 | vit_large_patch14_dinov2.lvd142m | 1 | 5 | +| food101 | 100 | vit_large_patch14_dinov2.lvd142m | 8 | 1 | +| food101 | 50 | dino_resnet50 | 0.1 | 50 | +| food101 | 50 | dino_resnet50 | 0.2 | 50 | +| food101 | 50 | dino_resnet50 | 1 | 1 | +| food101 | 50 | dino_resnet50 | 8 | 1 | +| food101 | 50 | vit_b_16 | 0.1 | 5 | +| food101 | 50 | vit_b_16 | 0.2 | 2 | +| food101 | 50 | vit_b_16 | 1 | 2 | +| food101 | 50 | vit_b_16 | 8 | 1 | +| food101 | 50 | vit_h_14 | 0.1 | 10 | +| food101 | 50 | vit_h_14 | 0.2 | 10 | +| food101 | 50 | vit_h_14 | 1 | 5 | +| food101 | 50 | vit_h_14 | 8 | 2 | +| food101 | 50 | vit_large_patch14_dinov2.lvd142m | 0.1 | 5 | +| food101 | 50 | vit_large_patch14_dinov2.lvd142m | 0.2 | 5 | +| food101 | 50 | vit_large_patch14_dinov2.lvd142m | 1 | 5 | +| food101 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | +| stl10 | 1 | dino_resnet50 | 0.1 | 10 | +| stl10 | 1 | dino_resnet50 | 0.2 | 10 | +| stl10 | 1 | dino_resnet50 | 1 | 1 | +| stl10 | 1 | dino_resnet50 | 8 | 1 | +| stl10 | 1 | vit_b_16 | 0.1 | 5 | +| stl10 | 1 | vit_b_16 | 0.2 | 2 | +| stl10 | 1 | vit_b_16 | 1 | 1 | +| stl10 | 1 | vit_b_16 | 8 | 1 | +| stl10 | 1 | vit_h_14 | 0.1 | 20 | +| stl10 | 1 | vit_h_14 | 0.2 | 5 | +| stl10 | 1 | vit_h_14 | 1 | 1 | +| stl10 | 1 | vit_h_14 | 8 | 1 | +| stl10 | 1 | vit_large_patch14_dinov2.lvd142m | 0.1 | 10 | +| stl10 | 1 | vit_large_patch14_dinov2.lvd142m | 0.2 | 5 | +| stl10 | 1 | vit_large_patch14_dinov2.lvd142m | 1 | 1 | +| stl10 | 1 | vit_large_patch14_dinov2.lvd142m | 8 | 1 | +| stl10 | 10 | dino_resnet50 | 0.1 | 10 | +| stl10 | 10 | dino_resnet50 | 0.2 | 10 | +| stl10 | 10 | dino_resnet50 | 1 | 20 | +| stl10 | 10 | dino_resnet50 | 8 | 1 | +| stl10 | 10 | vit_b_16 | 0.1 | 5 | +| stl10 | 10 | vit_b_16 | 0.2 | 5 | +| stl10 | 10 | vit_b_16 | 1 | 2 | +| stl10 | 10 | vit_b_16 | 8 | 1 | +| stl10 | 10 | vit_h_14 | 0.1 | 20 | +| stl10 | 10 | vit_h_14 | 0.2 | 10 | +| stl10 | 10 | vit_h_14 | 1 | 5 | +| stl10 | 10 | vit_h_14 | 8 | 1 | +| stl10 | 10 | vit_large_patch14_dinov2.lvd142m | 0.1 | 10 | +| stl10 | 10 | vit_large_patch14_dinov2.lvd142m | 0.2 | 10 | +| stl10 | 10 | vit_large_patch14_dinov2.lvd142m | 1 | 1 | +| stl10 | 10 | vit_large_patch14_dinov2.lvd142m | 8 | 1 | +| stl10 | 100 | dino_resnet50 | 0.1 | 100 | +| stl10 | 100 | dino_resnet50 | 0.2 | 10 | +| stl10 | 100 | dino_resnet50 | 1 | 50 | +| stl10 | 100 | dino_resnet50 | 8 | 1 | +| stl10 | 100 | vit_b_16 | 0.1 | 10 | +| stl10 | 100 | vit_b_16 | 0.2 | 5 | +| stl10 | 100 | vit_b_16 | 1 | 2 | +| stl10 | 100 | vit_b_16 | 8 | 1 | +| stl10 | 100 | vit_h_14 | 0.1 | 10 | +| stl10 | 100 | vit_h_14 | 0.2 | 10 | +| stl10 | 100 | vit_h_14 | 1 | 10 | +| stl10 | 100 | vit_h_14 | 8 | 5 | +| stl10 | 100 | vit_large_patch14_dinov2.lvd142m | 0.1 | 10 | +| stl10 | 100 | vit_large_patch14_dinov2.lvd142m | 0.2 | 10 | +| stl10 | 100 | vit_large_patch14_dinov2.lvd142m | 1 | 1 | +| stl10 | 100 | vit_large_patch14_dinov2.lvd142m | 8 | 1 | +| stl10 | 50 | dino_resnet50 | 0.1 | 10 | +| stl10 | 50 | dino_resnet50 | 0.2 | 10 | +| stl10 | 50 | dino_resnet50 | 1 | 50 | +| stl10 | 50 | dino_resnet50 | 8 | 1 | +| stl10 | 50 | vit_b_16 | 0.1 | 10 | +| stl10 | 50 | vit_b_16 | 0.2 | 5 | +| stl10 | 50 | vit_b_16 | 1 | 1 | +| stl10 | 50 | vit_b_16 | 8 | 1 | +| stl10 | 50 | vit_h_14 | 0.1 | 10 | +| stl10 | 50 | vit_h_14 | 0.2 | 10 | +| stl10 | 50 | vit_h_14 | 1 | 10 | +| stl10 | 50 | vit_h_14 | 8 | 5 | +| stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 0.1 | 10 | +| stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 0.2 | 10 | +| stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 1 | 1 | +| stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 1 | \ No newline at end of file diff --git a/research/dppl_2024/hparams_public.md b/research/dppl_2024/hparams_public.md new file mode 100644 index 00000000..1f94dfe4 --- /dev/null +++ b/research/dppl_2024/hparams_public.md @@ -0,0 +1,258 @@ +| dataset | imbalance_ratio | encoder | epsilon | d_max | d_min | +|:----------|------------------:|:---------------------------------|----------:|--------:|--------:| +| cifar10 | 1 | dino_resnet50 | 0.1 | 2 | 0 | +| cifar10 | 1 | dino_resnet50 | 0.2 | 2 | 0 | +| cifar10 | 1 | dino_resnet50 | 1 | 1.64 | 1.34 | +| cifar10 | 1 | dino_resnet50 | 8 | 1.64 | 1.34 | +| cifar10 | 1 | vit_b_16 | 0.1 | 2 | 0 | +| cifar10 | 1 | vit_b_16 | 0.2 | 2 | 0 | +| cifar10 | 1 | vit_b_16 | 1 | 2 | 0 | +| cifar10 | 1 | vit_b_16 | 8 | 2 | 0 | +| cifar10 | 1 | vit_h_14 | 0.1 | 2 | 0 | +| cifar10 | 1 | vit_h_14 | 0.2 | 2 | 0 | +| cifar10 | 1 | vit_h_14 | 1 | 2 | 0 | +| cifar10 | 1 | vit_h_14 | 8 | 2 | 0 | +| cifar10 | 1 | vit_large_patch14_dinov2.lvd142m | 0.1 | 2 | 0 | +| cifar10 | 1 | vit_large_patch14_dinov2.lvd142m | 0.2 | 2 | 0 | +| cifar10 | 1 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | 0 | +| cifar10 | 1 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | +| cifar10 | 10 | dino_resnet50 | 0.1 | 1.6 | 1.58 | +| cifar10 | 10 | dino_resnet50 | 0.2 | 2 | 0 | +| cifar10 | 10 | dino_resnet50 | 1 | 2 | 0 | +| cifar10 | 10 | dino_resnet50 | 8 | 1.64 | 1.34 | +| cifar10 | 10 | vit_b_16 | 0.1 | 1.5 | 1.42 | +| cifar10 | 10 | vit_b_16 | 0.2 | 1.5 | 1.42 | +| cifar10 | 10 | vit_b_16 | 1 | 2 | 0 | +| cifar10 | 10 | vit_b_16 | 8 | 2 | 0 | +| cifar10 | 10 | vit_h_14 | 0.1 | 2 | 0 | +| cifar10 | 10 | vit_h_14 | 0.2 | 2 | 0 | +| cifar10 | 10 | vit_h_14 | 1 | 2 | 0 | +| cifar10 | 10 | vit_h_14 | 8 | 2 | 0 | +| cifar10 | 10 | vit_large_patch14_dinov2.lvd142m | 0.1 | 2 | 0 | +| cifar10 | 10 | vit_large_patch14_dinov2.lvd142m | 0.2 | 2 | 0 | +| cifar10 | 10 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | 0 | +| cifar10 | 10 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | +| cifar10 | 100 | dino_resnet50 | 0.1 | 1.6 | 1.58 | +| cifar10 | 100 | dino_resnet50 | 0.2 | 1.6 | 1.58 | +| cifar10 | 100 | dino_resnet50 | 1 | 2 | 0 | +| cifar10 | 100 | dino_resnet50 | 8 | 2 | 0 | +| cifar10 | 100 | vit_b_16 | 0.1 | 1.5 | 1.42 | +| cifar10 | 100 | vit_b_16 | 0.2 | 1.5 | 1.42 | +| cifar10 | 100 | vit_b_16 | 1 | 1.5 | 1.42 | +| cifar10 | 100 | vit_b_16 | 8 | 2 | 0 | +| cifar10 | 100 | vit_h_14 | 0.1 | 1.5 | 1.42 | +| cifar10 | 100 | vit_h_14 | 0.2 | 1.54 | 1.46 | +| cifar10 | 100 | vit_h_14 | 1 | 2 | 0 | +| cifar10 | 100 | vit_h_14 | 8 | 2 | 0 | +| cifar10 | 100 | vit_large_patch14_dinov2.lvd142m | 0.1 | 2 | 0 | +| cifar10 | 100 | vit_large_patch14_dinov2.lvd142m | 0.2 | 2 | 0 | +| cifar10 | 100 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | 0 | +| cifar10 | 100 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | +| cifar10 | 50 | dino_resnet50 | 0.1 | 1.6 | 1.58 | +| cifar10 | 50 | dino_resnet50 | 0.2 | 1.6 | 1.58 | +| cifar10 | 50 | dino_resnet50 | 1 | 2 | 0 | +| cifar10 | 50 | dino_resnet50 | 8 | 2 | 0 | +| cifar10 | 50 | vit_b_16 | 0.1 | 1.5 | 1.42 | +| cifar10 | 50 | vit_b_16 | 0.2 | 1.5 | 1.42 | +| cifar10 | 50 | vit_b_16 | 1 | 2 | 0 | +| cifar10 | 50 | vit_b_16 | 8 | 2 | 0 | +| cifar10 | 50 | vit_h_14 | 0.1 | 1.64 | 1.34 | +| cifar10 | 50 | vit_h_14 | 0.2 | 1.54 | 1.46 | +| cifar10 | 50 | vit_h_14 | 1 | 2 | 0 | +| cifar10 | 50 | vit_h_14 | 8 | 2 | 0 | +| cifar10 | 50 | vit_large_patch14_dinov2.lvd142m | 0.1 | 2 | 0 | +| cifar10 | 50 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.5 | 1.42 | +| cifar10 | 50 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | 0 | +| cifar10 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | +| cifar100 | 1 | dino_resnet50 | 0.1 | 1.6 | 1.58 | +| cifar100 | 1 | dino_resnet50 | 0.2 | 1.6 | 1.58 | +| cifar100 | 1 | dino_resnet50 | 1 | 2 | 0 | +| cifar100 | 1 | dino_resnet50 | 8 | 2 | 0 | +| cifar100 | 1 | vit_b_16 | 0.1 | 1.5 | 1.42 | +| cifar100 | 1 | vit_b_16 | 0.2 | 1.5 | 1.42 | +| cifar100 | 1 | vit_b_16 | 1 | 1.5 | 1.42 | +| cifar100 | 1 | vit_b_16 | 8 | 1.5 | 1.42 | +| cifar100 | 1 | vit_h_14 | 0.1 | 1.58 | 1.54 | +| cifar100 | 1 | vit_h_14 | 0.2 | 1.6 | 1.58 | +| cifar100 | 1 | vit_h_14 | 1 | 2 | 0 | +| cifar100 | 1 | vit_h_14 | 8 | 2 | 0 | +| cifar100 | 1 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.5 | 1.42 | +| cifar100 | 1 | vit_large_patch14_dinov2.lvd142m | 0.2 | 2 | 0 | +| cifar100 | 1 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | 0 | +| cifar100 | 1 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | +| cifar100 | 10 | dino_resnet50 | 0.1 | 1.6 | 1.58 | +| cifar100 | 10 | dino_resnet50 | 0.2 | 1.6 | 1.58 | +| cifar100 | 10 | dino_resnet50 | 1 | 1.64 | 1.34 | +| cifar100 | 10 | dino_resnet50 | 8 | 2 | 0 | +| cifar100 | 10 | vit_b_16 | 0.1 | 1.5 | 1.42 | +| cifar100 | 10 | vit_b_16 | 0.2 | 1.5 | 1.42 | +| cifar100 | 10 | vit_b_16 | 1 | 1.54 | 1.46 | +| cifar100 | 10 | vit_b_16 | 8 | 1.64 | 1.34 | +| cifar100 | 10 | vit_h_14 | 0.1 | 1.5 | 1.42 | +| cifar100 | 10 | vit_h_14 | 0.2 | 1.64 | 1.34 | +| cifar100 | 10 | vit_h_14 | 1 | 1.64 | 1.34 | +| cifar100 | 10 | vit_h_14 | 8 | 2 | 0 | +| cifar100 | 10 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.5 | 1.42 | +| cifar100 | 10 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.5 | 1.42 | +| cifar100 | 10 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | 0 | +| cifar100 | 10 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | +| cifar100 | 100 | dino_resnet50 | 0.1 | 1.6 | 1.58 | +| cifar100 | 100 | dino_resnet50 | 0.2 | 1.6 | 1.58 | +| cifar100 | 100 | dino_resnet50 | 1 | 1.6 | 1.58 | +| cifar100 | 100 | dino_resnet50 | 8 | 2 | 0 | +| cifar100 | 100 | vit_b_16 | 0.1 | 1.5 | 1.42 | +| cifar100 | 100 | vit_b_16 | 0.2 | 1.5 | 1.42 | +| cifar100 | 100 | vit_b_16 | 1 | 1.5 | 1.42 | +| cifar100 | 100 | vit_b_16 | 8 | 1.64 | 1.34 | +| cifar100 | 100 | vit_h_14 | 0.1 | 1.5 | 1.42 | +| cifar100 | 100 | vit_h_14 | 0.2 | 1.5 | 1.42 | +| cifar100 | 100 | vit_h_14 | 1 | 1.64 | 1.34 | +| cifar100 | 100 | vit_h_14 | 8 | 1.64 | 1.34 | +| cifar100 | 100 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.5 | 1.42 | +| cifar100 | 100 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.5 | 1.42 | +| cifar100 | 100 | vit_large_patch14_dinov2.lvd142m | 1 | 1.5 | 1.42 | +| cifar100 | 100 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | +| cifar100 | 50 | dino_resnet50 | 0.1 | 1.6 | 1.58 | +| cifar100 | 50 | dino_resnet50 | 0.2 | 1.6 | 1.58 | +| cifar100 | 50 | dino_resnet50 | 1 | 1.6 | 1.58 | +| cifar100 | 50 | dino_resnet50 | 8 | 2 | 0 | +| cifar100 | 50 | vit_b_16 | 0.1 | 1.5 | 1.42 | +| cifar100 | 50 | vit_b_16 | 0.2 | 1.5 | 1.42 | +| cifar100 | 50 | vit_b_16 | 1 | 1.5 | 1.42 | +| cifar100 | 50 | vit_b_16 | 8 | 2 | 0 | +| cifar100 | 50 | vit_h_14 | 0.1 | 1.5 | 1.42 | +| cifar100 | 50 | vit_h_14 | 0.2 | 1.5 | 1.42 | +| cifar100 | 50 | vit_h_14 | 1 | 1.64 | 1.34 | +| cifar100 | 50 | vit_h_14 | 8 | 1.64 | 1.34 | +| cifar100 | 50 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.5 | 1.42 | +| cifar100 | 50 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.5 | 1.42 | +| cifar100 | 50 | vit_large_patch14_dinov2.lvd142m | 1 | 1.5 | 1.42 | +| cifar100 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | +| food101 | 1 | dino_resnet50 | 0.1 | 1.6 | 1.58 | +| food101 | 1 | dino_resnet50 | 0.2 | 1.64 | 1.34 | +| food101 | 1 | dino_resnet50 | 1 | 2 | 0 | +| food101 | 1 | dino_resnet50 | 8 | 1.64 | 1.34 | +| food101 | 1 | vit_b_16 | 0.1 | 1.64 | 1.34 | +| food101 | 1 | vit_b_16 | 0.2 | 1.5 | 1.42 | +| food101 | 1 | vit_b_16 | 1 | 2 | 0 | +| food101 | 1 | vit_b_16 | 8 | 2 | 0 | +| food101 | 1 | vit_h_14 | 0.1 | 1.64 | 1.34 | +| food101 | 1 | vit_h_14 | 0.2 | 1.64 | 1.34 | +| food101 | 1 | vit_h_14 | 1 | 2 | 0 | +| food101 | 1 | vit_h_14 | 8 | 1.64 | 1.34 | +| food101 | 1 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.64 | 1.34 | +| food101 | 1 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.64 | 1.34 | +| food101 | 1 | vit_large_patch14_dinov2.lvd142m | 1 | 1.64 | 1.34 | +| food101 | 1 | vit_large_patch14_dinov2.lvd142m | 8 | 1.56 | 1.5 | +| food101 | 10 | dino_resnet50 | 0.1 | 1.6 | 1.58 | +| food101 | 10 | dino_resnet50 | 0.2 | 1.6 | 1.58 | +| food101 | 10 | dino_resnet50 | 1 | 1.64 | 1.34 | +| food101 | 10 | dino_resnet50 | 8 | 2 | 0 | +| food101 | 10 | vit_b_16 | 0.1 | 1.5 | 1.42 | +| food101 | 10 | vit_b_16 | 0.2 | 1.5 | 1.42 | +| food101 | 10 | vit_b_16 | 1 | 2 | 0 | +| food101 | 10 | vit_b_16 | 8 | 2 | 0 | +| food101 | 10 | vit_h_14 | 0.1 | 1.5 | 1.42 | +| food101 | 10 | vit_h_14 | 0.2 | 1.64 | 1.34 | +| food101 | 10 | vit_h_14 | 1 | 1.64 | 1.34 | +| food101 | 10 | vit_h_14 | 8 | 1.64 | 1.34 | +| food101 | 10 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.5 | 1.42 | +| food101 | 10 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.64 | 1.34 | +| food101 | 10 | vit_large_patch14_dinov2.lvd142m | 1 | 1.64 | 1.34 | +| food101 | 10 | vit_large_patch14_dinov2.lvd142m | 8 | 1.64 | 1.34 | +| food101 | 100 | dino_resnet50 | 0.1 | 1.6 | 1.58 | +| food101 | 100 | dino_resnet50 | 0.2 | 1.6 | 1.58 | +| food101 | 100 | dino_resnet50 | 1 | 1.6 | 1.58 | +| food101 | 100 | dino_resnet50 | 8 | 1.64 | 1.34 | +| food101 | 100 | vit_b_16 | 0.1 | 1.5 | 1.42 | +| food101 | 100 | vit_b_16 | 0.2 | 1.5 | 1.42 | +| food101 | 100 | vit_b_16 | 1 | 1.64 | 1.34 | +| food101 | 100 | vit_b_16 | 8 | 2 | 0 | +| food101 | 100 | vit_h_14 | 0.1 | 1.5 | 1.42 | +| food101 | 100 | vit_h_14 | 0.2 | 1.5 | 1.42 | +| food101 | 100 | vit_h_14 | 1 | 1.64 | 1.34 | +| food101 | 100 | vit_h_14 | 8 | 1.64 | 1.34 | +| food101 | 100 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.5 | 1.42 | +| food101 | 100 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.5 | 1.42 | +| food101 | 100 | vit_large_patch14_dinov2.lvd142m | 1 | 1.64 | 1.34 | +| food101 | 100 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | +| food101 | 50 | dino_resnet50 | 0.1 | 1.6 | 1.58 | +| food101 | 50 | dino_resnet50 | 0.2 | 1.6 | 1.58 | +| food101 | 50 | dino_resnet50 | 1 | 1.6 | 1.58 | +| food101 | 50 | dino_resnet50 | 8 | 2 | 0 | +| food101 | 50 | vit_b_16 | 0.1 | 1.5 | 1.42 | +| food101 | 50 | vit_b_16 | 0.2 | 1.5 | 1.42 | +| food101 | 50 | vit_b_16 | 1 | 1.64 | 1.34 | +| food101 | 50 | vit_b_16 | 8 | 2 | 0 | +| food101 | 50 | vit_h_14 | 0.1 | 1.5 | 1.42 | +| food101 | 50 | vit_h_14 | 0.2 | 1.5 | 1.42 | +| food101 | 50 | vit_h_14 | 1 | 1.64 | 1.34 | +| food101 | 50 | vit_h_14 | 8 | 2 | 0 | +| food101 | 50 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.5 | 1.42 | +| food101 | 50 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.5 | 1.42 | +| food101 | 50 | vit_large_patch14_dinov2.lvd142m | 1 | 1.64 | 1.34 | +| food101 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | +| stl10 | 1 | dino_resnet50 | 0.1 | 1.6 | 1.58 | +| stl10 | 1 | dino_resnet50 | 0.2 | 1.6 | 1.58 | +| stl10 | 1 | dino_resnet50 | 1 | 2 | 0 | +| stl10 | 1 | dino_resnet50 | 8 | 2 | 0 | +| stl10 | 1 | vit_b_16 | 0.1 | 1.5 | 1.42 | +| stl10 | 1 | vit_b_16 | 0.2 | 2 | 0 | +| stl10 | 1 | vit_b_16 | 1 | 2 | 0 | +| stl10 | 1 | vit_b_16 | 8 | 2 | 0 | +| stl10 | 1 | vit_h_14 | 0.1 | 1.64 | 1.34 | +| stl10 | 1 | vit_h_14 | 0.2 | 1.64 | 1.34 | +| stl10 | 1 | vit_h_14 | 1 | 1.64 | 1.34 | +| stl10 | 1 | vit_h_14 | 8 | 1.64 | 1.34 | +| stl10 | 1 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.5 | 1.42 | +| stl10 | 1 | vit_large_patch14_dinov2.lvd142m | 0.2 | 2 | 0 | +| stl10 | 1 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | 0 | +| stl10 | 1 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | +| stl10 | 10 | dino_resnet50 | 0.1 | 1.6 | 1.58 | +| stl10 | 10 | dino_resnet50 | 0.2 | 1.6 | 1.58 | +| stl10 | 10 | dino_resnet50 | 1 | 1.6 | 1.58 | +| stl10 | 10 | dino_resnet50 | 8 | 2 | 0 | +| stl10 | 10 | vit_b_16 | 0.1 | 1.5 | 1.42 | +| stl10 | 10 | vit_b_16 | 0.2 | 1.5 | 1.42 | +| stl10 | 10 | vit_b_16 | 1 | 2 | 0 | +| stl10 | 10 | vit_b_16 | 8 | 2 | 0 | +| stl10 | 10 | vit_h_14 | 0.1 | 1.5 | 1.42 | +| stl10 | 10 | vit_h_14 | 0.2 | 1.64 | 1.34 | +| stl10 | 10 | vit_h_14 | 1 | 1.64 | 1.34 | +| stl10 | 10 | vit_h_14 | 8 | 1.64 | 1.34 | +| stl10 | 10 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.5 | 1.42 | +| stl10 | 10 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.5 | 1.42 | +| stl10 | 10 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | 0 | +| stl10 | 10 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | +| stl10 | 100 | dino_resnet50 | 0.1 | 1.6 | 1.58 | +| stl10 | 100 | dino_resnet50 | 0.2 | 1.6 | 1.58 | +| stl10 | 100 | dino_resnet50 | 1 | 1.6 | 1.58 | +| stl10 | 100 | dino_resnet50 | 8 | 1.64 | 1.34 | +| stl10 | 100 | vit_b_16 | 0.1 | 1.5 | 1.42 | +| stl10 | 100 | vit_b_16 | 0.2 | 1.5 | 1.42 | +| stl10 | 100 | vit_b_16 | 1 | 1.5 | 1.42 | +| stl10 | 100 | vit_b_16 | 8 | 2 | 0 | +| stl10 | 100 | vit_h_14 | 0.1 | 1.5 | 1.42 | +| stl10 | 100 | vit_h_14 | 0.2 | 1.5 | 1.42 | +| stl10 | 100 | vit_h_14 | 1 | 1.5 | 1.42 | +| stl10 | 100 | vit_h_14 | 8 | 1.5 | 1.42 | +| stl10 | 100 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.5 | 1.42 | +| stl10 | 100 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.5 | 1.42 | +| stl10 | 100 | vit_large_patch14_dinov2.lvd142m | 1 | 1.5 | 1.42 | +| stl10 | 100 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | +| stl10 | 50 | dino_resnet50 | 0.1 | 1.6 | 1.58 | +| stl10 | 50 | dino_resnet50 | 0.2 | 1.6 | 1.58 | +| stl10 | 50 | dino_resnet50 | 1 | 1.6 | 1.58 | +| stl10 | 50 | dino_resnet50 | 8 | 1.64 | 1.34 | +| stl10 | 50 | vit_b_16 | 0.1 | 1.5 | 1.42 | +| stl10 | 50 | vit_b_16 | 0.2 | 1.5 | 1.42 | +| stl10 | 50 | vit_b_16 | 1 | 1.5 | 1.42 | +| stl10 | 50 | vit_b_16 | 8 | 2 | 0 | +| stl10 | 50 | vit_h_14 | 0.1 | 1.5 | 1.42 | +| stl10 | 50 | vit_h_14 | 0.2 | 1.5 | 1.42 | +| stl10 | 50 | vit_h_14 | 1 | 1.64 | 1.34 | +| stl10 | 50 | vit_h_14 | 8 | 1.5 | 1.42 | +| stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.5 | 1.42 | +| stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.5 | 1.42 | +| stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 1 | 1.5 | 1.42 | +| stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | \ No newline at end of file diff --git a/research/dppl_2024/hparams_public_topk.md b/research/dppl_2024/hparams_public_topk.md new file mode 100644 index 00000000..1f42375a --- /dev/null +++ b/research/dppl_2024/hparams_public_topk.md @@ -0,0 +1,258 @@ +| dataset | imbalance_ratio | encoder | epsilon | d_max | d_min | K | +|:----------|------------------:|:---------------------------------|----------:|--------:|--------:|----:| +| cifar10 | 1 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 10 | +| cifar10 | 1 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 20 | +| cifar10 | 1 | dino_resnet50 | 1 | 2 | 0 | 10 | +| cifar10 | 1 | dino_resnet50 | 8 | 2 | 0 | 20 | +| cifar10 | 1 | vit_b_16 | 0.1 | 1.64 | 1.34 | 5 | +| cifar10 | 1 | vit_b_16 | 0.2 | 2 | 0 | 3 | +| cifar10 | 1 | vit_b_16 | 1 | 2 | 0 | 20 | +| cifar10 | 1 | vit_b_16 | 8 | 2 | 0 | 20 | +| cifar10 | 1 | vit_h_14 | 0.1 | 1.506 | 1.42 | 5 | +| cifar10 | 1 | vit_h_14 | 0.2 | 1.506 | 1.42 | 10 | +| cifar10 | 1 | vit_h_14 | 1 | 1.506 | 1.42 | 20 | +| cifar10 | 1 | vit_h_14 | 8 | 1.506 | 1.42 | 20 | +| cifar10 | 1 | vit_large_patch14_dinov2.lvd142m | 0.1 | 2 | 0 | 3 | +| cifar10 | 1 | vit_large_patch14_dinov2.lvd142m | 0.2 | 2 | 0 | 5 | +| cifar10 | 1 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | 0 | 10 | +| cifar10 | 1 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | 20 | +| cifar10 | 10 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 3 | +| cifar10 | 10 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 3 | +| cifar10 | 10 | dino_resnet50 | 1 | 1.64 | 1.34 | 5 | +| cifar10 | 10 | dino_resnet50 | 8 | 2 | 0 | 10 | +| cifar10 | 10 | vit_b_16 | 0.1 | 1.64 | 1.34 | 1 | +| cifar10 | 10 | vit_b_16 | 0.2 | 1.64 | 1.34 | 1 | +| cifar10 | 10 | vit_b_16 | 1 | 2 | 0 | 3 | +| cifar10 | 10 | vit_b_16 | 8 | 2 | 0 | 20 | +| cifar10 | 10 | vit_h_14 | 0.1 | 1.506 | 1.42 | 3 | +| cifar10 | 10 | vit_h_14 | 0.2 | 1.506 | 1.42 | 10 | +| cifar10 | 10 | vit_h_14 | 1 | 1.506 | 1.42 | 20 | +| cifar10 | 10 | vit_h_14 | 8 | 2 | 0 | 20 | +| cifar10 | 10 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.64 | 1.34 | 1 | +| cifar10 | 10 | vit_large_patch14_dinov2.lvd142m | 0.2 | 2 | 0 | 1 | +| cifar10 | 10 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | 0 | 5 | +| cifar10 | 10 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | 20 | +| cifar10 | 100 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 3 | +| cifar10 | 100 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 2 | +| cifar10 | 100 | dino_resnet50 | 1 | 1.64 | 1.34 | 3 | +| cifar10 | 100 | dino_resnet50 | 8 | 1.64 | 1.34 | 10 | +| cifar10 | 100 | vit_b_16 | 0.1 | 1.64 | 1.34 | 1 | +| cifar10 | 100 | vit_b_16 | 0.2 | 1.64 | 1.34 | 1 | +| cifar10 | 100 | vit_b_16 | 1 | 1.64 | 1.34 | 1 | +| cifar10 | 100 | vit_b_16 | 8 | 1.64 | 1.34 | 5 | +| cifar10 | 100 | vit_h_14 | 0.1 | 1.506 | 1.42 | 1 | +| cifar10 | 100 | vit_h_14 | 0.2 | 1.506 | 1.42 | 1 | +| cifar10 | 100 | vit_h_14 | 1 | 1.506 | 1.42 | 3 | +| cifar10 | 100 | vit_h_14 | 8 | 1.506 | 1.42 | 5 | +| cifar10 | 100 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.64 | 1.34 | 1 | +| cifar10 | 100 | vit_large_patch14_dinov2.lvd142m | 0.2 | 2 | 0 | 1 | +| cifar10 | 100 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | 0 | 1 | +| cifar10 | 100 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | 5 | +| cifar10 | 50 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 3 | +| cifar10 | 50 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 2 | +| cifar10 | 50 | dino_resnet50 | 1 | 1.64 | 1.34 | 5 | +| cifar10 | 50 | dino_resnet50 | 8 | 2 | 0 | 3 | +| cifar10 | 50 | vit_b_16 | 0.1 | 1.64 | 1.34 | 1 | +| cifar10 | 50 | vit_b_16 | 0.2 | 1.64 | 1.34 | 1 | +| cifar10 | 50 | vit_b_16 | 1 | 1.64 | 1.34 | 1 | +| cifar10 | 50 | vit_b_16 | 8 | 2 | 0 | 3 | +| cifar10 | 50 | vit_h_14 | 0.1 | 1.506 | 1.42 | 1 | +| cifar10 | 50 | vit_h_14 | 0.2 | 1.506 | 1.42 | 1 | +| cifar10 | 50 | vit_h_14 | 1 | 1.506 | 1.42 | 10 | +| cifar10 | 50 | vit_h_14 | 8 | 2 | 0 | 20 | +| cifar10 | 50 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.64 | 1.34 | 1 | +| cifar10 | 50 | vit_large_patch14_dinov2.lvd142m | 0.2 | 2 | 0 | 2 | +| cifar10 | 50 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | 0 | 1 | +| cifar10 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | 10 | +| cifar100 | 1 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 2 | +| cifar100 | 1 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 2 | +| cifar100 | 1 | dino_resnet50 | 1 | 1.64 | 1.34 | 5 | +| cifar100 | 1 | dino_resnet50 | 8 | 2 | 0 | 10 | +| cifar100 | 1 | vit_b_16 | 0.1 | 1.64 | 1.34 | 1 | +| cifar100 | 1 | vit_b_16 | 0.2 | 1.64 | 1.34 | 1 | +| cifar100 | 1 | vit_b_16 | 1 | 1.64 | 1.34 | 5 | +| cifar100 | 1 | vit_b_16 | 8 | 2 | 0 | 10 | +| cifar100 | 1 | vit_h_14 | 0.1 | 1.64 | 1.34 | 1 | +| cifar100 | 1 | vit_h_14 | 0.2 | 1.64 | 1.34 | 2 | +| cifar100 | 1 | vit_h_14 | 1 | 1.506 | 1.42 | 10 | +| cifar100 | 1 | vit_h_14 | 8 | 1.506 | 1.42 | 20 | +| cifar100 | 1 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.64 | 1.34 | 1 | +| cifar100 | 1 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.64 | 1.34 | 1 | +| cifar100 | 1 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | 0 | 3 | +| cifar100 | 1 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | 20 | +| cifar100 | 10 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 2 | +| cifar100 | 10 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 2 | +| cifar100 | 10 | dino_resnet50 | 1 | 1.64 | 1.34 | 3 | +| cifar100 | 10 | dino_resnet50 | 8 | 1.64 | 1.34 | 5 | +| cifar100 | 10 | vit_b_16 | 0.1 | 1.64 | 1.34 | 1 | +| cifar100 | 10 | vit_b_16 | 0.2 | 1.64 | 1.34 | 1 | +| cifar100 | 10 | vit_b_16 | 1 | 1.64 | 1.34 | 1 | +| cifar100 | 10 | vit_b_16 | 8 | 1.64 | 1.34 | 5 | +| cifar100 | 10 | vit_h_14 | 0.1 | 1.506 | 1.42 | 1 | +| cifar100 | 10 | vit_h_14 | 0.2 | 1.506 | 1.42 | 1 | +| cifar100 | 10 | vit_h_14 | 1 | 1.64 | 1.34 | 2 | +| cifar100 | 10 | vit_h_14 | 8 | 1.64 | 1.34 | 10 | +| cifar100 | 10 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.64 | 1.34 | 1 | +| cifar100 | 10 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.64 | 1.34 | 1 | +| cifar100 | 10 | vit_large_patch14_dinov2.lvd142m | 1 | 1.64 | 1.34 | 1 | +| cifar100 | 10 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | 3 | +| cifar100 | 100 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 2 | +| cifar100 | 100 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 3 | +| cifar100 | 100 | dino_resnet50 | 1 | 1.64 | 1.34 | 3 | +| cifar100 | 100 | dino_resnet50 | 8 | 1.64 | 1.34 | 3 | +| cifar100 | 100 | vit_b_16 | 0.1 | 1.64 | 1.34 | 1 | +| cifar100 | 100 | vit_b_16 | 0.2 | 1.64 | 1.34 | 1 | +| cifar100 | 100 | vit_b_16 | 1 | 1.64 | 1.34 | 1 | +| cifar100 | 100 | vit_b_16 | 8 | 1.64 | 1.34 | 2 | +| cifar100 | 100 | vit_h_14 | 0.1 | 1.506 | 1.42 | 1 | +| cifar100 | 100 | vit_h_14 | 0.2 | 1.506 | 1.42 | 1 | +| cifar100 | 100 | vit_h_14 | 1 | 1.54 | 1.46 | 1 | +| cifar100 | 100 | vit_h_14 | 8 | 1.54 | 1.46 | 2 | +| cifar100 | 100 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.64 | 1.34 | 1 | +| cifar100 | 100 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.64 | 1.34 | 1 | +| cifar100 | 100 | vit_large_patch14_dinov2.lvd142m | 1 | 1.64 | 1.34 | 1 | +| cifar100 | 100 | vit_large_patch14_dinov2.lvd142m | 8 | 1.64 | 1.34 | 1 | +| cifar100 | 50 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 2 | +| cifar100 | 50 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 2 | +| cifar100 | 50 | dino_resnet50 | 1 | 1.64 | 1.34 | 3 | +| cifar100 | 50 | dino_resnet50 | 8 | 1.64 | 1.34 | 5 | +| cifar100 | 50 | vit_b_16 | 0.1 | 1.64 | 1.34 | 1 | +| cifar100 | 50 | vit_b_16 | 0.2 | 1.64 | 1.34 | 1 | +| cifar100 | 50 | vit_b_16 | 1 | 1.64 | 1.34 | 1 | +| cifar100 | 50 | vit_b_16 | 8 | 1.64 | 1.34 | 2 | +| cifar100 | 50 | vit_h_14 | 0.1 | 1.506 | 1.42 | 1 | +| cifar100 | 50 | vit_h_14 | 0.2 | 1.506 | 1.42 | 1 | +| cifar100 | 50 | vit_h_14 | 1 | 1.506 | 1.42 | 1 | +| cifar100 | 50 | vit_h_14 | 8 | 1.54 | 1.46 | 3 | +| cifar100 | 50 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.64 | 1.34 | 1 | +| cifar100 | 50 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.64 | 1.34 | 1 | +| cifar100 | 50 | vit_large_patch14_dinov2.lvd142m | 1 | 1.64 | 1.34 | 1 | +| cifar100 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | 1 | +| food101 | 1 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 2 | +| food101 | 1 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 2 | +| food101 | 1 | dino_resnet50 | 1 | 1.64 | 1.34 | 10 | +| food101 | 1 | dino_resnet50 | 8 | 1.64 | 1.34 | 20 | +| food101 | 1 | vit_b_16 | 0.1 | 1.64 | 1.34 | 1 | +| food101 | 1 | vit_b_16 | 0.2 | 1.64 | 1.34 | 2 | +| food101 | 1 | vit_b_16 | 1 | 1.64 | 1.34 | 5 | +| food101 | 1 | vit_b_16 | 8 | 2 | 0 | 20 | +| food101 | 1 | vit_h_14 | 0.1 | 1.64 | 1.34 | 1 | +| food101 | 1 | vit_h_14 | 0.2 | 1.64 | 1.34 | 2 | +| food101 | 1 | vit_h_14 | 1 | 1.64 | 1.34 | 5 | +| food101 | 1 | vit_h_14 | 8 | 2 | 0 | 10 | +| food101 | 1 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.64 | 1.34 | 1 | +| food101 | 1 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.64 | 1.34 | 1 | +| food101 | 1 | vit_large_patch14_dinov2.lvd142m | 1 | 1.64 | 1.34 | 5 | +| food101 | 1 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | 20 | +| food101 | 10 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 2 | +| food101 | 10 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 2 | +| food101 | 10 | dino_resnet50 | 1 | 1.64 | 1.34 | 3 | +| food101 | 10 | dino_resnet50 | 8 | 1.64 | 1.34 | 10 | +| food101 | 10 | vit_b_16 | 0.1 | 1.64 | 1.34 | 1 | +| food101 | 10 | vit_b_16 | 0.2 | 1.64 | 1.34 | 1 | +| food101 | 10 | vit_b_16 | 1 | 1.64 | 1.34 | 2 | +| food101 | 10 | vit_b_16 | 8 | 1.64 | 1.34 | 5 | +| food101 | 10 | vit_h_14 | 0.1 | 1.64 | 1.34 | 1 | +| food101 | 10 | vit_h_14 | 0.2 | 1.64 | 1.34 | 1 | +| food101 | 10 | vit_h_14 | 1 | 1.64 | 1.34 | 2 | +| food101 | 10 | vit_h_14 | 8 | 1.64 | 1.34 | 5 | +| food101 | 10 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.64 | 1.34 | 1 | +| food101 | 10 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.64 | 1.34 | 1 | +| food101 | 10 | vit_large_patch14_dinov2.lvd142m | 1 | 1.64 | 1.34 | 2 | +| food101 | 10 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | 5 | +| food101 | 100 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 2 | +| food101 | 100 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 2 | +| food101 | 100 | dino_resnet50 | 1 | 1.64 | 1.34 | 3 | +| food101 | 100 | dino_resnet50 | 8 | 1.64 | 1.34 | 5 | +| food101 | 100 | vit_b_16 | 0.1 | 1.64 | 1.34 | 1 | +| food101 | 100 | vit_b_16 | 0.2 | 1.64 | 1.34 | 1 | +| food101 | 100 | vit_b_16 | 1 | 1.64 | 1.34 | 1 | +| food101 | 100 | vit_b_16 | 8 | 1.64 | 1.34 | 2 | +| food101 | 100 | vit_h_14 | 0.1 | 1.64 | 1.34 | 1 | +| food101 | 100 | vit_h_14 | 0.2 | 1.64 | 1.34 | 1 | +| food101 | 100 | vit_h_14 | 1 | 1.64 | 1.34 | 1 | +| food101 | 100 | vit_h_14 | 8 | 1.64 | 1.34 | 2 | +| food101 | 100 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.64 | 1.34 | 1 | +| food101 | 100 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.64 | 1.34 | 1 | +| food101 | 100 | vit_large_patch14_dinov2.lvd142m | 1 | 1.64 | 1.34 | 1 | +| food101 | 100 | vit_large_patch14_dinov2.lvd142m | 8 | 1.64 | 1.34 | 2 | +| food101 | 50 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 2 | +| food101 | 50 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 2 | +| food101 | 50 | dino_resnet50 | 1 | 1.64 | 1.34 | 3 | +| food101 | 50 | dino_resnet50 | 8 | 1.64 | 1.34 | 10 | +| food101 | 50 | vit_b_16 | 0.1 | 1.64 | 1.34 | 1 | +| food101 | 50 | vit_b_16 | 0.2 | 1.64 | 1.34 | 1 | +| food101 | 50 | vit_b_16 | 1 | 1.64 | 1.34 | 2 | +| food101 | 50 | vit_b_16 | 8 | 1.64 | 1.34 | 3 | +| food101 | 50 | vit_h_14 | 0.1 | 1.64 | 1.34 | 1 | +| food101 | 50 | vit_h_14 | 0.2 | 1.64 | 1.34 | 1 | +| food101 | 50 | vit_h_14 | 1 | 1.64 | 1.34 | 1 | +| food101 | 50 | vit_h_14 | 8 | 1.64 | 1.34 | 3 | +| food101 | 50 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.64 | 1.34 | 1 | +| food101 | 50 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.64 | 1.34 | 1 | +| food101 | 50 | vit_large_patch14_dinov2.lvd142m | 1 | 1.64 | 1.34 | 1 | +| food101 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 1.64 | 1.34 | 3 | +| stl10 | 1 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 1 | +| stl10 | 1 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 2 | +| stl10 | 1 | dino_resnet50 | 1 | 1.64 | 1.34 | 5 | +| stl10 | 1 | dino_resnet50 | 8 | 1.64 | 1.34 | 20 | +| stl10 | 1 | vit_b_16 | 0.1 | 1.64 | 1.34 | 1 | +| stl10 | 1 | vit_b_16 | 0.2 | 1.64 | 1.34 | 1 | +| stl10 | 1 | vit_b_16 | 1 | 2 | 0 | 3 | +| stl10 | 1 | vit_b_16 | 8 | 2 | 0 | 10 | +| stl10 | 1 | vit_h_14 | 0.1 | 1.64 | 1.34 | 1 | +| stl10 | 1 | vit_h_14 | 0.2 | 1.64 | 1.34 | 1 | +| stl10 | 1 | vit_h_14 | 1 | 1.64 | 1.34 | 3 | +| stl10 | 1 | vit_h_14 | 8 | 2 | 0 | 20 | +| stl10 | 1 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.64 | 1.34 | 1 | +| stl10 | 1 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.64 | 1.34 | 1 | +| stl10 | 1 | vit_large_patch14_dinov2.lvd142m | 1 | 2 | 0 | 2 | +| stl10 | 1 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | 20 | +| stl10 | 10 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 1 | +| stl10 | 10 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 2 | +| stl10 | 10 | dino_resnet50 | 1 | 1.64 | 1.34 | 2 | +| stl10 | 10 | dino_resnet50 | 8 | 1.64 | 1.34 | 10 | +| stl10 | 10 | vit_b_16 | 0.1 | 1.64 | 1.34 | 1 | +| stl10 | 10 | vit_b_16 | 0.2 | 1.64 | 1.34 | 1 | +| stl10 | 10 | vit_b_16 | 1 | 1.64 | 1.34 | 1 | +| stl10 | 10 | vit_b_16 | 8 | 2 | 0 | 3 | +| stl10 | 10 | vit_h_14 | 0.1 | 1.64 | 1.34 | 2 | +| stl10 | 10 | vit_h_14 | 0.2 | 1.64 | 1.34 | 1 | +| stl10 | 10 | vit_h_14 | 1 | 1.64 | 1.34 | 2 | +| stl10 | 10 | vit_h_14 | 8 | 1.64 | 1.34 | 20 | +| stl10 | 10 | vit_large_patch14_dinov2.lvd142m | 0.1 | 2 | 0 | 20 | +| stl10 | 10 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.64 | 1.34 | 1 | +| stl10 | 10 | vit_large_patch14_dinov2.lvd142m | 1 | 1.64 | 1.34 | 1 | +| stl10 | 10 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | 3 | +| stl10 | 100 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 1 | +| stl10 | 100 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 2 | +| stl10 | 100 | dino_resnet50 | 1 | 1.64 | 1.34 | 2 | +| stl10 | 100 | dino_resnet50 | 8 | 1.64 | 1.34 | 2 | +| stl10 | 100 | vit_b_16 | 0.1 | 1.64 | 1.34 | 1 | +| stl10 | 100 | vit_b_16 | 0.2 | 1.64 | 1.34 | 1 | +| stl10 | 100 | vit_b_16 | 1 | 1.64 | 1.34 | 1 | +| stl10 | 100 | vit_b_16 | 8 | 1.64 | 1.34 | 1 | +| stl10 | 100 | vit_h_14 | 0.1 | 1.64 | 1.34 | 1 | +| stl10 | 100 | vit_h_14 | 0.2 | 1.64 | 1.34 | 3 | +| stl10 | 100 | vit_h_14 | 1 | 1.64 | 1.34 | 2 | +| stl10 | 100 | vit_h_14 | 8 | 1.64 | 1.34 | 2 | +| stl10 | 100 | vit_large_patch14_dinov2.lvd142m | 0.1 | 2 | 0 | 20 | +| stl10 | 100 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.64 | 1.34 | 1 | +| stl10 | 100 | vit_large_patch14_dinov2.lvd142m | 1 | 1.64 | 1.34 | 1 | +| stl10 | 100 | vit_large_patch14_dinov2.lvd142m | 8 | 1.64 | 1.34 | 1 | +| stl10 | 50 | dino_resnet50 | 0.1 | 1.64 | 1.34 | 1 | +| stl10 | 50 | dino_resnet50 | 0.2 | 1.64 | 1.34 | 2 | +| stl10 | 50 | dino_resnet50 | 1 | 1.64 | 1.34 | 1 | +| stl10 | 50 | dino_resnet50 | 8 | 1.64 | 1.34 | 2 | +| stl10 | 50 | vit_b_16 | 0.1 | 1.64 | 1.34 | 1 | +| stl10 | 50 | vit_b_16 | 0.2 | 1.64 | 1.34 | 1 | +| stl10 | 50 | vit_b_16 | 1 | 1.64 | 1.34 | 1 | +| stl10 | 50 | vit_b_16 | 8 | 2 | 0 | 1 | +| stl10 | 50 | vit_h_14 | 0.1 | 1.64 | 1.34 | 1 | +| stl10 | 50 | vit_h_14 | 0.2 | 1.64 | 1.34 | 2 | +| stl10 | 50 | vit_h_14 | 1 | 1.64 | 1.34 | 1 | +| stl10 | 50 | vit_h_14 | 8 | 1.64 | 1.34 | 5 | +| stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 0.1 | 2 | 0 | 20 | +| stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.64 | 1.34 | 1 | +| stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 1 | 1.64 | 1.34 | 1 | +| stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | 1 | \ No newline at end of file diff --git a/research/dppl_2024/lib/__init__.py b/research/dppl_2024/lib/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/research/dppl_2024/lib/coinpress.py b/research/dppl_2024/lib/coinpress.py new file mode 100644 index 00000000..f9d3df6a --- /dev/null +++ b/research/dppl_2024/lib/coinpress.py @@ -0,0 +1,77 @@ +from functools import partial + +import jax +import numpy as np +from jax import numpy as jnp + + +@jax.jit +def gaussian_tailbound_jit(d, b): + return (d + 2 * (d * jnp.log(1 / b)) ** 0.5 + 2 * jnp.log(1 / b)) ** 0.5 + + +@partial(jax.jit, static_argnames=("d",)) +def multivariate_mean_step_jit(X, c, r, p, n, d, subkey): + ## Determine a good clipping threshold + gamma = gaussian_tailbound_jit(d, 0.01) + clip_thresh = jnp.minimum( + (r**2 + 2 * r * 3 + gamma**2) ** 0.5, r + gamma + ) # 3 in place of sqrt(log(2/beta)) + + ## Round each of X1,...,Xn to the nearest point in the ball B2(c,clip_thresh) + x = X - c + mag_x = jnp.linalg.norm(x, axis=1) + + outside_ball_bool = mag_x > clip_thresh + x_hat = (x.T / mag_x).T + X = jnp.where( + outside_ball_bool[:, jnp.newaxis], + c + (x_hat * clip_thresh), + X, + ) + + ## Compute sensitivity + delta = 2 * clip_thresh / n.astype(float) + sd = delta / (2 * p) ** 0.5 + + ## Add noise calibrated to sensitivity + Y = sd * jax.random.normal(subkey, (d,)) + c = jnp.sum(X, axis=0) / n.astype(float) + Y + r = (1 / n.astype(float) + sd**2) ** 0.5 * gaussian_tailbound_jit(d, 0.01) + return c, r + + +def multivariate_mean_iterative_jit_inner(i, val, X, Ps, n, d, subkeys): + c, r = val + c, r = multivariate_mean_step_jit(X, c, r, Ps[i], n, d, subkeys[i]) + return (c, r) + + +@partial(jax.jit, static_argnames=("d", "t")) +def multivariate_mean_iterative_jit(X, c, r, t, Ps, n, d, key): + subkeys = jax.random.split(key, t) + init_val = c, r + (c, r) = jax.lax.fori_loop( + 0, + t, + partial( + multivariate_mean_iterative_jit_inner, X=X, Ps=Ps, n=n, d=d, subkeys=subkeys + ), + init_val, + ) + return c + + +def private_mean_jit(X, Ps, key=jax.random.key(42), r=None, c=None): + if len(X.shape) != 2: + raise ValueError("X must be a 2D array, but received shape: {}".format(X.shape)) + d = X.shape[1] + if r is None: + r = np.sqrt(d) * 0.9 + if c is None: + c = np.zeros(d) + t = len(Ps) + mean = multivariate_mean_iterative_jit( + X, c=c, r=r, t=t, Ps=Ps, n=X.shape[0], d=d, key=key + ) + return mean diff --git a/research/dppl_2024/lib/public.py b/research/dppl_2024/lib/public.py new file mode 100644 index 00000000..f56899f7 --- /dev/null +++ b/research/dppl_2024/lib/public.py @@ -0,0 +1,150 @@ +from functools import partial + +import jax +import numpy as np +from jax import numpy as jnp +from jax import scipy as jsc + + +def exponential( + scores: np.ndarray, + sensitivity: float, + epsilon: float, + size: int = 1, + max_fix: bool = True, + monotonic: bool = False, + key: int = 0, +) -> np.ndarray: + """Perform exponential sampling on the scores. + + Args: + scores (np.ndarray): The scores of the elements in R. + sensitivity (float): Sensitivity of the score function w.r.t. the private data. + epsilon (float): pure-differential privacy parameter. + size (int, optional): Number of independent samplings to perform (e.g. for reporting avg/std of accuracy). Defaults to 1. + max_fix (bool, optional): Perform a numeric fix by multiplying all probablities with exp(-max_exponent). Defaults to True. + monotonic (bool, optional): Use lower privacy bound when the score function is monotonic w.r.t. to the private dataset. Defaults to False. + key (int, optional): Random key for reproducibility. Defaults to 0. + + Returns: + np.ndarray: array of indice(s) of the sampled element(s). + """ + if np.isposinf(epsilon): + max_idx = scores.argmax() + max_idx = max_idx.repeat(size) + return max_idx + + sensitivity_factor = 1 if monotonic else 2 + + # Substract maximum exponent to avoid overflow + if max_fix: + max_exponent = epsilon * scores.max() / (sensitivity_factor * sensitivity) + else: + max_exponent = 0 + # Calculate the probability for each element, based on its score + probabilities = np.exp( + epsilon * scores / (sensitivity_factor * sensitivity) - max_exponent + ) + # Normalize the probabilties so they sum to 1 + probabilities = probabilities / np.linalg.norm(probabilities, ord=1) + + # Choose an element from R based on the probabilities + rng = np.random.default_rng(key) + return rng.choice(len(scores), size, p=probabilities, replace=True) + + +@jax.jit +def log_binom(n: int, k: int) -> float: + """Calculate log(n choose k) + + Args: + n (int): n + k (int): k + + Returns: + float: log(n choose k) + """ + return ( + jsc.special.gammaln(n + 1) + - jsc.special.gammaln(k + 1) + - jsc.special.gammaln(n - k + 1) + ) + + +@partial( + jax.jit, + static_argnames=["total_rows", "total_cols"], +) +def exponential_parallel( + U: jnp.ndarray, + logm: jnp.ndarray, + total_rows: int, + total_cols: int, + epsilon: float, + key: int = 42, +) -> jnp.ndarray: + """Perform parallel exponential sampling of all classes. + + Args: + U (jnp.ndarray): Scores, shape (total_rows, total_cols) = (n_classes, n_public_samples). + logm (jnp.ndarray): Log-Counts of the Utilities, shape (total_cols,) = (n_public_samples,). + total_rows (int): U.shape[0] = n_classes. (needed for jit compilation) + total_cols (int): U.shape[1] = n_public_samples. (needed for jit compilation) + epsilon (float): pure-differential privacy parameter. + key (int, optional): PRNG-initialization. Defaults to 42. + + Returns: + jnp.ndarray: array of indice(s) of the sampled element(s), shape (total_rows,) = (n_classes,). + """ + rng = jax.random.key(key) + choices = ( + jnp.log(jnp.log(1 / jax.random.uniform(rng, (total_rows, total_cols)))) + - logm + - epsilon * U / 2 + ).argmin(axis=-1) + return choices + + +@partial(jax.jit, static_argnames=("total_rows", "total_cols", "k")) +def give_topk_proto_idx( + U: jnp.ndarray, + logm: jnp.ndarray, + k: int, + total_rows: int, + total_cols: int, + epsilon: float, + key: int = 42, +): + """Perform the private top-k prototyping. First, perform exponential sampling on the utilities. + Then, uniformly sample the remaining k-1 prototypes, s.t. their utility is equal or better. + + Args: + U (jnp.ndarray): Scores, shape (total_rows, total_cols) = (n_classes, n_public_samples). + logm (jnp.ndarray): Log-Counts of the Utilities, shape (total_cols,) = (n_public_samples,). + k (int): Number of prototypes per class to sample. + total_rows (int): U.shape[0] = n_classes. (needed for jit compilation) + total_cols (int): U.shape[1] = n_public_samples. (needed for jit compilation) + epsilon (float): pure-differential privacy parameter. + key (int, optional): PRNG-initialization. Defaults to 42. + + Returns: + jnp.ndarray: array of indice(s) of the sampled element(s), shape (total_rows, k) = (n_classes, k). + """ + choices = exponential_parallel( + U, logm, total_rows, total_cols, epsilon, key + ).astype(int) + + proto_idx_C = jnp.concatenate( + [ + jax.lax.select( + jnp.arange(total_cols)[jnp.newaxis, :].repeat(total_rows, axis=0) + < choices[:, jnp.newaxis], + -jax.random.uniform(jax.random.key(key), (total_rows, total_cols)), + jnp.stack([jnp.zeros((total_cols)) for row in jnp.arange(total_rows)]), + ).argsort(axis=-1)[:, : k - 1], + choices[:, jnp.newaxis], + ], + axis=1, + ) + + return proto_idx_C diff --git a/research/dppl_2024/lib/utils.py b/research/dppl_2024/lib/utils.py new file mode 100644 index 00000000..cc882c25 --- /dev/null +++ b/research/dppl_2024/lib/utils.py @@ -0,0 +1,140 @@ +from functools import partial + +import numpy as np +from jax import jit, vmap +from jax import numpy as jnp +from omegaconf import DictConfig + + +def load_dataset(cfg: DictConfig): + X_train = np.load(cfg.dataset.train_data) + Y_train = np.load(cfg.dataset.train_labels) + X_test = np.load(cfg.dataset.test_data) + Y_test = np.load(cfg.dataset.test_labels) + + return X_train, Y_train, X_test, Y_test + + +def load_public_dataset(cfg: DictConfig): + X_public = np.load(cfg.dataset.public_data) + return X_public + + +def decay(cls: int | np.ndarray, max_samples: int, num_classes: int, ratio: float = 10): + decay = -np.log(ratio) / num_classes + return np.round(max_samples * np.exp(decay * cls)).astype(int) + + +def give_imbalanced_set(x, y, imbalance_ratio: float = 10, seed: int = 42): + classes = np.unique(y) + X_classes = [x[y == i] for i in classes] + rng = np.random.default_rng(seed) + input_samples_per_class = np.asarray([(y == i).sum() for i in classes]) + + output_samples_per_class = decay( + np.linspace(0, len(classes), len(classes)), + max_samples=input_samples_per_class.min(), + num_classes=len(classes), + ratio=imbalance_ratio, + ) + rng.shuffle(output_samples_per_class) + x = np.concatenate( + [ + X_classes[i][:num_samples] + for i, num_samples in enumerate(output_samples_per_class) + ] + ) + y = np.concatenate( + [ + np.repeat(i, num_samples) + for i, num_samples in enumerate(output_samples_per_class) + ] + ) + return x, y + + +def zcdp_of_naive_epsilon(epsilon): + return epsilon**2 / 2 + + +def exponential_epsilon_of_zcdp(rho): + return np.sqrt(8 * rho) + + +@jit +def pairwise_distance(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + """Calculate 1-cosine_similarity between x and y. + + Args: + x (jnp.ndarray): x + y (jnp.ndarray): y + + Returns: + jnp.ndarray: pairwise distance(s) + """ + x = x / jnp.linalg.norm(x, axis=-1, keepdims=True) + y = y / jnp.linalg.norm(y, axis=-1, keepdims=True) + return 1 - jnp.dot(x, y.T) + + +@jit +def scores_single( + X: jnp.ndarray, y: jnp.ndarray, min_score: float, max_score: float +) -> jnp.ndarray: + """Score Calculation for a single public sample. The score is calculated as the sum + of the clipped pairwise distances between the public sample and the private samples. + + Args: + X (jnp.ndarray): private dataset + y (jnp.ndarray): public sample + min_score (float): minimum score (in [0,2)) + max_score (float): maximum score (in (min_score, 2]) + + Returns: + jnp.ndarray: Score of the public sample + """ + return jnp.sum( + ( + ( + jnp.clip( + 2 - vmap(pairwise_distance, in_axes=(0, None))(X, y), + min_score, + max_score, + ) + - min_score + ) + / (max_score - min_score) + ), + axis=0, + ) + + +@partial(jit, static_argnames=["batch_size_y"]) +def scores_multiple( + X: jnp.ndarray, + Y: jnp.ndarray, + min_score: float = 0.0, + max_score: float = 2.0, + batch_size_y: int = 5000, +) -> jnp.ndarray: + """Perform the score calculation batched over the public samples. + + Args: + X (jnp.ndarray): private dataset + Y (jnp.ndarray): public dataset + min_score (float, optional): minimum score (in [0,2)). Defaults to 0.0. + max_score (float, optional): maximum score (in (min_score, 2]). Defaults to 2.0. + batch_size_y (int, optional): batch size (impacts VRAM usage). Defaults to 5000. + + Returns: + jnp.ndarray: Scores of all public samples in Y + """ + return jnp.concatenate( + [ + vmap( + partial(scores_single, min_score=min_score, max_score=max_score), + in_axes=(None, 0), + )(X, Y[i : min(i + batch_size_y, len(Y))]) + for i in range(0, len(Y), batch_size_y) + ], + ) diff --git a/research/dppl_2024/requirements.txt b/research/dppl_2024/requirements.txt new file mode 100644 index 00000000..0e83ed25 --- /dev/null +++ b/research/dppl_2024/requirements.txt @@ -0,0 +1,43 @@ +absl-py==2.1.0 +antlr4-python3-runtime==4.9.3 +chex==0.1.86 +etils==1.7.0 +flax==0.8.3 +fsspec==2024.5.0 +hydra-core==1.3.2 +importlib_resources==6.4.0 +jax==0.4.28 +jax-cuda12-pjrt==0.4.28 +jax-cuda12-plugin==0.4.28 +jaxlib==0.4.28 +markdown-it-py==3.0.0 +mdurl==0.1.2 +ml-dtypes==0.4.0 +msgpack==1.0.8 +nest-asyncio==1.6.0 +numpy==1.26.4 +nvidia-cublas-cu12==12.4.5.8 +nvidia-cuda-cupti-cu12==12.4.127 +nvidia-cuda-nvcc-cu12==12.4.131 +nvidia-cuda-nvrtc-cu12==12.4.127 +nvidia-cuda-runtime-cu12==12.4.127 +nvidia-cudnn-cu12==8.9.7.29 +nvidia-cufft-cu12==11.2.1.3 +nvidia-cusolver-cu12==11.6.1.9 +nvidia-cusparse-cu12==12.3.1.170 +nvidia-nccl-cu12==2.21.5 +nvidia-nvjitlink-cu12==12.4.127 +omegaconf==2.3.0 +opt-einsum==3.3.0 +optax==0.2.2 +orbax-checkpoint==0.5.11 +packaging==24.0 +protobuf==5.26.1 +Pygments==2.18.0 +PyYAML==6.0.1 +rich==13.7.1 +scipy==1.13.0 +tensorstore==0.1.59 +toolz==0.12.1 +typing_extensions==4.11.0 +zipp==3.18.2 From dd991c1bfda7216b19d2a868ffc2118f4fc83b50 Mon Sep 17 00:00:00 2001 From: Dariush Wahdany Date: Tue, 25 Jun 2024 19:09:21 +0200 Subject: [PATCH 2/5] fix: remove macOS file --- research/dppl_2024/.DS_Store | Bin 6148 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 research/dppl_2024/.DS_Store diff --git a/research/dppl_2024/.DS_Store b/research/dppl_2024/.DS_Store deleted file mode 100644 index 878677845a65fa897646859996ebdbbabb735c39..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeH~JqiLr422WjLa^D=avBfd4F=H@cmYwd61NchIl3=D2(H#5@&d^>$xK-G6+0Ud z(e?eb66r-`1~ zw}U0m)np4syJ!v{nom}nVqhBWq6G;|tAl|GP=S#G)5v?f|F`f@^Z%%YDHWgsf2M$T zo84xOm&&{K?e(m_&#J8(9Q5M|FFyfD>?&Ts-LPM50oG&-q5|WOfXl!@1-`1l11psh AsQ>@~ From 10e77d4fe8d0f298a61d7c40677aadb9926cf239 Mon Sep 17 00:00:00 2001 From: Dariush Wahdany Date: Tue, 25 Jun 2024 19:12:50 +0200 Subject: [PATCH 3/5] feat: DPPL Readme --- research/dppl_2024/README.md | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/research/dppl_2024/README.md b/research/dppl_2024/README.md index 88b51e40..bf54fa18 100644 --- a/research/dppl_2024/README.md +++ b/research/dppl_2024/README.md @@ -1,7 +1,9 @@ -# Supplemental Material: Code Submission +# Beyond the Mean: Differentially Private Prototypes for Private Transfer Learning +This folder contains the code for - -## Paper Title: *Beyond the Mean: Differentially Private Prototypes for Private Transfer Learning* +**Beyond the Mean: Differentially Private Prototypes for Private Transfer Learning** +by Dariush Wahdany, Matthew Jagielski, Adam Dziedzic, Franziska Boenisch +https://arxiv.org/abs/2406.08039 Abstract: Machine learning (ML) models have been shown to leak private information from their training datasets. Differential Privacy (DP), typically implemented through the differential private stochastic gradient descent algorithm (DP-SGD), has become the standard solution to bound leakage from the models. Despite recent improvments, DP-SGD-based approaches for private learning still usually struggle in the high privacy ($\varepsilon<0.1$) and low data regimes, and when the private training datasets are imbalanced. To overcome these limitations, we propose Differentially Private Prototype Learning (DPPL) as a new paradigm for private transfer learning. DPPL leverages publicly pre-trained encoders to extract features from private data and generates DP prototypes that represent each private class in the embedding space and can be publicly released for inference. Since our DP prototypes can be obtained from only a few private training data points and without iterative noise addition, they offer high-utility predictions and strong privacy guarantees even under the notion of pure DP. We additionally show that privacy-utility trade-offs can be further improved when leveraging the public data beyond pre-training of the encoder: we are able to privately sample our DP prototypes from the publicly available data points used to train the encoder. Our experimental evaluation with four state-of-the-art encoders, four vision datasets, and under different data and unbalancedness regimes demonstrate DPPL's high performance under strong privacy guarantees in challenging private learning setups. @@ -81,8 +83,3 @@ python dppl_mean.py python dppl_public_topk.py ``` -## Contributing -We welcome any feedback during the review process. - -## License -Submitted to 38th Conference on Neural Information Processing Systems (NeurIPS 2024). Do not distribute From a60483277966d6788429fab6d3edca72fc2abc23 Mon Sep 17 00:00:00 2001 From: Dariush Wahdany Date: Tue, 25 Jun 2024 19:15:35 +0200 Subject: [PATCH 4/5] feat: link to embeddings --- research/dppl_2024/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/research/dppl_2024/README.md b/research/dppl_2024/README.md index bf54fa18..32c19132 100644 --- a/research/dppl_2024/README.md +++ b/research/dppl_2024/README.md @@ -57,6 +57,8 @@ Before running any of the experiments, set the path to your embeddings in `confi - Imbalance Ratio - Seed +We provide the required embeddings as a [huggingface dataset](https://huggingface.co/datasets/lsc64/DPPL-embeddings). + ### DPPL-Mean (Optional): In `config/mean.yaml`, change `pool` to any desired integer value. It configures the optional average pooling before the mean estimation and can improve utility especially at strict privacy budgets. From 9985230a1595d270948a0572b4fe61eeb1e74ea9 Mon Sep 17 00:00:00 2001 From: Dariush Wahdany Date: Thu, 27 Jun 2024 11:40:13 +0200 Subject: [PATCH 5/5] style: fixed (some?) checks --- research/dppl_2024/README.md | 4 +- research/dppl_2024/dppl_mean.py | 79 +++---- research/dppl_2024/dppl_public.py | 110 ++++----- research/dppl_2024/dppl_public_topk.py | 134 +++++------ research/dppl_2024/hparams_mean.md | 2 +- research/dppl_2024/hparams_public.md | 2 +- research/dppl_2024/hparams_public_topk.md | 2 +- research/dppl_2024/lib/coinpress.py | 113 +++++----- research/dppl_2024/lib/public.py | 260 ++++++++++++---------- research/dppl_2024/lib/utils.py | 205 ++++++++--------- 10 files changed, 472 insertions(+), 439 deletions(-) diff --git a/research/dppl_2024/README.md b/research/dppl_2024/README.md index 32c19132..0dbbdcf6 100644 --- a/research/dppl_2024/README.md +++ b/research/dppl_2024/README.md @@ -1,8 +1,8 @@ # Beyond the Mean: Differentially Private Prototypes for Private Transfer Learning This folder contains the code for -**Beyond the Mean: Differentially Private Prototypes for Private Transfer Learning** -by Dariush Wahdany, Matthew Jagielski, Adam Dziedzic, Franziska Boenisch +**Beyond the Mean: Differentially Private Prototypes for Private Transfer Learning** +by Dariush Wahdany, Matthew Jagielski, Adam Dziedzic, Franziska Boenisch https://arxiv.org/abs/2406.08039 Abstract: diff --git a/research/dppl_2024/dppl_mean.py b/research/dppl_2024/dppl_mean.py index 35558930..9d8baca5 100644 --- a/research/dppl_2024/dppl_mean.py +++ b/research/dppl_2024/dppl_mean.py @@ -2,52 +2,55 @@ import hydra import jax import jax.numpy as jnp -from lib import coinpress, utils from omegaconf import DictConfig, OmegaConf +from lib import coinpress, utils + @hydra.main(config_path="conf", config_name="mean", version_base=None) def main(cfg: DictConfig): - print(OmegaConf.to_yaml(cfg)) + print(OmegaConf.to_yaml(cfg)) - X_train, Y_train, X_test, Y_test = utils.load_dataset(cfg) - X_train = pooling.avg_pool( - X_train.T, window_shape=(cfg.pool,), strides=(cfg.pool,) - ).T - X_test = pooling.avg_pool(X_test.T, window_shape=(cfg.pool,), strides=(cfg.pool,)).T - x_imbalanced, y_imbalanced = utils.give_imbalanced_set( - X_train, Y_train, cfg.imbalance_ratio - ) - classes = jnp.unique(y_imbalanced) - if cfg.epsilon < jnp.inf: - rho = utils.zcdp_of_naive_epsilon(cfg.epsilon) - Ps = jnp.array([5 / 64, 7 / 64, 52 / 64]) * rho - key = jax.random.key(cfg.seed) - class_keys = jax.random.split(key, len(classes)) - r = jnp.sqrt(x_imbalanced.shape[1]) - protos = jnp.stack( - [ - coinpress.private_mean_jit( - x_imbalanced[y_imbalanced == i], Ps, key=class_keys[i], r=r - ) - for i in classes - ] - ) - else: - protos = jnp.stack( - [x_imbalanced[y_imbalanced == i].mean(axis=0) for i in classes] + x_train, y_train, x_test, y_test = utils.load_dataset(cfg) + x_train = pooling.avg_pool( + x_train.T, window_shape=(cfg.pool,), strides=(cfg.pool,) + ).T + x_test = pooling.avg_pool( + x_test.T, window_shape=(cfg.pool,), strides=(cfg.pool,) + ).T + x_imbalanced, y_imbalanced = utils.give_imbalanced_set( + x_train, y_train, cfg.imbalance_ratio + ) + classes = jnp.unique(y_imbalanced) + if cfg.epsilon < jnp.inf: + rho = utils.zcdp_of_naive_epsilon(cfg.epsilon) + ps = jnp.array([5 / 64, 7 / 64, 52 / 64]) * rho + key = jax.random.key(cfg.seed) + class_keys = jax.random.split(key, len(classes)) + r = jnp.sqrt(x_imbalanced.shape[1]) + protos = jnp.stack( + [ + coinpress.private_mean_jit( + x_imbalanced[y_imbalanced == i], ps, key=class_keys[i], r=r ) - dists_test = utils.pairwise_distance(protos, X_test) - test_acc = float((dists_test.argmin(axis=0) == Y_test).mean()) - test_acc_per_class = jnp.stack( - [ - (dists_test[..., Y_test == target].argmin(axis=0) == target).mean() - for target in classes - ] + for i in classes + ] + ) + else: + protos = jnp.stack( + [x_imbalanced[y_imbalanced == i].mean(axis=0) for i in classes] ) - print(f"Test accuracy: {test_acc}") - print(f"Test accuracy per class: {test_acc_per_class}") + dists_test = utils.pairwise_distance(protos, x_test) + test_acc = float((dists_test.argmin(axis=0) == y_test).mean()) + test_acc_per_class = jnp.stack( + [ + (dists_test[..., y_test == target].argmin(axis=0) == target).mean() + for target in classes + ] + ) + print(f"Test accuracy: {test_acc}") + print(f"Test accuracy per class: {test_acc_per_class}") if __name__ == "__main__": - main() + main() diff --git a/research/dppl_2024/dppl_public.py b/research/dppl_2024/dppl_public.py index 9829f372..0ae029ed 100644 --- a/research/dppl_2024/dppl_public.py +++ b/research/dppl_2024/dppl_public.py @@ -4,68 +4,70 @@ import jax import jax.numpy as jnp import numpy as np -from lib import public, utils from omegaconf import DictConfig, OmegaConf +from lib import public, utils + @hydra.main(config_path="conf", config_name="public", version_base=None) def main(cfg: DictConfig): - print(OmegaConf.to_yaml(cfg)) + print(OmegaConf.to_yaml(cfg)) - rho = utils.zcdp_of_naive_epsilon(cfg.epsilon) - actual_epsilon = utils.exponential_epsilon_of_zcdp(rho) - print( - f"Converted settings epsilon {cfg.epsilon} to rho {rho} to exponential epsilon {actual_epsilon}" - ) + rho = utils.zcdp_of_naive_epsilon(cfg.epsilon) + actual_epsilon = utils.exponential_epsilon_of_zcdp(rho) + print( + f"Converted settings epsilon {cfg.epsilon} to rho {rho} to \ + exponential epsilon {actual_epsilon}" + ) - X_train, Y_train, X_test, Y_test = utils.load_dataset(cfg) - X_public = utils.load_public_dataset(cfg) - x_imbalanced, y_imbalanced = utils.give_imbalanced_set( - X_train, Y_train, cfg.imbalance_ratio - ) - classes = jnp.unique(y_imbalanced) - try: - jax.devices("gpu") - except RuntimeError: - warnings.warn("No GPU found, falling back to CPU. This will be slow.") - scores = jnp.stack( - [ - utils.scores_multiple( - x_imbalanced[y_imbalanced == target], - X_public, - cfg.min_score, - cfg.max_score, - ) - for target in classes - ] - ) - sensitivity = 1.0 - proto_idx_per_class = [] - for target in classes: - proto_idx_per_class.append( - public.exponential( - scores=scores[target], - sensitivity=sensitivity, - epsilon=actual_epsilon, - size=1, - monotonic=True, - key=int(cfg.seed + target), - ) - ) - public_protos = X_public[np.concatenate(proto_idx_per_class)].reshape( - len(classes), X_public.shape[-1] - ) - dists_test = utils.pairwise_distance(public_protos, X_test) - test_acc = float((dists_test.argmin(axis=0) == Y_test).mean()) - test_acc_per_class = jnp.stack( - [ - (dists_test[..., Y_test == target].argmin(axis=0) == target).mean() - for target in classes - ] + x_train, y_train, x_test, y_test = utils.load_dataset(cfg) + x_public = utils.load_public_dataset(cfg) + x_imbalanced, y_imbalanced = utils.give_imbalanced_set( + x_train, y_train, cfg.imbalance_ratio + ) + classes = jnp.unique(y_imbalanced) + try: + jax.devices("gpu") + except RuntimeError: + warnings.warn("No GPU found, falling back to CPU. This will be slow.") + scores = jnp.stack( + [ + utils.scores_multiple( + x_imbalanced[y_imbalanced == target], + x_public, + cfg.min_score, + cfg.max_score, + ) + for target in classes + ] + ) + sensitivity = 1.0 + proto_idx_per_class = [] + for target in classes: + proto_idx_per_class.append( + public.exponential( + scores=scores[target], + sensitivity=sensitivity, + epsilon=actual_epsilon, + size=1, + monotonic=True, + key=int(cfg.seed + target), + ) ) - print(f"Test accuracy: {test_acc}") - print(f"Test accuracy per class: {test_acc_per_class}") + public_protos = x_public[np.concatenate(proto_idx_per_class)].reshape( + len(classes), x_public.shape[-1] + ) + dists_test = utils.pairwise_distance(public_protos, x_test) + test_acc = float((dists_test.argmin(axis=0) == y_test).mean()) + test_acc_per_class = jnp.stack( + [ + (dists_test[..., y_test == target].argmin(axis=0) == target).mean() + for target in classes + ] + ) + print(f"Test accuracy: {test_acc}") + print(f"Test accuracy per class: {test_acc_per_class}") if __name__ == "__main__": - main() + main() diff --git a/research/dppl_2024/dppl_public_topk.py b/research/dppl_2024/dppl_public_topk.py index f6868a0b..8e4d1a31 100644 --- a/research/dppl_2024/dppl_public_topk.py +++ b/research/dppl_2024/dppl_public_topk.py @@ -4,80 +4,82 @@ import hydra import jax import jax.numpy as jnp -from lib import public, utils from omegaconf import DictConfig, OmegaConf +from lib import public, utils + @hydra.main(config_path="conf", config_name="public_topk", version_base=None) def main(cfg: DictConfig): - print(OmegaConf.to_yaml(cfg)) + print(OmegaConf.to_yaml(cfg)) - rho = utils.zcdp_of_naive_epsilon(cfg.epsilon) - actual_epsilon = utils.exponential_epsilon_of_zcdp(rho) - print( - f"Converted settings epsilon {cfg.epsilon} to rho {rho} to exponential epsilon {actual_epsilon}" - ) + rho = utils.zcdp_of_naive_epsilon(cfg.epsilon) + actual_epsilon = utils.exponential_epsilon_of_zcdp(rho) + print( + f"Converted settings epsilon {cfg.epsilon} to rho {rho} to exponential \ + epsilon {actual_epsilon}" + ) - X_train, Y_train, X_test, Y_test = utils.load_dataset(cfg) - X_public = utils.load_public_dataset(cfg) - x_imbalanced, y_imbalanced = utils.give_imbalanced_set( - X_train, Y_train, cfg.imbalance_ratio - ) - classes = jnp.unique(y_imbalanced) - try: - jax.devices("gpu") - except RuntimeError: - warnings.warn("No GPU found, falling back to CPU. This will be slow.") - scores = jnp.stack( - [ - utils.scores_multiple( - x_imbalanced[y_imbalanced == target], - X_public, - cfg.min_score, - cfg.max_score, - ) - for target in classes - ] - ) - C_idx = jnp.argsort(scores, axis=1, descending=True) - if cfg.epsilon < jnp.inf: - C = jnp.stack([scores[i, C_idx[i]] for i in range(scores.shape[0])]) - U = C - C[:, cfg.k - 1][:, jnp.newaxis] - with jax.experimental.enable_x64(): - logm = jax.vmap(partial(public.log_binom, k=cfg.k), in_axes=(0))( - jnp.arange(scores.shape[-1]) - ) - proto_idx_C = public.give_topk_proto_idx( - U, - logm, - cfg.k, - U.shape[0], - U.shape[1], - actual_epsilon, - cfg.seed, - ) - proto_idx = jnp.stack( - [ - C_idx[jnp.arange(C_idx.shape[0]), proto_idx_C[:, k_i]] - for k_i in range(cfg.k) - ] - ).T - else: - proto_idx = jnp.stack( - [C_idx[jnp.arange(C_idx.shape[0]), k_i] for k_i in range(cfg.k)] - ).T - public_protos = X_public[proto_idx.flatten()].reshape((*proto_idx.shape, -1)) - dists_test = utils.pairwise_distance(public_protos, X_test) - test_acc = float((dists_test.argmin(axis=0) == Y_test).mean()) - test_acc_per_class = jnp.stack( - [ - (dists_test[..., Y_test == target].argmin(axis=0) == target).mean() - for target in classes - ] + x_train, y_train, x_test, y_test = utils.load_dataset(cfg) + x_public = utils.load_public_dataset(cfg) + x_imbalanced, y_imbalanced = utils.give_imbalanced_set( + x_train, y_train, cfg.imbalance_ratio + ) + classes = jnp.unique(y_imbalanced) + try: + jax.devices("gpu") + except RuntimeError: + warnings.warn("No GPU found, falling back to CPU. This will be slow.") + scores = jnp.stack( + [ + utils.scores_multiple( + x_imbalanced[y_imbalanced == target], + x_public, + cfg.min_score, + cfg.max_score, + ) + for target in classes + ] + ) + c_idx = jnp.argsort(scores, axis=1, descending=True) + if cfg.epsilon < jnp.inf: + c = jnp.stack([scores[i, c_idx[i]] for i in range(scores.shape[0])]) + u = c - c[:, cfg.k - 1][:, jnp.newaxis] + with jax.experimental.enable_x64(): + logm = jax.vmap(partial(public.log_binom, k=cfg.k), in_axes=(0))( + jnp.arange(scores.shape[-1]) + ) + proto_idx_c = public.give_topk_proto_idx( + u, + logm, + cfg.k, + u.shape[0], + u.shape[1], + actual_epsilon, + cfg.seed, ) - print(f"Test accuracy: {test_acc}") - print(f"Test accuracy per class: {test_acc_per_class}") + proto_idx = jnp.stack( + [ + c_idx[jnp.arange(c_idx.shape[0]), proto_idx_c[:, k_i]] + for k_i in range(cfg.k) + ] + ).T + else: + proto_idx = jnp.stack( + [c_idx[jnp.arange(c_idx.shape[0]), k_i] for k_i in range(cfg.k)] + ).T + public_protos = x_public[proto_idx.flatten()].reshape((*proto_idx.shape, -1)) + dists_test = utils.pairwise_distance(public_protos, x_test) + test_acc = float((dists_test.argmin(axis=0) == y_test).mean()) + test_acc_per_class = jnp.stack( + [ + (dists_test[..., y_test == target].argmin(axis=0) == target).mean() + for target in classes + ] + ) + print(f"Test accuracy: {test_acc}") + print(f"Test accuracy per class: {test_acc_per_class}") if __name__ == "__main__": - main() + main() diff --git a/research/dppl_2024/hparams_mean.md b/research/dppl_2024/hparams_mean.md index f33038bb..f6367c51 100644 --- a/research/dppl_2024/hparams_mean.md +++ b/research/dppl_2024/hparams_mean.md @@ -319,4 +319,4 @@ | stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 0.1 | 10 | | stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 0.2 | 10 | | stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 1 | 1 | -| stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 1 | \ No newline at end of file +| stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 1 | diff --git a/research/dppl_2024/hparams_public.md b/research/dppl_2024/hparams_public.md index 1f94dfe4..55964b70 100644 --- a/research/dppl_2024/hparams_public.md +++ b/research/dppl_2024/hparams_public.md @@ -255,4 +255,4 @@ | stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 0.1 | 1.5 | 1.42 | | stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.5 | 1.42 | | stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 1 | 1.5 | 1.42 | -| stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | \ No newline at end of file +| stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | diff --git a/research/dppl_2024/hparams_public_topk.md b/research/dppl_2024/hparams_public_topk.md index 1f42375a..c78f74a6 100644 --- a/research/dppl_2024/hparams_public_topk.md +++ b/research/dppl_2024/hparams_public_topk.md @@ -255,4 +255,4 @@ | stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 0.1 | 2 | 0 | 20 | | stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 0.2 | 1.64 | 1.34 | 1 | | stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 1 | 1.64 | 1.34 | 1 | -| stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | 1 | \ No newline at end of file +| stl10 | 50 | vit_large_patch14_dinov2.lvd142m | 8 | 2 | 0 | 1 | diff --git a/research/dppl_2024/lib/coinpress.py b/research/dppl_2024/lib/coinpress.py index f9d3df6a..b7133bda 100644 --- a/research/dppl_2024/lib/coinpress.py +++ b/research/dppl_2024/lib/coinpress.py @@ -7,71 +7,78 @@ @jax.jit def gaussian_tailbound_jit(d, b): - return (d + 2 * (d * jnp.log(1 / b)) ** 0.5 + 2 * jnp.log(1 / b)) ** 0.5 + return (d + 2 * (d * jnp.log(1 / b)) ** 0.5 + 2 * jnp.log(1 / b)) ** 0.5 @partial(jax.jit, static_argnames=("d",)) -def multivariate_mean_step_jit(X, c, r, p, n, d, subkey): - ## Determine a good clipping threshold - gamma = gaussian_tailbound_jit(d, 0.01) - clip_thresh = jnp.minimum( - (r**2 + 2 * r * 3 + gamma**2) ** 0.5, r + gamma - ) # 3 in place of sqrt(log(2/beta)) +def multivariate_mean_step_jit(x, c, r, p, n, d, subkey): + ## Determine a good clipping threshold + gamma = gaussian_tailbound_jit(d, 0.01) + clip_thresh = jnp.minimum( + (r**2 + 2 * r * 3 + gamma**2) ** 0.5, r + gamma + ) # 3 in place of sqrt(log(2/beta)) - ## Round each of X1,...,Xn to the nearest point in the ball B2(c,clip_thresh) - x = X - c - mag_x = jnp.linalg.norm(x, axis=1) + ## Round each of X1,...,Xn to the nearest point in the ball B2(c,clip_thresh) + x = x - c + mag_x = jnp.linalg.norm(x, axis=1) - outside_ball_bool = mag_x > clip_thresh - x_hat = (x.T / mag_x).T - X = jnp.where( - outside_ball_bool[:, jnp.newaxis], - c + (x_hat * clip_thresh), - X, - ) + outside_ball_bool = mag_x > clip_thresh + x_hat = (x.T / mag_x).T + x = jnp.where( + outside_ball_bool[:, jnp.newaxis], + c + (x_hat * clip_thresh), + x, + ) - ## Compute sensitivity - delta = 2 * clip_thresh / n.astype(float) - sd = delta / (2 * p) ** 0.5 + ## Compute sensitivity + delta = 2 * clip_thresh / n.astype(float) + sd = delta / (2 * p) ** 0.5 - ## Add noise calibrated to sensitivity - Y = sd * jax.random.normal(subkey, (d,)) - c = jnp.sum(X, axis=0) / n.astype(float) + Y - r = (1 / n.astype(float) + sd**2) ** 0.5 * gaussian_tailbound_jit(d, 0.01) - return c, r + ## Add noise calibrated to sensitivity + y = sd * jax.random.normal(subkey, (d,)) + c = jnp.sum(x, axis=0) / n.astype(float) + y + r = (1 / n.astype(float) + sd**2) ** 0.5 * gaussian_tailbound_jit(d, 0.01) + return c, r -def multivariate_mean_iterative_jit_inner(i, val, X, Ps, n, d, subkeys): - c, r = val - c, r = multivariate_mean_step_jit(X, c, r, Ps[i], n, d, subkeys[i]) - return (c, r) +def multivariate_mean_iterative_jit_inner(i, val, x, ps, n, d, subkeys): + c, r = val + c, r = multivariate_mean_step_jit(x, c, r, ps[i], n, d, subkeys[i]) + return (c, r) @partial(jax.jit, static_argnames=("d", "t")) -def multivariate_mean_iterative_jit(X, c, r, t, Ps, n, d, key): - subkeys = jax.random.split(key, t) - init_val = c, r - (c, r) = jax.lax.fori_loop( - 0, - t, - partial( - multivariate_mean_iterative_jit_inner, X=X, Ps=Ps, n=n, d=d, subkeys=subkeys - ), - init_val, - ) - return c +def multivariate_mean_iterative_jit(x, c, r, t, ps, n, d, key): + subkeys = jax.random.split(key, t) + init_val = c, r + (c, r) = jax.lax.fori_loop( + 0, + t, + partial( + multivariate_mean_iterative_jit_inner, + x=x, + ps=ps, + n=n, + d=d, + subkeys=subkeys, + ), + init_val, + ) + return c -def private_mean_jit(X, Ps, key=jax.random.key(42), r=None, c=None): - if len(X.shape) != 2: - raise ValueError("X must be a 2D array, but received shape: {}".format(X.shape)) - d = X.shape[1] - if r is None: - r = np.sqrt(d) * 0.9 - if c is None: - c = np.zeros(d) - t = len(Ps) - mean = multivariate_mean_iterative_jit( - X, c=c, r=r, t=t, Ps=Ps, n=X.shape[0], d=d, key=key +def private_mean_jit(x, ps, key=jax.random.key(42), r=None, c=None): + if len(x.shape) != 2: + raise ValueError( + "X must be a 2D array, but received shape: {}".format(x.shape) ) - return mean + d = x.shape[1] + if r is None: + r = np.sqrt(d) * 0.9 + if c is None: + c = np.zeros(d) + t = len(ps) + mean = multivariate_mean_iterative_jit( + x, c=c, r=r, t=t, ps=ps, n=x.shape[0], d=d, key=key + ) + return mean diff --git a/research/dppl_2024/lib/public.py b/research/dppl_2024/lib/public.py index f56899f7..0742f757 100644 --- a/research/dppl_2024/lib/public.py +++ b/research/dppl_2024/lib/public.py @@ -7,144 +7,158 @@ def exponential( - scores: np.ndarray, - sensitivity: float, - epsilon: float, - size: int = 1, - max_fix: bool = True, - monotonic: bool = False, - key: int = 0, + scores: np.ndarray, + sensitivity: float, + epsilon: float, + size: int = 1, + max_fix: bool = True, + monotonic: bool = False, + key: int = 0, ) -> np.ndarray: - """Perform exponential sampling on the scores. - - Args: - scores (np.ndarray): The scores of the elements in R. - sensitivity (float): Sensitivity of the score function w.r.t. the private data. - epsilon (float): pure-differential privacy parameter. - size (int, optional): Number of independent samplings to perform (e.g. for reporting avg/std of accuracy). Defaults to 1. - max_fix (bool, optional): Perform a numeric fix by multiplying all probablities with exp(-max_exponent). Defaults to True. - monotonic (bool, optional): Use lower privacy bound when the score function is monotonic w.r.t. to the private dataset. Defaults to False. - key (int, optional): Random key for reproducibility. Defaults to 0. - - Returns: - np.ndarray: array of indice(s) of the sampled element(s). - """ - if np.isposinf(epsilon): - max_idx = scores.argmax() - max_idx = max_idx.repeat(size) - return max_idx - - sensitivity_factor = 1 if monotonic else 2 - - # Substract maximum exponent to avoid overflow - if max_fix: - max_exponent = epsilon * scores.max() / (sensitivity_factor * sensitivity) - else: - max_exponent = 0 - # Calculate the probability for each element, based on its score - probabilities = np.exp( - epsilon * scores / (sensitivity_factor * sensitivity) - max_exponent - ) - # Normalize the probabilties so they sum to 1 - probabilities = probabilities / np.linalg.norm(probabilities, ord=1) - - # Choose an element from R based on the probabilities - rng = np.random.default_rng(key) - return rng.choice(len(scores), size, p=probabilities, replace=True) + """Perform exponential sampling on the scores. + + Args: + scores (np.ndarray): The scores of the elements in R. + sensitivity (float): Sensitivity of the score function w.r.t. \ + the private data. + epsilon (float): pure-differential privacy parameter. + size (int, optional): Number of independent samplings to perform (e.g. \ + for reporting avg/std of accuracy). Defaults to 1. + max_fix (bool, optional): Perform a numeric fix by multiplying all \ + probablities with exp(-max_exponent). Defaults to True. + monotonic (bool, optional): Use lower privacy bound when the score \ + function is monotonic w.r.t. to the private dataset. Defaults to False. + key (int, optional): Random key for reproducibility. Defaults to 0. + + Returns: + np.ndarray: array of indice(s) of the sampled element(s). + """ + if np.isposinf(epsilon): + max_idx = scores.argmax() + max_idx = max_idx.repeat(size) + return max_idx + + sensitivity_factor = 1 if monotonic else 2 + + # Substract maximum exponent to avoid overflow + if max_fix: + max_exponent = epsilon * scores.max() / (sensitivity_factor * sensitivity) + else: + max_exponent = 0 + # Calculate the probability for each element, based on its score + probabilities = np.exp( + epsilon * scores / (sensitivity_factor * sensitivity) - max_exponent + ) + # Normalize the probabilties so they sum to 1 + probabilities = probabilities / np.linalg.norm(probabilities, ord=1) + + # Choose an element from R based on the probabilities + rng = np.random.default_rng(key) + return rng.choice(len(scores), size, p=probabilities, replace=True) @jax.jit def log_binom(n: int, k: int) -> float: - """Calculate log(n choose k) + """Calculate log(n choose k) - Args: - n (int): n - k (int): k + Args: + n (int): n + k (int): k - Returns: - float: log(n choose k) - """ - return ( - jsc.special.gammaln(n + 1) - - jsc.special.gammaln(k + 1) - - jsc.special.gammaln(n - k + 1) - ) + Returns: + float: log(n choose k) + """ + return ( + jsc.special.gammaln(n + 1) + - jsc.special.gammaln(k + 1) + - jsc.special.gammaln(n - k + 1) + ) @partial( - jax.jit, - static_argnames=["total_rows", "total_cols"], + jax.jit, + static_argnames=["total_rows", "total_cols"], ) def exponential_parallel( - U: jnp.ndarray, - logm: jnp.ndarray, - total_rows: int, - total_cols: int, - epsilon: float, - key: int = 42, + u: jnp.ndarray, + logm: jnp.ndarray, + total_rows: int, + total_cols: int, + epsilon: float, + key: int = 42, ) -> jnp.ndarray: - """Perform parallel exponential sampling of all classes. - - Args: - U (jnp.ndarray): Scores, shape (total_rows, total_cols) = (n_classes, n_public_samples). - logm (jnp.ndarray): Log-Counts of the Utilities, shape (total_cols,) = (n_public_samples,). - total_rows (int): U.shape[0] = n_classes. (needed for jit compilation) - total_cols (int): U.shape[1] = n_public_samples. (needed for jit compilation) - epsilon (float): pure-differential privacy parameter. - key (int, optional): PRNG-initialization. Defaults to 42. - - Returns: - jnp.ndarray: array of indice(s) of the sampled element(s), shape (total_rows,) = (n_classes,). - """ - rng = jax.random.key(key) - choices = ( - jnp.log(jnp.log(1 / jax.random.uniform(rng, (total_rows, total_cols)))) - - logm - - epsilon * U / 2 - ).argmin(axis=-1) - return choices + """Perform parallel exponential sampling of all classes. + + Args: + u (jnp.ndarray): Scores, shape (total_rows, total_cols) \ + = (n_classes, n_public_samples). + logm (jnp.ndarray): Log-Counts of the Utilities, \ + shape (total_cols,) = (n_public_samples,). + total_rows (int): U.shape[0] = n_classes. (needed for jit compilation) + total_cols (int): U.shape[1] = n_public_samples. \ + (needed for jit compilation) + epsilon (float): pure-differential privacy parameter. + key (int, optional): PRNG-initialization. Defaults to 42. + + Returns: + jnp.ndarray: array of indice(s) of the sampled element(s), \ + shape (total_rows,) = (n_classes,). + """ + rng = jax.random.key(key) + choices = ( + jnp.log(jnp.log(1 / jax.random.uniform(rng, (total_rows, total_cols)))) + - logm + - epsilon * u / 2 + ).argmin(axis=-1) + return choices @partial(jax.jit, static_argnames=("total_rows", "total_cols", "k")) def give_topk_proto_idx( - U: jnp.ndarray, - logm: jnp.ndarray, - k: int, - total_rows: int, - total_cols: int, - epsilon: float, - key: int = 42, + u: jnp.ndarray, + logm: jnp.ndarray, + k: int, + total_rows: int, + total_cols: int, + epsilon: float, + key: int = 42, ): - """Perform the private top-k prototyping. First, perform exponential sampling on the utilities. - Then, uniformly sample the remaining k-1 prototypes, s.t. their utility is equal or better. - - Args: - U (jnp.ndarray): Scores, shape (total_rows, total_cols) = (n_classes, n_public_samples). - logm (jnp.ndarray): Log-Counts of the Utilities, shape (total_cols,) = (n_public_samples,). - k (int): Number of prototypes per class to sample. - total_rows (int): U.shape[0] = n_classes. (needed for jit compilation) - total_cols (int): U.shape[1] = n_public_samples. (needed for jit compilation) - epsilon (float): pure-differential privacy parameter. - key (int, optional): PRNG-initialization. Defaults to 42. - - Returns: - jnp.ndarray: array of indice(s) of the sampled element(s), shape (total_rows, k) = (n_classes, k). - """ - choices = exponential_parallel( - U, logm, total_rows, total_cols, epsilon, key - ).astype(int) - - proto_idx_C = jnp.concatenate( - [ - jax.lax.select( - jnp.arange(total_cols)[jnp.newaxis, :].repeat(total_rows, axis=0) - < choices[:, jnp.newaxis], - -jax.random.uniform(jax.random.key(key), (total_rows, total_cols)), - jnp.stack([jnp.zeros((total_cols)) for row in jnp.arange(total_rows)]), - ).argsort(axis=-1)[:, : k - 1], - choices[:, jnp.newaxis], - ], - axis=1, - ) - - return proto_idx_C + """Perform the private top-k prototyping. First, perform exponential sampling + on the utilities. + Then, uniformly sample the remaining k-1 prototypes, s.t. their utility is + equal or better. + + Args: + u (jnp.ndarray): Scores, shape (total_rows, total_cols) \ + = (n_classes, n_public_samples). + logm (jnp.ndarray): Log-Counts of the Utilities, \ + shape (total_cols,) = (n_public_samples,). + k (int): Number of prototypes per class to sample. + total_rows (int): U.shape[0] = n_classes. (needed for jit compilation) + total_cols (int): U.shape[1] = n_public_samples. \ + (needed for jit compilation) + epsilon (float): pure-differential privacy parameter. + key (int, optional): PRNG-initialization. Defaults to 42. + + Returns: + jnp.ndarray: array of indice(s) of the sampled element(s), \ + shape (total_rows, k) = (n_classes, k). + """ + choices = exponential_parallel( + u, logm, total_rows, total_cols, epsilon, key + ).astype(int) + + proto_idx_c = jnp.concatenate( + [ + jax.lax.select( + jnp.arange(total_cols)[jnp.newaxis, :].repeat(total_rows, axis=0) + < choices[:, jnp.newaxis], + -jax.random.uniform(jax.random.key(key), (total_rows, total_cols)), + jnp.stack([jnp.zeros((total_cols)) for row in jnp.arange(total_rows)]), + ).argsort(axis=-1)[:, : k - 1], + choices[:, jnp.newaxis], + ], + axis=1, + ) + + return proto_idx_c diff --git a/research/dppl_2024/lib/utils.py b/research/dppl_2024/lib/utils.py index cc882c25..4da22696 100644 --- a/research/dppl_2024/lib/utils.py +++ b/research/dppl_2024/lib/utils.py @@ -7,134 +7,139 @@ def load_dataset(cfg: DictConfig): - X_train = np.load(cfg.dataset.train_data) - Y_train = np.load(cfg.dataset.train_labels) - X_test = np.load(cfg.dataset.test_data) - Y_test = np.load(cfg.dataset.test_labels) + x_train = np.load(cfg.dataset.train_data) + y_train = np.load(cfg.dataset.train_labels) + x_test = np.load(cfg.dataset.test_data) + y_test = np.load(cfg.dataset.test_labels) - return X_train, Y_train, X_test, Y_test + return x_train, y_train, x_test, y_test def load_public_dataset(cfg: DictConfig): - X_public = np.load(cfg.dataset.public_data) - return X_public + x_public = np.load(cfg.dataset.public_data) + return x_public -def decay(cls: int | np.ndarray, max_samples: int, num_classes: int, ratio: float = 10): - decay = -np.log(ratio) / num_classes - return np.round(max_samples * np.exp(decay * cls)).astype(int) +def decay( + cls: int | np.ndarray, max_samples: int, num_classes: int, ratio: float = 10 +): + decay = -np.log(ratio) / num_classes + return np.round(max_samples * np.exp(decay * cls)).astype(int) def give_imbalanced_set(x, y, imbalance_ratio: float = 10, seed: int = 42): - classes = np.unique(y) - X_classes = [x[y == i] for i in classes] - rng = np.random.default_rng(seed) - input_samples_per_class = np.asarray([(y == i).sum() for i in classes]) - - output_samples_per_class = decay( - np.linspace(0, len(classes), len(classes)), - max_samples=input_samples_per_class.min(), - num_classes=len(classes), - ratio=imbalance_ratio, - ) - rng.shuffle(output_samples_per_class) - x = np.concatenate( - [ - X_classes[i][:num_samples] - for i, num_samples in enumerate(output_samples_per_class) - ] - ) - y = np.concatenate( - [ - np.repeat(i, num_samples) - for i, num_samples in enumerate(output_samples_per_class) - ] - ) - return x, y + classes = np.unique(y) + x_classes = [x[y == i] for i in classes] + rng = np.random.default_rng(seed) + input_samples_per_class = np.asarray([(y == i).sum() for i in classes]) + + output_samples_per_class = decay( + np.linspace(0, len(classes), len(classes)), + max_samples=input_samples_per_class.min(), + num_classes=len(classes), + ratio=imbalance_ratio, + ) + rng.shuffle(output_samples_per_class) + x = np.concatenate( + [ + x_classes[i][:num_samples] + for i, num_samples in enumerate(output_samples_per_class) + ] + ) + y = np.concatenate( + [ + np.repeat(i, num_samples) + for i, num_samples in enumerate(output_samples_per_class) + ] + ) + return x, y def zcdp_of_naive_epsilon(epsilon): - return epsilon**2 / 2 + return epsilon**2 / 2 def exponential_epsilon_of_zcdp(rho): - return np.sqrt(8 * rho) + return np.sqrt(8 * rho) @jit def pairwise_distance(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: - """Calculate 1-cosine_similarity between x and y. + """Calculate 1-cosine_similarity between x and y. - Args: - x (jnp.ndarray): x - y (jnp.ndarray): y + Args: + x (jnp.ndarray): x + y (jnp.ndarray): y - Returns: - jnp.ndarray: pairwise distance(s) - """ - x = x / jnp.linalg.norm(x, axis=-1, keepdims=True) - y = y / jnp.linalg.norm(y, axis=-1, keepdims=True) - return 1 - jnp.dot(x, y.T) + Returns: + jnp.ndarray: pairwise distance(s) + """ + x = x / jnp.linalg.norm(x, axis=-1, keepdims=True) + y = y / jnp.linalg.norm(y, axis=-1, keepdims=True) + return 1 - jnp.dot(x, y.T) @jit def scores_single( - X: jnp.ndarray, y: jnp.ndarray, min_score: float, max_score: float + x: jnp.ndarray, y: jnp.ndarray, min_score: float, max_score: float ) -> jnp.ndarray: - """Score Calculation for a single public sample. The score is calculated as the sum - of the clipped pairwise distances between the public sample and the private samples. - - Args: - X (jnp.ndarray): private dataset - y (jnp.ndarray): public sample - min_score (float): minimum score (in [0,2)) - max_score (float): maximum score (in (min_score, 2]) - - Returns: - jnp.ndarray: Score of the public sample - """ - return jnp.sum( - ( - ( - jnp.clip( - 2 - vmap(pairwise_distance, in_axes=(0, None))(X, y), - min_score, - max_score, - ) - - min_score - ) - / (max_score - min_score) - ), - axis=0, - ) + """Score Calculation for a single public sample. The score is calculated as + the sum of the clipped pairwise distances between the public sample and the + private samples. + + Args: + x (jnp.ndarray): private dataset + y (jnp.ndarray): public sample + min_score (float): minimum score (in [0,2)) + max_score (float): maximum score (in (min_score, 2]) + + Returns: + jnp.ndarray: Score of the public sample + """ + return jnp.sum( + ( + ( + jnp.clip( + 2 - vmap(pairwise_distance, in_axes=(0, None))(x, y), + min_score, + max_score, + ) + - min_score + ) + / (max_score - min_score) + ), + axis=0, + ) @partial(jit, static_argnames=["batch_size_y"]) def scores_multiple( - X: jnp.ndarray, - Y: jnp.ndarray, - min_score: float = 0.0, - max_score: float = 2.0, - batch_size_y: int = 5000, + x: jnp.ndarray, + y: jnp.ndarray, + min_score: float = 0.0, + max_score: float = 2.0, + batch_size_y: int = 5000, ) -> jnp.ndarray: - """Perform the score calculation batched over the public samples. - - Args: - X (jnp.ndarray): private dataset - Y (jnp.ndarray): public dataset - min_score (float, optional): minimum score (in [0,2)). Defaults to 0.0. - max_score (float, optional): maximum score (in (min_score, 2]). Defaults to 2.0. - batch_size_y (int, optional): batch size (impacts VRAM usage). Defaults to 5000. - - Returns: - jnp.ndarray: Scores of all public samples in Y - """ - return jnp.concatenate( - [ - vmap( - partial(scores_single, min_score=min_score, max_score=max_score), - in_axes=(None, 0), - )(X, Y[i : min(i + batch_size_y, len(Y))]) - for i in range(0, len(Y), batch_size_y) - ], - ) + """Perform the score calculation batched over the public samples. + + Args: + x (jnp.ndarray): private dataset + y (jnp.ndarray): public dataset + min_score (float, optional): minimum score (in [0,2)). Defaults to 0.0. + max_score (float, optional): maximum score (in (min_score, 2]). \ + Defaults to 2.0. + batch_size_y (int, optional): batch size (impacts VRAM usage). \ + Defaults to 5000. + + Returns: + jnp.ndarray: Scores of all public samples in Y + """ + return jnp.concatenate( + [ + vmap( + partial(scores_single, min_score=min_score, max_score=max_score), + in_axes=(None, 0), + )(x, y[i : min(i + batch_size_y, len(y))]) + for i in range(0, len(y), batch_size_y) + ], + )