-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathmain.py
More file actions
94 lines (81 loc) · 3.34 KB
/
Copy pathmain.py
File metadata and controls
94 lines (81 loc) · 3.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import pandas as pd
import jax
import torchvision
import reprieve
from reprieve.representations import mnist_vae
from reprieve.mnist_noisy_label import MNISTNoisyLabelDataset
from reprieve.algorithms import mlp as alg
def main(args):
init_fn, train_step_fn, eval_fn = alg.make_algorithm((1, 28, 28), 10)
dataset_mnist = torchvision.datasets.MNIST(
'./data', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))]))
raw_loss_data_estimator = reprieve.LossDataEstimator(
init_fn, train_step_fn, eval_fn, dataset_mnist,
train_steps=args.train_steps, n_seeds=args.seeds,
use_vmap=args.use_vmap, cache_data=args.cache_data,
verbose=True)
raw_results = raw_loss_data_estimator.compute_curve(n_points=args.points)
vae_repr = mnist_vae.build_repr(8)
init_fn, train_step_fn, eval_fn = alg.make_algorithm((8,), 10)
vae_loss_data_estimator = reprieve.LossDataEstimator(
init_fn, train_step_fn, eval_fn, dataset_mnist,
representation_fn=vae_repr,
train_steps=args.train_steps, n_seeds=args.seeds,
use_vmap=args.use_vmap, cache_data=args.cache_data,
verbose=True)
vae_results = vae_loss_data_estimator.compute_curve(n_points=args.points)
dataset_noisygt = MNISTNoisyLabelDataset(
train=True, p_corrupt=0.05)
init_fn, train_step_fn, eval_fn = alg.make_algorithm((784,), 10)
noisy_loss_data_estimator = reprieve.LossDataEstimator(
init_fn, train_step_fn, eval_fn, dataset_noisygt,
train_steps=args.train_steps, n_seeds=args.seeds,
use_vmap=args.use_vmap, cache_data=args.cache_data,
verbose=True)
noisy_results = noisy_loss_data_estimator.compute_curve(
n_points=args.points)
raw_results['name'] = 'Raw'
vae_results['name'] = 'VAE'
noisy_results['name'] = 'Noisy labels'
outcome_df = pd.concat([
raw_results,
vae_results,
noisy_results,
])
os.makedirs('results', exist_ok=True)
save_path = ('results/'
f'{args.name}'
f'_train{args.train_steps}'
f'_seed{args.seeds}'
f'_point{args.points}')
ns = [60, 20000]
epsilons = [1, 0.2]
reprieve.render_curve(outcome_df, ns, epsilons,
save_path=save_path + '.pdf')
metrics_df = reprieve.compute_metrics(outcome_df, ns, epsilons)
print(metrics_df)
reprieve.render_latex(metrics_df, save_path=save_path + '.tex')
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--name', type=str, default='jax')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--no_vmap', dest='use_vmap', action='store_false')
parser.add_argument('--no_cache', dest='cache_data', action='store_false')
parser.add_argument('--train_steps', type=float, default=4e3)
parser.add_argument('--seeds', type=int, default=5)
parser.add_argument('--points', type=int, default=10)
args = parser.parse_args()
import time
start = time.time()
if args.debug:
with jax.disable_jit():
main(args)
else:
main(args)
end = time.time()
print(f"Time: {end - start :.3f} seconds")