Skip to content

Commit 14a0aef

Browse files
committed
inception model
1 parent 1a2a76e commit 14a0aef

2 files changed

Lines changed: 11 additions & 219 deletions

File tree

src/napatrackmater/Trackvector.py

Lines changed: 11 additions & 217 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,61 +1227,6 @@ def TrackVolumeMaker(
12271227
)
12281228

12291229

1230-
def create_h5(
1231-
save_dir,
1232-
train_size=0.95,
1233-
save_name="cellfate_vision_training_data_gbr",
1234-
):
1235-
"""
1236-
Create HDF5 file with training and validation data for morphodynamic model in TZYX format.
1237-
1238-
Args:
1239-
save_dir (str): Directory containing image and label files.
1240-
train_size (float): Proportion of data to use for training.
1241-
save_name (str): Name of the output HDF5 file (without extension).
1242-
"""
1243-
data = []
1244-
labels = []
1245-
1246-
# Gather all TIFF files and their corresponding labels
1247-
all_files = os.listdir(save_dir)
1248-
tif_files = sorted([f for f in all_files if f.endswith(".tif")])
1249-
1250-
# Load images and labels in T, Z, Y, X format
1251-
for tif_file in tif_files:
1252-
image_path = os.path.join(save_dir, tif_file)
1253-
image = imread(image_path)
1254-
1255-
# Ensure the image has T, Z, Y, X format, where T is the leading dimension
1256-
if image.ndim != 4:
1257-
raise ValueError(
1258-
f"Image {tif_file} does not have four dimensions, expected T, Z, Y, X."
1259-
)
1260-
1261-
data.append(image)
1262-
1263-
# Load corresponding label CSV
1264-
csv_path = os.path.join(save_dir, os.path.splitext(tif_file)[0] + ".csv")
1265-
with open(csv_path) as csvfile:
1266-
reader = csv.reader(csvfile, delimiter=",")
1267-
label = np.array(list(reader)[0]).astype(np.float32)
1268-
labels.append(label)
1269-
1270-
data = np.array(data) # Shape: (N, T, Z, Y, X)
1271-
labels = np.array(labels) # Shape: (N, label_dim)
1272-
1273-
train_data, val_data, train_labels, val_labels = train_test_split(
1274-
data, labels, train_size=train_size, shuffle=True, random_state=42
1275-
)
1276-
1277-
h5_save_path = os.path.join(save_dir, f"{save_name}.h5")
1278-
with h5py.File(h5_save_path, "w") as hf:
1279-
hf.create_dataset("train_arrays", data=train_data)
1280-
hf.create_dataset("train_labels", data=train_labels)
1281-
hf.create_dataset("val_arrays", data=val_data)
1282-
hf.create_dataset("val_labels", data=val_labels)
1283-
1284-
print(f"HDF5 training data saved to {h5_save_path} in T, Z, Y, X format.")
12851230

12861231

12871232
def create_analysis_tracklets(
@@ -2870,7 +2815,9 @@ def train_gbr_neural_net(
28702815
max_shift=1.05,
28712816
max_scale=1.05,
28722817
max_mask_ratio=0.1,
2873-
augment = False
2818+
augment = False,
2819+
attn_heads = 8,
2820+
seq_len = 25
28742821
):
28752822

28762823
if isinstance(block_config, int):
@@ -2894,6 +2841,8 @@ def train_gbr_neural_net(
28942841
learning_rate=learning_rate,
28952842
n_pos=n_pos,
28962843
attention_dim=attention_dim,
2844+
attn_heads = attn_heads,
2845+
seq_len = seq_len
28972846
)
28982847

28992848
if augment:
@@ -2946,6 +2895,8 @@ def train_mitosis_neural_net(
29462895
scheduler_choice="plateau",
29472896
attention_dim: int = 64,
29482897
n_pos: list = (8,),
2898+
attn_heads = 8,
2899+
seq_len = 25
29492900
):
29502901

29512902
if isinstance(block_config, int):
@@ -2970,6 +2921,8 @@ def train_mitosis_neural_net(
29702921
learning_rate=learning_rate,
29712922
n_pos=n_pos,
29722923
attention_dim=attention_dim,
2924+
attn_heads = attn_heads,
2925+
seq_len = seq_len
29732926
)
29742927

29752928
mitosis_inception.setup_timeseries_transforms()
@@ -2980,6 +2933,8 @@ def train_mitosis_neural_net(
29802933
mitosis_inception.setup_densenet_model()
29812934
if model_type == "attention":
29822935
mitosis_inception.setup_hybrid_attention_model()
2936+
if model_type == 'qkv':
2937+
mitosis_inception.setup_inception_qkv_model()
29832938

29842939
mitosis_inception.setup_logger()
29852940
mitosis_inception.setup_checkpoint()
@@ -2988,65 +2943,6 @@ def train_mitosis_neural_net(
29882943
mitosis_inception.train()
29892944

29902945

2991-
def train_gbr_vision_neural_net(
2992-
save_path,
2993-
h5_file,
2994-
input_shape,
2995-
box_vector=7,
2996-
start_kernel=7,
2997-
mid_kernel=3,
2998-
startfilter=64,
2999-
growth_rate=32,
3000-
depth={"depth_0": 12, "depth_1": 24, "depth_2": 16},
3001-
num_classes=3,
3002-
batch_size=64,
3003-
num_workers=0,
3004-
learning_rate=0.001,
3005-
epochs=100,
3006-
accelerator="cuda",
3007-
devices=1,
3008-
loss_function="oneat",
3009-
experiment_name="mitosis",
3010-
scheduler_choice="plateau",
3011-
oneat_accuracy=True,
3012-
crop_size=None,
3013-
pool_first=True,
3014-
):
3015-
3016-
mitosis_inception = MitosisInception(
3017-
h5_file=h5_file,
3018-
num_classes=num_classes,
3019-
num_workers=num_workers,
3020-
epochs=epochs,
3021-
log_path=save_path,
3022-
batch_size=batch_size,
3023-
accelerator=accelerator,
3024-
devices=devices,
3025-
experiment_name=experiment_name,
3026-
scheduler_choice=scheduler_choice,
3027-
loss_function=loss_function,
3028-
learning_rate=learning_rate,
3029-
)
3030-
3031-
mitosis_inception.setup_gbr_vision_h5_datasets(crop_size=crop_size)
3032-
3033-
mitosis_inception.setup_densenet_vision_model(
3034-
input_shape,
3035-
num_classes,
3036-
box_vector,
3037-
start_kernel,
3038-
mid_kernel,
3039-
startfilter,
3040-
depth,
3041-
growth_rate,
3042-
pool_first=pool_first,
3043-
)
3044-
3045-
mitosis_inception.setup_logger()
3046-
mitosis_inception.setup_checkpoint()
3047-
mitosis_inception.setup_adam()
3048-
mitosis_inception.setup_lightning_model(oneat_accuracy=oneat_accuracy)
3049-
mitosis_inception.train()
30502946

30512947

30522948
def plot_metrics_from_npz(npz_file):
@@ -4529,109 +4425,7 @@ def weighted_prediction(predictions, weights):
45294425
return most_common_prediction
45304426

45314427

4532-
def vision_inception_model_prediction(
4533-
dataframe,
4534-
trackmate_id,
4535-
raw_image,
4536-
class_map,
4537-
model,
4538-
device="cpu",
4539-
crop_size=(25, 8, 128, 128),
4540-
):
4541-
"""
4542-
Generate predictions for an inception-style vision model based on patches around each point in a tracklet.
4543-
4544-
Parameters:
4545-
dataframe (pd.DataFrame): The dataframe containing track information.
4546-
trackmate_id (int): The TrackMate track ID for which to generate predictions.
4547-
tracklet_length (int): The number of time points in each tracklet.
4548-
raw_image (np.array): The raw image from which patches are extracted.
4549-
class_map (dict): Mapping of class indices to labels.
4550-
model (torch.nn.Module): The model to use for predictions.
4551-
device (str): Device for running predictions, 'cpu' or 'cuda'.
4552-
crop_size (tuple): Size of the crop around each point (imagesizex, imagesizey, imagesizez).
4553-
4554-
Returns:
4555-
predictions (list): Predicted class labels for each tracklet.
4556-
weights (list): Prediction confidence or logits for each tracklet.
4557-
"""
4558-
model = model.to(device)
4559-
model.eval()
4560-
sub_trackmate_dataframe = dataframe[dataframe["TrackMate Track ID"] == trackmate_id]
45614428

4562-
# Extract the dimensions of the crop
4563-
imagesizet, sizez, sizex, sizey = crop_size
4564-
tracklet_predictions = []
4565-
tracklet_weights = []
4566-
4567-
for tracklet_id in sub_trackmate_dataframe["Track ID"].unique():
4568-
tracklet_sub_dataframe = sub_trackmate_dataframe[
4569-
sub_trackmate_dataframe["Track ID"] == tracklet_id
4570-
]
4571-
4572-
sub_trackmate_dataframe = tracklet_sub_dataframe.sort_values(by="t")
4573-
total_duration = sub_trackmate_dataframe["Track Duration"].max()
4574-
tracklet_blocks = []
4575-
for i in range(0, len(sub_trackmate_dataframe), imagesizet):
4576-
tracklet_block = sub_trackmate_dataframe.iloc[i : i + imagesizet][
4577-
["t", "z", "y", "x"]
4578-
].values
4579-
if len(tracklet_block) == imagesizet:
4580-
tracklet_blocks.append(tracklet_block)
4581-
4582-
for tracklet_block in tracklet_blocks:
4583-
stitched_volume = []
4584-
for (t, z, y, x) in tracklet_block:
4585-
small_image = raw_image[int(t)]
4586-
4587-
if (
4588-
x > sizex / 2
4589-
and z > sizez / 2
4590-
and y > sizey / 2
4591-
and z + int(sizez / 2) < raw_image.shape[1]
4592-
and y + int(sizey / 2) < raw_image.shape[2]
4593-
and x + int(sizex / 2) < raw_image.shape[3]
4594-
and t < raw_image.shape[0]
4595-
):
4596-
crop_xminus = x - int(sizex / 2)
4597-
crop_xplus = x + int(sizex / 2)
4598-
crop_yminus = y - int(sizey / 2)
4599-
crop_yplus = y + int(sizey / 2)
4600-
crop_zminus = z - int(sizez / 2)
4601-
crop_zplus = z + int(sizez / 2)
4602-
region = (
4603-
slice(int(crop_zminus), int(crop_zplus)),
4604-
slice(int(crop_yminus), int(crop_yplus)),
4605-
slice(int(crop_xminus), int(crop_xplus)),
4606-
)
4607-
4608-
crop_image = small_image[region]
4609-
stitched_volume.append(crop_image)
4610-
stitched_volume = np.stack(stitched_volume, axis=0)
4611-
if stitched_volume.shape[0] == imagesizet:
4612-
with torch.no_grad():
4613-
prediction_vector = model(
4614-
torch.unsqueeze(
4615-
torch.tensor(stitched_volume, dtype=torch.float32), dim=0
4616-
)
4617-
)
4618-
4619-
class_logits = prediction_vector[0, : len(class_map)]
4620-
4621-
most_frequent_prediction = get_most_frequent_prediction(class_logits)
4622-
if most_frequent_prediction is not None:
4623-
most_predicted_class = class_map[int(most_frequent_prediction)]
4624-
tracklet_predictions.append(most_predicted_class)
4625-
tracklet_weights.append(total_duration)
4626-
4627-
if tracklet_predictions:
4628-
final_weighted_prediction = weighted_prediction(
4629-
tracklet_predictions, tracklet_weights
4630-
)
4631-
return final_weighted_prediction
4632-
4633-
else:
4634-
return "UnClassified"
46354429

46364430

46374431
def inception_dual_model_prediction(

src/napatrackmater/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
filter_and_get_tracklets,
3838
create_h5,
3939
normalize_image_in_chunks,
40-
vision_inception_model_prediction,
4140
inception_dual_model_prediction
4241
)
4342

@@ -133,7 +132,6 @@ def load_json(fpath):
133132
"filter_and_get_tracklets",
134133
"create_h5",
135134
"normalize_image_in_chunks",
136-
"vision_inception_model_prediction",
137135
"inception_dual_model_prediction",
138136
"affine_transform",
139137
"apply_alpha_drift",

0 commit comments

Comments
 (0)