|
7 | 7 | - weights loaded later by asparagus.BaseModule via load_state_dict(strict=False) |
8 | 8 | """ |
9 | 9 |
|
| 10 | +import math |
| 11 | + |
| 12 | +from gardening_tools.modules.networks.BaseNet import BaseNet |
| 13 | +from gardening_tools.modules.networks.components.transformer import PatchDecode |
10 | 14 | import torch.nn as nn |
11 | 15 | from torch import Tensor |
12 | 16 |
|
@@ -62,18 +66,68 @@ def _features(self, x: Tensor) -> Tensor: |
62 | 66 | def forward(self, x: Tensor) -> Tensor: |
63 | 67 | """Encoder + head """ |
64 | 68 | return self.head(self._features(x)) |
65 | | - |
| 69 | + |
66 | 70 | def _encode(self, x: Tensor) -> Tensor: |
67 | 71 | """ Encoder output in format used for linear probing""" |
68 | 72 | feat = self._features(x) |
69 | 73 | return feat[:, :, None, None, None] |
70 | 74 |
|
71 | 75 |
|
72 | | -class SmriMaeSegBackbone(nn.Module): |
73 | | - """Placeholder for ViT-based segmentation backbone.""" |
| 76 | +class SmriMaeSegBackbone(BaseNet): |
| 77 | + """MAE ViT segmentation backbone with a Primus-like patch decoder.""" |
| 78 | + |
| 79 | + def __init__( |
| 80 | + self, |
| 81 | + input_channels: int, |
| 82 | + output_channels: int, |
| 83 | + img_size: int | tuple[int, int, int] = (160, 160, 160), |
| 84 | + patch_size: int | tuple[int, int, int] = (16, 16, 16), |
| 85 | + depth: int = 12, |
| 86 | + embed_dim: int = 768, |
| 87 | + num_heads: int = 12, |
| 88 | + dimensions: str = "3D", |
| 89 | + **_ignored, |
| 90 | + ): |
| 91 | + super().__init__() |
| 92 | + assert dimensions == "3D", f"only 3D supported, got dimensions={dimensions}" |
| 93 | + |
| 94 | + self.num_classes = output_channels |
| 95 | + self.stem_weight_name = "encoder.patch_embed.weight" |
74 | 96 |
|
75 | | - def __init__(self, *args, **kwargs): |
76 | | - raise NotImplementedError("SmriMaeSegBackbone is not yet implemented") |
| 97 | + self.encoder = MaskedViT( |
| 98 | + img_size=img_size, |
| 99 | + patch_size=patch_size, |
| 100 | + in_chans=input_channels, |
| 101 | + depth=depth, |
| 102 | + embed_dim=embed_dim, |
| 103 | + num_heads=num_heads, |
| 104 | + class_token=True, |
| 105 | + ) |
| 106 | + self.grid_size = self.encoder.patchify.grid_size |
| 107 | + self.decoder = PatchDecode( |
| 108 | + patch_size=self.encoder.patchify.patch_size, |
| 109 | + embed_dim=embed_dim, |
| 110 | + out_channels=output_channels, |
| 111 | + ) |
77 | 112 |
|
78 | 113 | def forward(self, x: Tensor) -> Tensor: |
79 | | - raise NotImplementedError |
| 114 | + _, _, patch_embeds, _, _ = self.encoder(x) |
| 115 | + expected_tokens = math.prod(self.grid_size) |
| 116 | + if patch_embeds.shape[1] != expected_tokens: |
| 117 | + raise ValueError( |
| 118 | + "unexpected MAE patch token count: " |
| 119 | + f"got {patch_embeds.shape[1]}, expected {expected_tokens}" |
| 120 | + ) |
| 121 | + |
| 122 | + features = patch_embeds.reshape( |
| 123 | + x.shape[0], |
| 124 | + *self.grid_size, |
| 125 | + patch_embeds.shape[-1], |
| 126 | + ) |
| 127 | + features = features.permute(0, 4, 1, 2, 3).contiguous() |
| 128 | + return self.decoder(features) |
| 129 | + |
| 130 | + # Inherits BaseNet.sliding_window_predict. That implementation sums |
| 131 | + # overlapping logits without normalizing by an overlap-count map. |
| 132 | + # TODO: replace inherited sliding-window accumulation with normalized overlap averaging. |
| 133 | + # TODO: consider Gaussian/Hann weighting so patch centers contribute more than patch borders. |
0 commit comments