-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy patheval.py
More file actions
37 lines (28 loc) · 787 Bytes
/
eval.py
File metadata and controls
37 lines (28 loc) · 787 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import argparse
from services.cifar_data import CifarData
from services.cifar_trainer import CifarTrainer
from services.neural_network import NeuralNetwork
parser = argparse.ArgumentParser(description="Train cifar neural network")
parser.add_argument(
"-u",
"--bucket",
required=True,
type=str,
help="Bucket to save the model",
)
parser.add_argument(
"-p",
"--path",
required=True,
type=str,
help="Path to save the model",
)
args = parser.parse_args()
def main():
neural_network = NeuralNetwork()
model = CifarTrainer(neural_network=neural_network)
model.load(bucket=args.bucket, path=args.path)
dataloader = CifarData(split="test").dataloader()
model.eval(dataloader=dataloader)
if __name__ == "__main__":
main()