@@ -1227,61 +1227,6 @@ def TrackVolumeMaker(
12271227 )
12281228
12291229
1230- def create_h5 (
1231- save_dir ,
1232- train_size = 0.95 ,
1233- save_name = "cellfate_vision_training_data_gbr" ,
1234- ):
1235- """
1236- Create HDF5 file with training and validation data for morphodynamic model in TZYX format.
1237-
1238- Args:
1239- save_dir (str): Directory containing image and label files.
1240- train_size (float): Proportion of data to use for training.
1241- save_name (str): Name of the output HDF5 file (without extension).
1242- """
1243- data = []
1244- labels = []
1245-
1246- # Gather all TIFF files and their corresponding labels
1247- all_files = os .listdir (save_dir )
1248- tif_files = sorted ([f for f in all_files if f .endswith (".tif" )])
1249-
1250- # Load images and labels in T, Z, Y, X format
1251- for tif_file in tif_files :
1252- image_path = os .path .join (save_dir , tif_file )
1253- image = imread (image_path )
1254-
1255- # Ensure the image has T, Z, Y, X format, where T is the leading dimension
1256- if image .ndim != 4 :
1257- raise ValueError (
1258- f"Image { tif_file } does not have four dimensions, expected T, Z, Y, X."
1259- )
1260-
1261- data .append (image )
1262-
1263- # Load corresponding label CSV
1264- csv_path = os .path .join (save_dir , os .path .splitext (tif_file )[0 ] + ".csv" )
1265- with open (csv_path ) as csvfile :
1266- reader = csv .reader (csvfile , delimiter = "," )
1267- label = np .array (list (reader )[0 ]).astype (np .float32 )
1268- labels .append (label )
1269-
1270- data = np .array (data ) # Shape: (N, T, Z, Y, X)
1271- labels = np .array (labels ) # Shape: (N, label_dim)
1272-
1273- train_data , val_data , train_labels , val_labels = train_test_split (
1274- data , labels , train_size = train_size , shuffle = True , random_state = 42
1275- )
1276-
1277- h5_save_path = os .path .join (save_dir , f"{ save_name } .h5" )
1278- with h5py .File (h5_save_path , "w" ) as hf :
1279- hf .create_dataset ("train_arrays" , data = train_data )
1280- hf .create_dataset ("train_labels" , data = train_labels )
1281- hf .create_dataset ("val_arrays" , data = val_data )
1282- hf .create_dataset ("val_labels" , data = val_labels )
1283-
1284- print (f"HDF5 training data saved to { h5_save_path } in T, Z, Y, X format." )
12851230
12861231
12871232def create_analysis_tracklets (
@@ -2870,7 +2815,9 @@ def train_gbr_neural_net(
28702815 max_shift = 1.05 ,
28712816 max_scale = 1.05 ,
28722817 max_mask_ratio = 0.1 ,
2873- augment = False
2818+ augment = False ,
2819+ attn_heads = 8 ,
2820+ seq_len = 25
28742821):
28752822
28762823 if isinstance (block_config , int ):
@@ -2894,6 +2841,8 @@ def train_gbr_neural_net(
28942841 learning_rate = learning_rate ,
28952842 n_pos = n_pos ,
28962843 attention_dim = attention_dim ,
2844+ attn_heads = attn_heads ,
2845+ seq_len = seq_len
28972846 )
28982847
28992848 if augment :
@@ -2946,6 +2895,8 @@ def train_mitosis_neural_net(
29462895 scheduler_choice = "plateau" ,
29472896 attention_dim : int = 64 ,
29482897 n_pos : list = (8 ,),
2898+ attn_heads = 8 ,
2899+ seq_len = 25
29492900):
29502901
29512902 if isinstance (block_config , int ):
@@ -2970,6 +2921,8 @@ def train_mitosis_neural_net(
29702921 learning_rate = learning_rate ,
29712922 n_pos = n_pos ,
29722923 attention_dim = attention_dim ,
2924+ attn_heads = attn_heads ,
2925+ seq_len = seq_len
29732926 )
29742927
29752928 mitosis_inception .setup_timeseries_transforms ()
@@ -2980,6 +2933,8 @@ def train_mitosis_neural_net(
29802933 mitosis_inception .setup_densenet_model ()
29812934 if model_type == "attention" :
29822935 mitosis_inception .setup_hybrid_attention_model ()
2936+ if model_type == 'qkv' :
2937+ mitosis_inception .setup_inception_qkv_model ()
29832938
29842939 mitosis_inception .setup_logger ()
29852940 mitosis_inception .setup_checkpoint ()
@@ -2988,65 +2943,6 @@ def train_mitosis_neural_net(
29882943 mitosis_inception .train ()
29892944
29902945
2991- def train_gbr_vision_neural_net (
2992- save_path ,
2993- h5_file ,
2994- input_shape ,
2995- box_vector = 7 ,
2996- start_kernel = 7 ,
2997- mid_kernel = 3 ,
2998- startfilter = 64 ,
2999- growth_rate = 32 ,
3000- depth = {"depth_0" : 12 , "depth_1" : 24 , "depth_2" : 16 },
3001- num_classes = 3 ,
3002- batch_size = 64 ,
3003- num_workers = 0 ,
3004- learning_rate = 0.001 ,
3005- epochs = 100 ,
3006- accelerator = "cuda" ,
3007- devices = 1 ,
3008- loss_function = "oneat" ,
3009- experiment_name = "mitosis" ,
3010- scheduler_choice = "plateau" ,
3011- oneat_accuracy = True ,
3012- crop_size = None ,
3013- pool_first = True ,
3014- ):
3015-
3016- mitosis_inception = MitosisInception (
3017- h5_file = h5_file ,
3018- num_classes = num_classes ,
3019- num_workers = num_workers ,
3020- epochs = epochs ,
3021- log_path = save_path ,
3022- batch_size = batch_size ,
3023- accelerator = accelerator ,
3024- devices = devices ,
3025- experiment_name = experiment_name ,
3026- scheduler_choice = scheduler_choice ,
3027- loss_function = loss_function ,
3028- learning_rate = learning_rate ,
3029- )
3030-
3031- mitosis_inception .setup_gbr_vision_h5_datasets (crop_size = crop_size )
3032-
3033- mitosis_inception .setup_densenet_vision_model (
3034- input_shape ,
3035- num_classes ,
3036- box_vector ,
3037- start_kernel ,
3038- mid_kernel ,
3039- startfilter ,
3040- depth ,
3041- growth_rate ,
3042- pool_first = pool_first ,
3043- )
3044-
3045- mitosis_inception .setup_logger ()
3046- mitosis_inception .setup_checkpoint ()
3047- mitosis_inception .setup_adam ()
3048- mitosis_inception .setup_lightning_model (oneat_accuracy = oneat_accuracy )
3049- mitosis_inception .train ()
30502946
30512947
30522948def plot_metrics_from_npz (npz_file ):
@@ -4529,109 +4425,7 @@ def weighted_prediction(predictions, weights):
45294425 return most_common_prediction
45304426
45314427
4532- def vision_inception_model_prediction (
4533- dataframe ,
4534- trackmate_id ,
4535- raw_image ,
4536- class_map ,
4537- model ,
4538- device = "cpu" ,
4539- crop_size = (25 , 8 , 128 , 128 ),
4540- ):
4541- """
4542- Generate predictions for an inception-style vision model based on patches around each point in a tracklet.
4543-
4544- Parameters:
4545- dataframe (pd.DataFrame): The dataframe containing track information.
4546- trackmate_id (int): The TrackMate track ID for which to generate predictions.
4547- tracklet_length (int): The number of time points in each tracklet.
4548- raw_image (np.array): The raw image from which patches are extracted.
4549- class_map (dict): Mapping of class indices to labels.
4550- model (torch.nn.Module): The model to use for predictions.
4551- device (str): Device for running predictions, 'cpu' or 'cuda'.
4552- crop_size (tuple): Size of the crop around each point (imagesizex, imagesizey, imagesizez).
4553-
4554- Returns:
4555- predictions (list): Predicted class labels for each tracklet.
4556- weights (list): Prediction confidence or logits for each tracklet.
4557- """
4558- model = model .to (device )
4559- model .eval ()
4560- sub_trackmate_dataframe = dataframe [dataframe ["TrackMate Track ID" ] == trackmate_id ]
45614428
4562- # Extract the dimensions of the crop
4563- imagesizet , sizez , sizex , sizey = crop_size
4564- tracklet_predictions = []
4565- tracklet_weights = []
4566-
4567- for tracklet_id in sub_trackmate_dataframe ["Track ID" ].unique ():
4568- tracklet_sub_dataframe = sub_trackmate_dataframe [
4569- sub_trackmate_dataframe ["Track ID" ] == tracklet_id
4570- ]
4571-
4572- sub_trackmate_dataframe = tracklet_sub_dataframe .sort_values (by = "t" )
4573- total_duration = sub_trackmate_dataframe ["Track Duration" ].max ()
4574- tracklet_blocks = []
4575- for i in range (0 , len (sub_trackmate_dataframe ), imagesizet ):
4576- tracklet_block = sub_trackmate_dataframe .iloc [i : i + imagesizet ][
4577- ["t" , "z" , "y" , "x" ]
4578- ].values
4579- if len (tracklet_block ) == imagesizet :
4580- tracklet_blocks .append (tracklet_block )
4581-
4582- for tracklet_block in tracklet_blocks :
4583- stitched_volume = []
4584- for (t , z , y , x ) in tracklet_block :
4585- small_image = raw_image [int (t )]
4586-
4587- if (
4588- x > sizex / 2
4589- and z > sizez / 2
4590- and y > sizey / 2
4591- and z + int (sizez / 2 ) < raw_image .shape [1 ]
4592- and y + int (sizey / 2 ) < raw_image .shape [2 ]
4593- and x + int (sizex / 2 ) < raw_image .shape [3 ]
4594- and t < raw_image .shape [0 ]
4595- ):
4596- crop_xminus = x - int (sizex / 2 )
4597- crop_xplus = x + int (sizex / 2 )
4598- crop_yminus = y - int (sizey / 2 )
4599- crop_yplus = y + int (sizey / 2 )
4600- crop_zminus = z - int (sizez / 2 )
4601- crop_zplus = z + int (sizez / 2 )
4602- region = (
4603- slice (int (crop_zminus ), int (crop_zplus )),
4604- slice (int (crop_yminus ), int (crop_yplus )),
4605- slice (int (crop_xminus ), int (crop_xplus )),
4606- )
4607-
4608- crop_image = small_image [region ]
4609- stitched_volume .append (crop_image )
4610- stitched_volume = np .stack (stitched_volume , axis = 0 )
4611- if stitched_volume .shape [0 ] == imagesizet :
4612- with torch .no_grad ():
4613- prediction_vector = model (
4614- torch .unsqueeze (
4615- torch .tensor (stitched_volume , dtype = torch .float32 ), dim = 0
4616- )
4617- )
4618-
4619- class_logits = prediction_vector [0 , : len (class_map )]
4620-
4621- most_frequent_prediction = get_most_frequent_prediction (class_logits )
4622- if most_frequent_prediction is not None :
4623- most_predicted_class = class_map [int (most_frequent_prediction )]
4624- tracklet_predictions .append (most_predicted_class )
4625- tracklet_weights .append (total_duration )
4626-
4627- if tracklet_predictions :
4628- final_weighted_prediction = weighted_prediction (
4629- tracklet_predictions , tracklet_weights
4630- )
4631- return final_weighted_prediction
4632-
4633- else :
4634- return "UnClassified"
46354429
46364430
46374431def inception_dual_model_prediction (
0 commit comments