|
4 | 4 |
|
5 | 5 | from airflow import DAG |
6 | 6 | from airflow.providers.standard.operators.bash import BashOperator |
7 | | -from airflow.providers.standard.operators.python import PythonOperator |
| 7 | +from airflow.providers.standard.operators.python import PythonOperator, BranchPythonOperator |
8 | 8 | from airflow.providers.standard.operators.trigger_dagrun import TriggerDagRunOperator |
9 | 9 |
|
10 | 10 | from generate.dataset import ( |
|
13 | 13 | remove_duplicated_files, |
14 | 14 | transform_images, |
15 | 15 | start_df, |
| 16 | + Checkpoint, |
| 17 | + Stages |
16 | 18 | ) |
17 | 19 | from utils.constants import ( |
18 | 20 | DEFAULT_NUM_QUBITS, |
|
22 | 24 | DEFAULT_SHOTS, |
23 | 25 | DEFAULT_DATASET_SIZE, |
24 | 26 | DEFAULT_THREADS, |
| 27 | + images_gen_checkpoint_file |
25 | 28 | ) |
26 | 29 | from generate.ghz import gen_circuit |
27 | 30 | from export.kaggle import upload_dataset as upload_dataset_kaggle |
|
31 | 34 | "depends_on_past": True, |
32 | 35 | } |
33 | 36 |
|
| 37 | +GEN_IMAGES_TASK_ID = 'gen_images' |
| 38 | +REMOVE_DUPLICATES_TASK_ID = 'remove_duplicates' |
| 39 | +TRANSFORM_TASK_ID = 'transform_images' |
| 40 | + |
| 41 | +def next_step(checkpoint:Checkpoint) -> str: |
| 42 | + """ |
| 43 | + checks teh current checkpoint and returns the next |
| 44 | + task_id. |
| 45 | + """ |
| 46 | + |
| 47 | + if (checkpoint.stage == Stages.GEN_IMAGES): |
| 48 | + return GEN_IMAGES_TASK_ID |
| 49 | + |
| 50 | + if(checkpoint.stage == Stages.DUPLICATES): |
| 51 | + return REMOVE_DUPLICATES_TASK_ID |
| 52 | + |
| 53 | + return TRANSFORM_TASK_ID |
| 54 | + |
| 55 | +def update_checkpoint(checkpoint:Checkpoint, stage:Stages): |
| 56 | + """ |
| 57 | + Updates the checkpoint to start next task. |
| 58 | + """ |
| 59 | + checkpoint.index = 0 |
| 60 | + checkpoint.files = [] |
| 61 | + checkpoint.stage = Stages.TRANSFORM |
| 62 | + checkpoint.save() |
| 63 | + |
| 64 | + |
34 | 65 | with DAG( |
35 | 66 | dag_id="build_dataset", |
36 | 67 | default_args=default_args, |
|
47 | 78 | create_folder.doc_md = """ |
48 | 79 | Create a folder (if it doesn't exist) to store images. |
49 | 80 | """ |
| 81 | + |
| 82 | + checkpoint = Checkpoint(images_gen_checkpoint_file(folder)) |
| 83 | + |
| 84 | + branch_checkpoint = BranchPythonOperator( |
| 85 | + task_id="check_checkpoint", |
| 86 | + python_callable=next_step, |
| 87 | + op_args=[checkpoint] |
| 88 | + ) |
| 89 | + branch_checkpoint.doc_md = """ |
| 90 | + Choose the next task based on the current checkpoint. |
| 91 | + """ |
50 | 92 |
|
51 | 93 | gen_df = PythonOperator( |
52 | 94 | task_id="gen_df", python_callable=start_df, op_args=[folder] |
|
56 | 98 | """ |
57 | 99 |
|
58 | 100 | gen_images = PythonOperator( |
59 | | - task_id="gen_images", |
| 101 | + task_id=GEN_IMAGES_TASK_ID, |
60 | 102 | python_callable=generate_images, |
61 | 103 | op_args=[ |
62 | 104 | folder, |
|
65 | 107 | DEFAULT_SHOTS, |
66 | 108 | DEFAULT_DATASET_SIZE, |
67 | 109 | DEFAULT_THREADS, |
| 110 | + checkpoint |
68 | 111 | ], |
69 | 112 | ) |
70 | 113 |
|
|
73 | 116 | Qiskit framework. |
74 | 117 | """ |
75 | 118 |
|
| 119 | + transtion_gen_to_remove = PythonOperator( |
| 120 | + task_id="gen_to_remove", |
| 121 | + python_callable=update_checkpoint, |
| 122 | + op_args=[checkpoint, Stages.DUPLICATES] |
| 123 | + ) |
| 124 | + |
| 125 | + transtion_gen_to_remove.doc_md = """ |
| 126 | + Update checkpoint to start removing duplicated files. |
| 127 | + """ |
| 128 | + |
76 | 129 | remove_duplicates = PythonOperator( |
77 | | - task_id="remove_duplicates", |
| 130 | + task_id=REMOVE_DUPLICATES_TASK_ID, |
78 | 131 | python_callable=remove_duplicated_files, |
79 | | - op_args=[folder], |
| 132 | + op_args=[folder, checkpoint], |
80 | 133 | ) |
81 | 134 |
|
82 | 135 | remove_duplicates.doc_md = """ |
83 | 136 | Remove files that have the same hashes. |
84 | 137 | """ |
85 | 138 |
|
| 139 | + transition_remove_to_transform = PythonOperator( |
| 140 | + task_id="remove_to_transform", |
| 141 | + python_callable=update_checkpoint, |
| 142 | + op_args=[checkpoint, Stages.TRANSFORM] |
| 143 | + ) |
| 144 | + |
| 145 | + transition_remove_to_transform.doc_md = """ |
| 146 | + Update checkpoint to start transforming images. |
| 147 | + """ |
| 148 | + |
86 | 149 | transform_img = PythonOperator( |
87 | | - task_id="transform_images", |
| 150 | + task_id=TRANSFORM_TASK_ID, |
88 | 151 | python_callable=transform_images, |
89 | | - op_args=[folder, DEFAULT_NEW_DIM], |
| 152 | + op_args=[folder, DEFAULT_NEW_DIM, checkpoint], |
90 | 153 | ) |
91 | 154 |
|
92 | 155 | transform_img.doc_md = """ |
|
137 | 200 | """ |
138 | 201 |
|
139 | 202 | create_folder >> [gen_ghz, gen_df] |
140 | | - gen_df >> gen_images |
141 | | - gen_images >> remove_duplicates |
142 | | - remove_duplicates >> transform_img |
| 203 | + gen_df >> branch_checkpoint |
| 204 | + |
| 205 | + branch_checkpoint >> gen_images |
| 206 | + branch_checkpoint >> remove_duplicates |
| 207 | + branch_checkpoint >> transform_img |
| 208 | + |
| 209 | + gen_images >> transtion_gen_to_remove |
| 210 | + transtion_gen_to_remove >> remove_duplicates |
| 211 | + remove_duplicates >> transition_remove_to_transform |
| 212 | + transition_remove_to_transform >> transform_img |
143 | 213 | transform_img >> pack_img |
144 | 214 |
|
145 | 215 | [gen_ghz, pack_img] >> trigger_dag_train |
|
0 commit comments