Skip to content

Commit 22a21fe

Browse files
author
valhassan
committed
Refactor SegmentationDOFA to enable encoder freezing and improve trainable parameter tracking.
1 parent ae958fd commit 22a21fe

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

geo_deep_learning/tasks_with_models/segmentation_dofa.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,22 @@ def __init__(self,
4040
self.std = std
4141
self.data_type_max = data_type_max
4242
self.num_classes = num_classes
43-
self.model = DOFASeg(encoder, pretrained, image_size, wavelengths, self.num_classes)
43+
self.model = DOFASeg(encoder, pretrained, freeze_encoder=False,
44+
image_size=image_size, wavelengths=wavelengths,
45+
num_classes=self.num_classes)
46+
47+
# param_status = self.model.get_trainable_parameters()
48+
# print(f"Trainable parameters: {param_status['trainable']}")
49+
# print(f"Frozen parameters: {param_status['frozen']}")
50+
51+
# Count trainable vs total parameters
52+
# total_params = sum(p.numel() for p in self.model.parameters())
53+
# trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
54+
55+
# print(f"\nTotal parameters: {total_params:,}")
56+
# print(f"Trainable parameters: {trainable_params:,}")
57+
# print(f"Percentage trainable: {100 * trainable_params / total_params:.2f}%")
58+
4459
if weights_from_checkpoint_path:
4560
print(f"Loading weights from checkpoint: {weights_from_checkpoint_path}")
4661
checkpoint = torch.load(weights_from_checkpoint_path)

0 commit comments

Comments
 (0)