Skip to content

Commit 7fa2174

Browse files
committed
adding checkpoint to airflow pipelines
1 parent 821cc1c commit 7fa2174

File tree

3 files changed

+115
-10
lines changed

3 files changed

+115
-10
lines changed

dags/dataset.py

Lines changed: 79 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from airflow import DAG
66
from airflow.providers.standard.operators.bash import BashOperator
7-
from airflow.providers.standard.operators.python import PythonOperator
7+
from airflow.providers.standard.operators.python import PythonOperator, BranchPythonOperator
88
from airflow.providers.standard.operators.trigger_dagrun import TriggerDagRunOperator
99

1010
from generate.dataset import (
@@ -13,6 +13,8 @@
1313
remove_duplicated_files,
1414
transform_images,
1515
start_df,
16+
Checkpoint,
17+
Stages
1618
)
1719
from utils.constants import (
1820
DEFAULT_NUM_QUBITS,
@@ -22,6 +24,7 @@
2224
DEFAULT_SHOTS,
2325
DEFAULT_DATASET_SIZE,
2426
DEFAULT_THREADS,
27+
images_gen_checkpoint_file
2528
)
2629
from generate.ghz import gen_circuit
2730
from export.kaggle import upload_dataset as upload_dataset_kaggle
@@ -31,6 +34,34 @@
3134
"depends_on_past": True,
3235
}
3336

37+
GEN_IMAGES_TASK_ID = 'gen_images'
38+
REMOVE_DUPLICATES_TASK_ID = 'remove_duplicates'
39+
TRANSFORM_TASK_ID = 'transform_images'
40+
41+
def next_step(checkpoint:Checkpoint) -> str:
42+
"""
43+
checks teh current checkpoint and returns the next
44+
task_id.
45+
"""
46+
47+
if (checkpoint.stage == Stages.GEN_IMAGES):
48+
return GEN_IMAGES_TASK_ID
49+
50+
if(checkpoint.stage == Stages.DUPLICATES):
51+
return REMOVE_DUPLICATES_TASK_ID
52+
53+
return TRANSFORM_TASK_ID
54+
55+
def update_checkpoint(checkpoint:Checkpoint, stage:Stages):
56+
"""
57+
Updates the checkpoint to start next task.
58+
"""
59+
checkpoint.index = 0
60+
checkpoint.files = []
61+
checkpoint.stage = Stages.TRANSFORM
62+
checkpoint.save()
63+
64+
3465
with DAG(
3566
dag_id="build_dataset",
3667
default_args=default_args,
@@ -47,6 +78,17 @@
4778
create_folder.doc_md = """
4879
Create a folder (if it doesn't exist) to store images.
4980
"""
81+
82+
checkpoint = Checkpoint(images_gen_checkpoint_file(folder))
83+
84+
branch_checkpoint = BranchPythonOperator(
85+
task_id="check_checkpoint",
86+
python_callable=next_step,
87+
op_args=[checkpoint]
88+
)
89+
branch_checkpoint.doc_md = """
90+
Choose the next task based on the current checkpoint.
91+
"""
5092

5193
gen_df = PythonOperator(
5294
task_id="gen_df", python_callable=start_df, op_args=[folder]
@@ -56,7 +98,7 @@
5698
"""
5799

58100
gen_images = PythonOperator(
59-
task_id="gen_images",
101+
task_id=GEN_IMAGES_TASK_ID,
60102
python_callable=generate_images,
61103
op_args=[
62104
folder,
@@ -65,6 +107,7 @@
65107
DEFAULT_SHOTS,
66108
DEFAULT_DATASET_SIZE,
67109
DEFAULT_THREADS,
110+
checkpoint
68111
],
69112
)
70113

@@ -73,20 +116,40 @@
73116
Qiskit framework.
74117
"""
75118

119+
transtion_gen_to_remove = PythonOperator(
120+
task_id="gen_to_remove",
121+
python_callable=update_checkpoint,
122+
op_args=[checkpoint, Stages.DUPLICATES]
123+
)
124+
125+
transtion_gen_to_remove.doc_md = """
126+
Update checkpoint to start removing duplicated files.
127+
"""
128+
76129
remove_duplicates = PythonOperator(
77-
task_id="remove_duplicates",
130+
task_id=REMOVE_DUPLICATES_TASK_ID,
78131
python_callable=remove_duplicated_files,
79-
op_args=[folder],
132+
op_args=[folder, checkpoint],
80133
)
81134

82135
remove_duplicates.doc_md = """
83136
Remove files that have the same hashes.
84137
"""
85138

139+
transition_remove_to_transform = PythonOperator(
140+
task_id="remove_to_transform",
141+
python_callable=update_checkpoint,
142+
op_args=[checkpoint, Stages.TRANSFORM]
143+
)
144+
145+
transition_remove_to_transform.doc_md = """
146+
Update checkpoint to start transforming images.
147+
"""
148+
86149
transform_img = PythonOperator(
87-
task_id="transform_images",
150+
task_id=TRANSFORM_TASK_ID,
88151
python_callable=transform_images,
89-
op_args=[folder, DEFAULT_NEW_DIM],
152+
op_args=[folder, DEFAULT_NEW_DIM, checkpoint],
90153
)
91154

92155
transform_img.doc_md = """
@@ -137,9 +200,16 @@
137200
"""
138201

139202
create_folder >> [gen_ghz, gen_df]
140-
gen_df >> gen_images
141-
gen_images >> remove_duplicates
142-
remove_duplicates >> transform_img
203+
gen_df >> branch_checkpoint
204+
205+
branch_checkpoint >> gen_images
206+
branch_checkpoint >> remove_duplicates
207+
branch_checkpoint >> transform_img
208+
209+
gen_images >> transtion_gen_to_remove
210+
transtion_gen_to_remove >> remove_duplicates
211+
remove_duplicates >> transition_remove_to_transform
212+
transition_remove_to_transform >> transform_img
143213
transform_img >> pack_img
144214

145215
[gen_ghz, pack_img] >> trigger_dag_train

dags/train.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from train import setup_and_run_training
99
from args.parser import Arguments
1010
from utils.constants import DEFAULT_TARGET_FOLDER
11+
from utils.helpers import get_latest_model_checkpoint
1112
from export.kaggle import upload_model as upload_model_kaggle
1213
from export.huggingface import upload_model as upload_model_hf
1314

@@ -17,6 +18,10 @@
1718
args = Arguments()
1819
args.target_folder = folder
1920

21+
checkpoint = get_latest_model_checkpoint(folder)
22+
if(checkpoint):
23+
args.checkpoint = checkpoint
24+
2025
train = PythonOperator(
2126
task_id="train_model",
2227
python_callable=setup_and_run_training,

utils/helpers.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22

33
from typing import List, Optional
44
import random
5+
import os
56

67
import torch
78

8-
from utils.constants import DEBUG
9+
from utils.constants import DEBUG, CHECKPOINT_FILE_PREFIX
10+
from utils.datatypes import FilePath
911

1012

1113
class PlotImages:
@@ -53,3 +55,31 @@ def get_measurements(n_qubits: int) -> List[int]:
5355
total_measurements = random.randint(1, n_qubits)
5456
qubits = list(range(n_qubits))
5557
return random.sample(qubits, total_measurements)
58+
59+
def get_latest_model_checkpoint(target_folder:FilePath) -> Optional[FilePath]:
60+
"""Returns the path of the lastest checkpoint"""
61+
62+
if not os.path.exists(target_folder):
63+
return None
64+
65+
files = [
66+
os.path.join(target_folder, file)
67+
for file in os.listdir(target_folder)
68+
if file.startswith(CHECKPOINT_FILE_PREFIX)
69+
]
70+
71+
if not files:
72+
return None
73+
74+
modification_time = [os.path.getmtime(file) for file in files]
75+
76+
latest_time = max(modification_time)
77+
78+
file_index = modification_time.index(latest_time)
79+
80+
return files[file_index]
81+
82+
83+
84+
85+

0 commit comments

Comments
 (0)