diff --git a/src/electrai/dataloader/mp.py b/src/electrai/dataloader/mp.py index 8684a26..5b7b880 100644 --- a/src/electrai/dataloader/mp.py +++ b/src/electrai/dataloader/mp.py @@ -66,21 +66,23 @@ def __init__( rho_type: str, data_augmentation: bool = True, random_state: int = 42, + tile_size: int | None = None, ): """ Parameters ---------- data: list of voxel data of length batch_size. rho_type: chgcar or elfcar. - data_size: target size of data. - label_size: target size of label. - pyrho_uf: pyrho upsampling factor + data_augmentation: whether to apply random rotations. + random_state: seed for reproducibility. + tile_size: spatial tile size for training (None = full volume). """ self.data = data self.data_precision = data_precision self.rho_type = rho_type self.da = data_augmentation self.rng = np.random.default_rng(random_state) + self.tile_size = tile_size def __len__(self): return len(self.data) @@ -122,6 +124,49 @@ def rotate(d): else: return [rotate(rotate(rotate(d))) for d in data_lst] + def extract_tile( + self, data: torch.Tensor, label: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """Extract random tile with periodic wrapping. + + Uses torch.tile to replicate unit cells smaller than tile_size, and create a supercell. + Then, uses torch.roll to shift by a random offset. Finally extracts a fixed-size tile + from the origin. + """ + if self.tile_size is None: + return data, label + + D, H, W = data.shape[-3:] + size = self.tile_size + + # Replicate if any dimension is smaller than tile_size + if size > D or size > H or size > W: + reps_d = (size + D - 1) // D + reps_h = (size + H - 1) // H + reps_w = (size + W - 1) // W + + # Data shape is (1, D, H, W) after unsqueeze + reps = (1, reps_d, reps_h, reps_w) + + data = torch.tile(data, reps) + label = torch.tile(label, reps) + + D, H, W = data.shape[-3:] + + # Random shift (handles periodicity via roll) + shift_d = int(self.rng.integers(0, D)) + shift_h = int(self.rng.integers(0, H)) + shift_w = int(self.rng.integers(0, W)) + + data = torch.roll(data, shifts=(shift_d, shift_h, shift_w), dims=(-3, -2, -1)) + label = torch.roll(label, shifts=(shift_d, shift_h, shift_w), dims=(-3, -2, -1)) + + # Extract tile from origin + data = data[..., :size, :size, :size] + label = label[..., :size, :size, :size] + + return data, label + def __getitem__(self, idx: int): data = self.read_data(self.data[idx][0]) label = self.read_data(self.data[idx][1]) @@ -136,6 +181,9 @@ def __getitem__(self, idx: int): data = torch.tensor(data, dtype=dtype_map[self.data_precision]).unsqueeze(0) label = torch.tensor(label, dtype=dtype_map[self.data_precision]).unsqueeze(0) + # Extract tile before augmentation (for memory efficiency) + data, label = self.extract_tile(data, label) + if self.da: data, label = self.rand_rotate([data, label]) return data, label @@ -168,19 +216,24 @@ def load_data(cfg): random_state=cfg.random_state, ).data_split() + tile_size = getattr(cfg, "tile_size", None) + train_data = RhoData( train_set, cfg.data_precision, cfg.rho_type, cfg.data_augmentation, cfg.random_state, + tile_size=tile_size, ) test_data = RhoData( test_set, cfg.data_precision, cfg.rho_type, - cfg.data_augmentation, - cfg.random_state, + data_augmentation=False, + random_state=cfg.random_state, + tile_size=None, # Full volume for validation ) + return train_data, test_data diff --git a/src/electrai/lightning.py b/src/electrai/lightning.py index a46c7d5..02702a5 100644 --- a/src/electrai/lightning.py +++ b/src/electrai/lightning.py @@ -41,8 +41,10 @@ def training_step(self, batch): def validation_step(self, batch): x, y = batch + pred = self(x) loss = self.loss_fn(pred, y) + self.log( "val_loss", loss, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True )