Skip to content

Commit 497ebc7

Browse files
authored
Merge pull request #51 from ua-datalab/create_pytests_uspantek
found bug in earlier pytest of sst2. fixed+ added 4 more for uspantek. all 8 passes (4 for sst 4 for uspantek)
2 parents 008f11a + fb112f9 commit 497ebc7

File tree

4 files changed

+207
-121
lines changed

4 files changed

+207
-121
lines changed

Project-Plan.md

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
* live or dead status of all experiments as of :dec 10th 2024
55
* sst2_classical_1: runs well end to end+ has pytest. Latest version can be found in branch titled: run_sst1_classical1
66
* sst2_classical2: i.e with bob cat parser, is hitting `both inputs must be same dtype` error again. end of road for now.Latest version can be found in branch titled: run_sst_classical2
7-
* sst2_quantum1:
8-
* sst2_quantum2:
7+
* sst2_quantum1: stuck on time out/ first .fit doesnt respond after a long time. Tried adding a pytest, but not able to cleanly capture a time out.
8+
* sst2_quantum2: same stuck on time out/ first .fit doesnt respond after a long time. Tried adding a pytest, but not able to cleanly capture a time out.
9+
910
* spanish_classical_1:
1011
* spanish_classical2:
1112
* spanish_quantum1:
1213
* spanish_quantum2:
14+
1315
* uspantek_classical_1:
1416
* uspantek_classical2:
1517
* uspantek_quantum1:
@@ -54,13 +56,31 @@
5456

5557
## Dec 12th 20204
5658
todos:
57-
3. add pytest for run_sst1_run_sst1_quantum1
59+
- add pytest for run_sst1_run_sst1_quantum1 ---done
5860
- add time out
5961
- add if wandb
6062
- merge to staging
6163
- merge to main
62-
4. add pytest for run_sst1_run_sst1_quantum2
64+
- add pytest for run_sst1_run_sst1_quantum2---done/ignored
6365
- mostly same as above- timeout
66+
- make wandb optional ---done
67+
- in a branch called make_wandb_optional_commandline_arg ---done
68+
- run pytest locally --done
69+
- merge to main---done
70+
- pull and run pytest locally : started at 7pm dec 12th
71+
- create 8 pytests for spanish
72+
- classical1 x2
73+
- classical2 x2
74+
- quantum1 x2
75+
- quantum2 x2
76+
- create 8 pytests for uspantek
77+
- classical1 x2
78+
- classical2 x2
79+
- quantum1 x2
80+
- quantum2 x2
81+
- add python dictionary
82+
- start tuning uspantek with whatever of the 8 above you think is best
83+
6484

6585

6686

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,4 @@ Also note, during development its a healthy habit to always run pytest before th
8484

8585
`--do_debug`: pass this only if you want debugging done. i.e you are planning to attaching a debugging process from an IDE like visual studio code.
8686

87-
87+
`--use_wandb`: pass this if you want to turn on wandb/logging all variables online. making it optional since wandb doesnt work well with cyverse.

classify.py

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -407,11 +407,11 @@ def generate_OOV_parameterising_model(trained_qnlp_model, train_vocab_embeddings
407407
hist = OOV_NN_model.fit(NN_train_X, np.array(NN_train_Y), validation_split=0.2, verbose=1, epochs=epochs_model3_oov_model,callbacks=[callback])
408408
print(hist.history.keys())
409409
print(f'OOV NN model final epoch loss: {(hist.history["loss"][-1], hist.history["val_loss"][-1])}')
410-
plt.plot(hist.history['loss'], label='loss')
411-
plt.plot(hist.history['val_loss'], label='val_loss')
412-
plt.xlabel('Epoch')
413-
plt.ylabel('Error')
414-
plt.legend()
410+
# plt.plot(hist.history['loss'], label='loss')
411+
# plt.plot(hist.history['val_loss'], label='val_loss')
412+
# plt.xlabel('Epoch')
413+
# plt.ylabel('Error')
414+
# plt.legend()
415415
# plt.show() #code is expecting user closing the picture manually. commenting this temporarily since that was preventing the smooth run/debugging of code
416416

417417
best_model= OOV_NN_model
@@ -528,13 +528,18 @@ def read_glue_data(dataset_downloaded,split,lines_to_read=0):
528528

529529

530530

531-
def read_data(filename):
531+
def read_data(filename,lines_to_read):
532532
labels, sentences = [], []
533+
line_counter=0
533534
with open(filename) as f:
534535
for line in f:
535536
t = float(line[0])
536537
labels.append([t, 1-t])
537538
sentences.append(line[1:].strip())
539+
line_counter+=1
540+
if (line_counter> lines_to_read):
541+
break
542+
538543
return labels, sentences
539544

540545

@@ -595,7 +600,7 @@ def convert_diagram_to_circuits_with_try_catch(diagrams, ansatz, labels,split):
595600
return list_circuits, list_labels
596601

597602

598-
def run_experiment(train_diagrams, train_labels, val_diagrams, val_labels,test_diagrams,test_labels, eval_metrics,seed,embedding_model,ansatz_class, single_qubit_params,base_dimension_for_noun,base_dimension_for_sent,base_dimension_for_prep_phrase,no_of_layers_in_ansatz,expose_model1_val_during_model_initialization,batch_size,learning_rate_model1,model_class_to_use, epochs_train_model1, trainer_class_to_use,do_model3_tuning,learning_rate_model3 ,maxparams,epochs_model3_oov_model,model14type):
603+
def run_experiment(train_diagrams, train_labels, val_diagrams, val_labels,test_diagrams,test_labels, eval_metrics,seed,embedding_model,ansatz_class, single_qubit_params,base_dimension_for_noun,base_dimension_for_sent,base_dimension_for_prep_phrase,no_of_layers_in_ansatz,expose_model1_val_during_model_initialization,batch_size,learning_rate_model1,model_class_to_use, epochs_train_model1, trainer_class_to_use,do_model3_tuning,learning_rate_model3 ,maxparams,epochs_model3_oov_model,model14type,use_wandb):
599604
if ansatz_class in [IQPAnsatz,Sim15Ansatz, Sim14Ansatz]:
600605
ansatz_obj = ansatz_class ({AtomicType.NOUN: base_dimension_for_noun,
601606
AtomicType.SENTENCE: base_dimension_for_sent,
@@ -723,24 +728,24 @@ def run_experiment(train_diagrams, train_labels, val_diagrams, val_labels,test_d
723728
"""if there are no OOV words, we dont need the model 2 through model 4.
724729
just use model 1 to evaluate and exit"""
725730
if oov_word_count==0:
726-
import matplotlib.pyplot as plt
731+
727732
import numpy as np
728733

729-
fig1, ((ax_tl, ax_tr), (ax_bl, ax_br)) = plt.subplots(2, 2, sharey='row', figsize=(10, 6))
734+
# fig1, ((ax_tl, ax_tr), (ax_bl, ax_br)) = plt.subplots(2, 2, sharey='row', figsize=(10, 6))
730735

731-
ax_tl.set_title('Training set')
732-
ax_tr.set_title('Development set')
733-
ax_bl.set_xlabel('Epochs')
734-
ax_br.set_xlabel('Epochs')
735-
ax_bl.set_ylabel('Accuracy')
736-
ax_tl.set_ylabel('Loss')
736+
# ax_tl.set_title('Training set')
737+
# ax_tr.set_title('Development set')
738+
# ax_bl.set_xlabel('Epochs')
739+
# ax_br.set_xlabel('Epochs')
740+
# ax_bl.set_ylabel('Accuracy')
741+
# ax_tl.set_ylabel('Loss')
737742

738-
colours = iter(plt.rcParams['axes.prop_cycle'].by_key()['color'])
739-
range_ = np.arange(1, trainer_class_to_use.epochs+1)
740-
ax_tl.plot(range_, trainer_class_to_use.train_epoch_costs, color=next(colours))
741-
ax_bl.plot(range_, trainer_class_to_use.train_eval_results['acc'], color=next(colours))
742-
ax_tr.plot(range_, trainer_class_to_use.val_costs, color=next(colours))
743-
ax_br.plot(range_, trainer_class_to_use.val_eval_results['acc'], color=next(colours))
743+
# colours = iter(plt.rcParams['axes.prop_cycle'].by_key()['color'])
744+
# range_ = np.arange(1, trainer_class_to_use.epochs+1)
745+
# ax_tl.plot(range_, trainer_class_to_use.train_epoch_costs, color=next(colours))
746+
# ax_bl.plot(range_, trainer_class_to_use.train_eval_results['acc'], color=next(colours))
747+
# ax_tr.plot(range_, trainer_class_to_use.val_costs, color=next(colours))
748+
# ax_br.plot(range_, trainer_class_to_use.val_eval_results['acc'], color=next(colours))
744749

745750

746751
val_preds = model1_obj.get_diagram_output(val_circuits)
@@ -791,6 +796,9 @@ def run_experiment(train_diagrams, train_labels, val_diagrams, val_labels,test_d
791796

792797

793798
}
799+
800+
if bool(use_wandb):
801+
wandb.log({"accuracy_model4": smart_acc.item(), "loss_model4": smart_loss.item()})
794802

795803
return smart_loss.item(), smart_acc.item()
796804

@@ -869,7 +877,15 @@ def perform_task(args):
869877
# a unique name to identify this run inside wandb data and graph
870878
arch = f"{args.ansatz}+'_'+{args.dataset}+'_'+{args.parser}+'_'+{args.trainer}+'_'+{args.model14type}+'_'+{embedding_model}"
871879

872-
wandb.init(
880+
if bool(args.use_wandb):
881+
# Importing required module
882+
import subprocess
883+
884+
# Using system() method to
885+
# execute shell commands
886+
subprocess.Popen('WANDB online', shell=True)
887+
888+
wandb.init(
873889
project="qnlp_nov2024_expts",
874890
config={
875891
"learning_rate_model1": args.learning_rate_model1,
@@ -907,23 +923,16 @@ def perform_task(args):
907923

908924
else:
909925
#read the base data, i.e plain text english.
910-
train_labels, train_data = read_data(os.path.join(args.data_base_folder,TRAIN))
911-
val_labels, val_data = read_data(os.path.join(args.data_base_folder,DEV))
912-
test_labels, test_data = read_data(os.path.join(args.data_base_folder,TEST))
926+
train_labels, train_data = read_data(os.path.join(args.data_base_folder,TRAIN),lines_to_read= args.no_of_training_data_points_to_use)
927+
val_labels, val_data = read_data(os.path.join(args.data_base_folder,DEV),lines_to_read= args.no_of_training_data_points_to_use)
928+
test_labels, test_data = read_data(os.path.join(args.data_base_folder,TEST),lines_to_read= args.no_of_training_data_points_to_use)
913929

914930

915931

916932

917933
"""#some datasets like spanish, uspantek, sst2 have some sentences which bobcat doesnt like. putting it
918934
in a try catch, so that code doesnt completely halt/atleast rest of the dataset can be used
919935
"""
920-
# if (args.dataset in ["uspantek","spanish"]):
921-
# train_diagrams, train_labels = convert_to_diagrams_with_try_catch(parser_obj,train_data,train_labels,spacy_tokeniser, split="train")
922-
# val_diagrams, val_labels= convert_to_diagrams_with_try_catch(parser_obj,val_data,val_labels,spacy_tokeniser,split="val")
923-
# test_diagrams, test_labels = convert_to_diagrams_with_try_catch(parser_obj,test_data,test_labels,spacy_tokeniser,split="test")
924-
# else:
925-
926-
927936

928937
#convert the plain text input to ZX diagrams
929938
train_diagrams, train_labels = convert_to_diagrams_with_try_catch(args,parser_obj,train_data,train_labels,spacy_tokeniser, split="train")
@@ -970,7 +979,7 @@ def perform_task(args):
970979
# But commenting out due to lack of ram in laptop
971980
tf_seed = args.seed
972981
tf.random.set_seed(tf_seed)
973-
return run_experiment(train_diagrams, train_labels, val_diagrams, val_labels,test_diagrams,test_labels, eval_metrics,tf_seed,embedding_model,args.ansatz,args.single_qubit_params,args.base_dimension_for_noun,args.base_dimension_for_sent,args.base_dimension_for_prep_phrase, args.no_of_layers_in_ansatz,args.expose_model1_val_during_model_initialization , args.batch_size,args.learning_rate_model1,args.model14type, args.epochs_train_model1,args.trainer,args.do_model3_tuning,args.learning_rate_model3,args.maxparams,args.epochs_model3_oov_model, args.model14type)
982+
return run_experiment(train_diagrams, train_labels, val_diagrams, val_labels,test_diagrams,test_labels, eval_metrics,tf_seed,embedding_model,args.ansatz,args.single_qubit_params,args.base_dimension_for_noun,args.base_dimension_for_sent,args.base_dimension_for_prep_phrase, args.no_of_layers_in_ansatz,args.expose_model1_val_during_model_initialization , args.batch_size,args.learning_rate_model1,args.model14type, args.epochs_train_model1,args.trainer,args.do_model3_tuning,args.learning_rate_model3,args.maxparams,args.epochs_model3_oov_model, args.model14type, args.use_wandb)
974983

975984
def parse_name_model(val):
976985
try:
@@ -1054,8 +1063,7 @@ def do_debug(val):#uncomment only for debugging/accessing breakpoints
10541063
print("not doing any debugging")
10551064

10561065
def parse_arguments():
1057-
parser = argparse.ArgumentParser(description="Description of your script.")
1058-
parser.add_argument('--do_debug', action= "store_true",help="to run debug or not to debug. If yes, will uncomment the attachment code")
1066+
parser = argparse.ArgumentParser(description="Description of your script.")
10591067
parser.add_argument('--dataset', type=str, required=True, default="food_it" ,help="type of dataset-choose from [sst2,uspantek,spanish,food_it,msr_paraphrase_corpus,sst2")
10601068
parser.add_argument('--parser', type=parse_name_parser, required=True, help="type of parser to use: [BobCatParser, Spider]")
10611069
parser.add_argument('--ansatz', type=parse_name_ansatz, required=True, help="type of ansatz to use: [IQPAnsatz,SpiderAnsatz,Sim14Ansatz, Sim15Ansatz,TensorAnsatz ]")
@@ -1081,6 +1089,8 @@ def parse_arguments():
10811089
parser.add_argument('--no_of_test_data_points_to_use', type=int, default=10, required=False, help="65k of sst data was taking a long time. temporarily training on a smaller data")
10821090
parser.add_argument('--single_qubit_params', type=int, default=3, required=False, help="")
10831091
parser.add_argument('--max_tokens_per_sent', type=int, required=True, help="Bobcat parser doesn't like longer sentences 9 or 10 is like the upper limit")
1092+
parser.add_argument('--do_debug', action= "store_true",help="to run debug or not to debug. If yes, will uncomment the attachment code")
1093+
parser.add_argument('--use_wandb', action= "store_true",help="turn on wandb. making it optional since wandb doesnt work well with cyverse")
10841094

10851095

10861096
return parser.parse_args()

0 commit comments

Comments
 (0)