Skip to content

Commit 0df9e32

Browse files
committed
fixed train pipeline
1 parent be6276d commit 0df9e32

File tree

14 files changed

+249
-72
lines changed

14 files changed

+249
-72
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@ repos:
55
hooks:
66
- id: ruff
77
args: [ --fix ]
8-
- id: ruff-format
98

109
- repo: https://github.com/pre-commit/mirrors-mypy
1110
rev: v1.16.0
1211
hooks:
13-
- id: mypy
12+
- id: mypy

Makefile

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,28 @@
11
SHELL := /bin/bash
2+
TARGET_PATH ?= "."
23

34
clean-all: clean-dataset clean-pred clean-ghz clean-model clean-checkpoints clean-history
45

56
clean-dataset:
6-
rm -rf dataset/ dataset.csv *.h5 dataset-images.zip
7+
rm -rf $(TARGET_PATH)/dataset/ $(TARGET_PATH)/dataset.csv $(TARGET_PATH)/*.h5 $(TARGET_PATH)/dataset-images.zip
78

89
clean-pred:
9-
rm -rf ghz-prediction.pth
10+
rm -rf $(TARGET_PATH)/ghz-prediction.pth
1011

1112
clean-ghz:
12-
rm -rf ghz.pth ghz.jpeg
13+
rm -rf $(TARGET_PATH)/ghz.pth $(TARGET_PATH)/ghz.jpeg
1314

1415
clean-model:
15-
rm -rf model_*
16+
rm -rf $(TARGET_PATH)/model_*
1617

1718
clean-checkpoints:
18-
rm -rf checkpoint_*
19+
rm -rf $(TARGET_PATH)/checkpoint_*
1920

2021
clean-history:
21-
rm -rf history.json
22+
rm -rf $(TARGET_PATH)/history.json
2223

2324
pack:
24-
zip -r dataset-images.zip dataset/
25+
zip -r $(TARGET_PATH)/dataset-images.zip $(TARGET_PATH)/dataset/
2526

2627
lock:
2728
conda-lock -f environment.yml

ai/__init__.py

Whitespace-only changes.

airflow-entrypoint.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@ done
88
echo "Migrating...."
99
airflow db migrate
1010

11-
# USER, PASSWORD and EMAIl come from env variables
12-
echo "Creating Airflow user ${USER}..."
11+
# USER, PASSWORD and EMAIL come from env variables
12+
echo "Creating Airflow user $USER..."
1313
airflow users create \
1414
--username $USER \
1515
--firstname user \
1616
--lastname user \
1717
--role Admin \
18-
--email ${EMAIL} \
19-
--password ${PASSWORD}
18+
--email $EMAIL \
19+
--password $PASSWORD
2020

2121
echo "Running scheduler...."
2222
airflow scheduler &

airflow.Dockerfile

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@ RUN ${PIPENV} install -r requirements.txt
77

88

99

10-
FROM busybox:1.37.0 AS entry
10+
FROM debian:bookworm-slim AS entry
11+
12+
RUN apt update && apt install zip make -y
13+
1114
WORKDIR /
1215
COPY airflow-entrypoint.sh entrypoint.sh
1316
RUN chmod +x entrypoint.sh
@@ -19,6 +22,10 @@ COPY --from=setup /proj-venv/lib/python3.12/site-packages/ /home/airflow/.local/
1922
WORKDIR /home/airflow/project
2023
COPY . .
2124

25+
WORKDIR /home/airflow/.local/bin
26+
COPY --from=entry --chown=airflow:root /usr/bin/zip zip
27+
COPY --from=entry --chown=airflow:root /usr/bin/make make
28+
2229
WORKDIR /
2330
COPY --from=entry /entrypoint.sh .
2431

airflow.Dockerfile.dockerignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
*
2-
!ai/
2+
!train.py
33
!export/
44
!generate/
55
!utils/

args/parser.py

Lines changed: 172 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
"""Parse CLI arguments"""
22

3-
from typing import TypedDict, Optional
4-
53
import sys
64
import argparse
5+
from typing import Optional
76

87
from utils.constants import (
98
DEFAULT_EPOCHS,
@@ -17,25 +16,180 @@
1716
DEFAULT_TRAIN_PERCENTAGE,
1817
DEFAULT_TEST_PERCENTAGE,
1918
DEFAULT_TARGET_FOLDER,
19+
DEFAULT_CHECKPOINT,
2020
)
2121
from utils.datatypes import Dimensions, FilePath
2222

2323

24-
class Arguments(TypedDict):
24+
class Arguments:
2525
"""Parsed args types"""
2626

27-
epochs: int
28-
batch_size: int
29-
train_size: int
30-
test_size: int
31-
threads: int
32-
shots: int
33-
n_qubits: int
34-
max_gates: int
35-
dataset_size: int
36-
target_folder: FilePath
37-
checkpoint: Optional[FilePath]
38-
new_image_dim: Dimensions
27+
def __init__(self):
28+
"""set default arguments"""
29+
30+
self._epochs = DEFAULT_EPOCHS
31+
self._batch_size = DEFAULT_BATCH_SIZE
32+
self._train_size = DEFAULT_TRAIN_PERCENTAGE
33+
self._test_size = DEFAULT_TEST_PERCENTAGE
34+
self._threads = DEFAULT_THREADS
35+
self._shots = DEFAULT_SHOTS
36+
self._n_qubits = DEFAULT_NUM_QUBITS
37+
self._max_gates = DEFAULT_MAX_TOTAL_GATES
38+
self._dataset_size = DEFAULT_DATASET_SIZE
39+
self._target_folder = DEFAULT_TARGET_FOLDER
40+
self._checkpoint = DEFAULT_CHECKPOINT
41+
self._new_image_dim = DEFAULT_NEW_DIM
42+
43+
def parse(self, args: argparse.Namespace):
44+
"""Parse arguments from argparse"""
45+
self._epochs = args.epochs
46+
self._batch_size = args.batch_size
47+
self._train_size = args.train_size
48+
self._test_size = args.test_size
49+
self._threads = args.threads
50+
self._shots = args.shots
51+
self._n_qubits = args.n_qubits
52+
self._max_gates = args.max_gates
53+
self._dataset_size = args.dataset_size
54+
self._target_folder = args.target_folder
55+
self._checkpoint = args.checkpoint
56+
self._new_image_dim = args.new_image_dim
57+
58+
@property
59+
def epochs(self) -> int:
60+
"""Get epochs data"""
61+
return self._epochs # type: ignore
62+
63+
@epochs.setter
64+
def epochs(self, value: int):
65+
"""Set epochs data"""
66+
self._epochs = value
67+
68+
@property
69+
def batch_size(self) -> int:
70+
"""Get batch_size data"""
71+
return self._batch_size # type: ignore
72+
73+
@batch_size.setter
74+
def batch_size(self, value: int):
75+
"""Set batch_size data"""
76+
self._batch_size
77+
78+
@property
79+
def train_size(self) -> int:
80+
"""Get train_size data"""
81+
return self._train_size # type: ignore
82+
83+
@train_size.setter
84+
def train_size(self, value: int):
85+
"""Set train_size data"""
86+
self._train_size = value
87+
88+
@property
89+
def test_size(self) -> int:
90+
"""Get test_size data"""
91+
return self._test_size # type: ignore
92+
93+
@test_size.setter
94+
def test_size(self, value: int):
95+
"""Set test_size data"""
96+
self._test_size = value
97+
98+
@property
99+
def threads(self) -> int:
100+
"""Get threads data"""
101+
return self._threads # type: ignore
102+
103+
@threads.setter
104+
def threads(self, value: int):
105+
"""Set threads data"""
106+
self._threads = value
107+
108+
@property
109+
def shots(self) -> int:
110+
"""Get shots data"""
111+
return self._shots # type: ignore
112+
113+
@shots.setter
114+
def shots(self, value: int):
115+
"""Set shots data"""
116+
self._shots = value
117+
118+
@property
119+
def n_qubits(self) -> int:
120+
"""Get n_qubits data"""
121+
return self._n_qubits # type: ignore
122+
123+
@n_qubits.setter
124+
def n_qubits(self, value: int):
125+
"""Set n_qubits data"""
126+
self._n_qubits = value
127+
128+
@property
129+
def max_gates(self) -> int:
130+
"""Get max_gates data"""
131+
return self._max_gates # type: ignore
132+
133+
@max_gates.setter
134+
def max_gates(self, value: int):
135+
"""Set max_gates data"""
136+
self._max_gates = value
137+
138+
@property
139+
def dataset_size(self) -> int:
140+
"""Get dataset_size data"""
141+
return self._dataset_size # type: ignore
142+
143+
@dataset_size.setter
144+
def dataset_size(self, value: int):
145+
"""Set dataset_size data"""
146+
self._dataset_size = value
147+
148+
@property
149+
def target_folder(self) -> FilePath:
150+
"""Get target_folder data"""
151+
return self._target_folder # type: ignore
152+
153+
@target_folder.setter
154+
def target_folder(self, value: FilePath):
155+
"""Set target_folder data"""
156+
self._target_folder = value
157+
158+
@property
159+
def checkpoint(self) -> Optional[FilePath]:
160+
"""Get checkpoint data"""
161+
return self._checkpoint # type: ignore
162+
163+
@checkpoint.setter
164+
def checkpoint(self, value: Optional[FilePath]):
165+
"""Set checkpoint data"""
166+
self._checkpoint = value
167+
168+
@property
169+
def new_image_dim(self) -> Dimensions:
170+
"""Get new_image_dim data"""
171+
return self._new_image_dim # type: ignore
172+
173+
@new_image_dim.setter
174+
def new_image_dim(self, value: Dimensions):
175+
"""Set new_image_dim data"""
176+
self._new_image_dim = value
177+
178+
def __str__(self) -> str:
179+
string = f"epochs: {self._epochs}\n"
180+
string += f"batch size: {self._batch_size}\n"
181+
string += f"train size: {self._train_size}\n"
182+
string += f"teste size: {self._test_size}\n"
183+
string += f"threads: {self._threads}\n"
184+
string += f"shots: {self._shots}\n"
185+
string += f"n qubits: {self._n_qubits}\n"
186+
string += f"max gates: {self._max_gates}\n"
187+
string += f"dataset size: {self._dataset_size}\n"
188+
string += f"target_folder: {self._target_folder}\n"
189+
string += f"checkpoint: {self._checkpoint}\n"
190+
string += f"new image dim: {self._new_image_dim}\n"
191+
192+
return string
39193

40194

41195
def parse_args() -> Arguments:
@@ -48,7 +202,7 @@ def parse_args() -> Arguments:
48202
parser.add_argument("--batch-size", type=int, default=DEFAULT_BATCH_SIZE)
49203
parser.add_argument("--train-size", type=float, default=DEFAULT_TRAIN_PERCENTAGE)
50204
parser.add_argument("--test-size", type=float, default=DEFAULT_TEST_PERCENTAGE)
51-
parser.add_argument("--checkpoint", type=str, default=None)
205+
parser.add_argument("--checkpoint", type=str, default=DEFAULT_CHECKPOINT)
52206

53207
parser.add_argument("--threads", type=int, default=DEFAULT_THREADS)
54208

@@ -62,19 +216,7 @@ def parse_args() -> Arguments:
62216

63217
args = parser.parse_args(sys.argv[1:])
64218

65-
parsed_arguments: Arguments = {
66-
"epochs": args.epochs,
67-
"batch_size": args.batch_size,
68-
"train_size": args.train_size,
69-
"test_size": args.test_size,
70-
"checkpoint": args.checkpoint,
71-
"threads": args.threads,
72-
"shots": args.shots,
73-
"n_qubits": args.n_qubits,
74-
"max_gates": args.max_gates,
75-
"dataset_size": args.dataset_size,
76-
"target_folder": args.target_folder,
77-
"new_image_dim": args.new_image_dim,
78-
}
219+
parsed_arguments = Arguments()
220+
parsed_arguments.parse(args)
79221

80222
return parsed_arguments

compose.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ services:
2727
volumes:
2828
- ./data:/home/airflow/data
2929
- ./dags:/opt/airflow/dags
30+
deploy:
31+
resources:
32+
reservations:
33+
devices:
34+
- capabilities: [gpu]
35+
count: all
36+
driver: nvidia
3037
environment:
3138
- USER=${AIRFLOW_USERNAME}
3239
- PASSWORD=${AIRFLOW_PASSWORD}

dags/dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
DEFAULT_SHOTS,
2323
DEFAULT_DATASET_SIZE,
2424
DEFAULT_THREADS,
25-
dataset_file,
2625
)
2726
from generate.ghz import gen_circuit
2827

@@ -48,7 +47,7 @@
4847
"""
4948

5049
gen_df = PythonOperator(
51-
task_id="gen_df", python_callable=start_df, op_args=[dataset_file(folder)]
50+
task_id="gen_df", python_callable=start_df, op_args=[folder]
5251
)
5352
gen_df.doc_md = """
5453
Generate an empty dataframe and saves it as an csv file.
@@ -93,7 +92,8 @@
9392
with resized and normalized images.
9493
"""
9594

96-
pack_img = BashOperator(task_id="pack_images", bash_command="make pack")
95+
command = f"zip -r {folder}/dataset-images.zip {folder}/dataset/"
96+
pack_img = BashOperator(task_id="pack_images", bash_command=command)
9797

9898
pack_img.doc_md = """
9999
This task is meant to get all .jpeg images that were generated, and pack them
@@ -103,7 +103,7 @@
103103
gen_ghz = PythonOperator(
104104
task_id="gen_ghz",
105105
python_callable=gen_circuit,
106-
op_args=[DEFAULT_NUM_QUBITS, DEFAULT_TARGET_FOLDER, DEFAULT_NEW_DIM],
106+
op_args=[DEFAULT_NUM_QUBITS, folder, DEFAULT_NEW_DIM],
107107
)
108108

109109
gen_ghz.doc_md = """

0 commit comments

Comments
 (0)