11"""Parse CLI arguments"""
22
3- from typing import TypedDict , Optional
4-
53import sys
64import argparse
5+ from typing import Optional
76
87from utils .constants import (
98 DEFAULT_EPOCHS ,
1716 DEFAULT_TRAIN_PERCENTAGE ,
1817 DEFAULT_TEST_PERCENTAGE ,
1918 DEFAULT_TARGET_FOLDER ,
19+ DEFAULT_CHECKPOINT ,
2020)
2121from utils .datatypes import Dimensions , FilePath
2222
2323
24- class Arguments ( TypedDict ) :
24+ class Arguments :
2525 """Parsed args types"""
2626
27- epochs : int
28- batch_size : int
29- train_size : int
30- test_size : int
31- threads : int
32- shots : int
33- n_qubits : int
34- max_gates : int
35- dataset_size : int
36- target_folder : FilePath
37- checkpoint : Optional [FilePath ]
38- new_image_dim : Dimensions
27+ def __init__ (self ):
28+ """set default arguments"""
29+
30+ self ._epochs = DEFAULT_EPOCHS
31+ self ._batch_size = DEFAULT_BATCH_SIZE
32+ self ._train_size = DEFAULT_TRAIN_PERCENTAGE
33+ self ._test_size = DEFAULT_TEST_PERCENTAGE
34+ self ._threads = DEFAULT_THREADS
35+ self ._shots = DEFAULT_SHOTS
36+ self ._n_qubits = DEFAULT_NUM_QUBITS
37+ self ._max_gates = DEFAULT_MAX_TOTAL_GATES
38+ self ._dataset_size = DEFAULT_DATASET_SIZE
39+ self ._target_folder = DEFAULT_TARGET_FOLDER
40+ self ._checkpoint = DEFAULT_CHECKPOINT
41+ self ._new_image_dim = DEFAULT_NEW_DIM
42+
43+ def parse (self , args : argparse .Namespace ):
44+ """Parse arguments from argparse"""
45+ self ._epochs = args .epochs
46+ self ._batch_size = args .batch_size
47+ self ._train_size = args .train_size
48+ self ._test_size = args .test_size
49+ self ._threads = args .threads
50+ self ._shots = args .shots
51+ self ._n_qubits = args .n_qubits
52+ self ._max_gates = args .max_gates
53+ self ._dataset_size = args .dataset_size
54+ self ._target_folder = args .target_folder
55+ self ._checkpoint = args .checkpoint
56+ self ._new_image_dim = args .new_image_dim
57+
58+ @property
59+ def epochs (self ) -> int :
60+ """Get epochs data"""
61+ return self ._epochs # type: ignore
62+
63+ @epochs .setter
64+ def epochs (self , value : int ):
65+ """Set epochs data"""
66+ self ._epochs = value
67+
68+ @property
69+ def batch_size (self ) -> int :
70+ """Get batch_size data"""
71+ return self ._batch_size # type: ignore
72+
73+ @batch_size .setter
74+ def batch_size (self , value : int ):
75+ """Set batch_size data"""
76+ self ._batch_size
77+
78+ @property
79+ def train_size (self ) -> int :
80+ """Get train_size data"""
81+ return self ._train_size # type: ignore
82+
83+ @train_size .setter
84+ def train_size (self , value : int ):
85+ """Set train_size data"""
86+ self ._train_size = value
87+
88+ @property
89+ def test_size (self ) -> int :
90+ """Get test_size data"""
91+ return self ._test_size # type: ignore
92+
93+ @test_size .setter
94+ def test_size (self , value : int ):
95+ """Set test_size data"""
96+ self ._test_size = value
97+
98+ @property
99+ def threads (self ) -> int :
100+ """Get threads data"""
101+ return self ._threads # type: ignore
102+
103+ @threads .setter
104+ def threads (self , value : int ):
105+ """Set threads data"""
106+ self ._threads = value
107+
108+ @property
109+ def shots (self ) -> int :
110+ """Get shots data"""
111+ return self ._shots # type: ignore
112+
113+ @shots .setter
114+ def shots (self , value : int ):
115+ """Set shots data"""
116+ self ._shots = value
117+
118+ @property
119+ def n_qubits (self ) -> int :
120+ """Get n_qubits data"""
121+ return self ._n_qubits # type: ignore
122+
123+ @n_qubits .setter
124+ def n_qubits (self , value : int ):
125+ """Set n_qubits data"""
126+ self ._n_qubits = value
127+
128+ @property
129+ def max_gates (self ) -> int :
130+ """Get max_gates data"""
131+ return self ._max_gates # type: ignore
132+
133+ @max_gates .setter
134+ def max_gates (self , value : int ):
135+ """Set max_gates data"""
136+ self ._max_gates = value
137+
138+ @property
139+ def dataset_size (self ) -> int :
140+ """Get dataset_size data"""
141+ return self ._dataset_size # type: ignore
142+
143+ @dataset_size .setter
144+ def dataset_size (self , value : int ):
145+ """Set dataset_size data"""
146+ self ._dataset_size = value
147+
148+ @property
149+ def target_folder (self ) -> FilePath :
150+ """Get target_folder data"""
151+ return self ._target_folder # type: ignore
152+
153+ @target_folder .setter
154+ def target_folder (self , value : FilePath ):
155+ """Set target_folder data"""
156+ self ._target_folder = value
157+
158+ @property
159+ def checkpoint (self ) -> Optional [FilePath ]:
160+ """Get checkpoint data"""
161+ return self ._checkpoint # type: ignore
162+
163+ @checkpoint .setter
164+ def checkpoint (self , value : Optional [FilePath ]):
165+ """Set checkpoint data"""
166+ self ._checkpoint = value
167+
168+ @property
169+ def new_image_dim (self ) -> Dimensions :
170+ """Get new_image_dim data"""
171+ return self ._new_image_dim # type: ignore
172+
173+ @new_image_dim .setter
174+ def new_image_dim (self , value : Dimensions ):
175+ """Set new_image_dim data"""
176+ self ._new_image_dim = value
177+
178+ def __str__ (self ) -> str :
179+ string = f"epochs: { self ._epochs } \n "
180+ string += f"batch size: { self ._batch_size } \n "
181+ string += f"train size: { self ._train_size } \n "
182+ string += f"teste size: { self ._test_size } \n "
183+ string += f"threads: { self ._threads } \n "
184+ string += f"shots: { self ._shots } \n "
185+ string += f"n qubits: { self ._n_qubits } \n "
186+ string += f"max gates: { self ._max_gates } \n "
187+ string += f"dataset size: { self ._dataset_size } \n "
188+ string += f"target_folder: { self ._target_folder } \n "
189+ string += f"checkpoint: { self ._checkpoint } \n "
190+ string += f"new image dim: { self ._new_image_dim } \n "
191+
192+ return string
39193
40194
41195def parse_args () -> Arguments :
@@ -48,7 +202,7 @@ def parse_args() -> Arguments:
48202 parser .add_argument ("--batch-size" , type = int , default = DEFAULT_BATCH_SIZE )
49203 parser .add_argument ("--train-size" , type = float , default = DEFAULT_TRAIN_PERCENTAGE )
50204 parser .add_argument ("--test-size" , type = float , default = DEFAULT_TEST_PERCENTAGE )
51- parser .add_argument ("--checkpoint" , type = str , default = None )
205+ parser .add_argument ("--checkpoint" , type = str , default = DEFAULT_CHECKPOINT )
52206
53207 parser .add_argument ("--threads" , type = int , default = DEFAULT_THREADS )
54208
@@ -62,19 +216,7 @@ def parse_args() -> Arguments:
62216
63217 args = parser .parse_args (sys .argv [1 :])
64218
65- parsed_arguments : Arguments = {
66- "epochs" : args .epochs ,
67- "batch_size" : args .batch_size ,
68- "train_size" : args .train_size ,
69- "test_size" : args .test_size ,
70- "checkpoint" : args .checkpoint ,
71- "threads" : args .threads ,
72- "shots" : args .shots ,
73- "n_qubits" : args .n_qubits ,
74- "max_gates" : args .max_gates ,
75- "dataset_size" : args .dataset_size ,
76- "target_folder" : args .target_folder ,
77- "new_image_dim" : args .new_image_dim ,
78- }
219+ parsed_arguments = Arguments ()
220+ parsed_arguments .parse (args )
79221
80222 return parsed_arguments
0 commit comments