diff --git a/configs/base.yaml b/configs/base.yaml index 66a1153a..54e0f15a 100644 --- a/configs/base.yaml +++ b/configs/base.yaml @@ -14,4 +14,4 @@ train: no_batch_norm: false initial_learning_rate: 0.01 weight_decay: 0.0001 - epochs: 10 + epochs: 3 \ No newline at end of file diff --git a/main.py b/main.py index 5f6c86da..7e9dbb65 100644 --- a/main.py +++ b/main.py @@ -13,6 +13,8 @@ from typing import Dict, Any import os import time +import torch.nn as nn +import torch.nn.functional as F class RepresentationType(Enum): @@ -27,6 +29,8 @@ def set_seed(seed): torch.backends.cudnn.benchmark = False np.random.seed(seed) +# 変更前 +""" def compute_epe_error(pred_flow: torch.Tensor, gt_flow: torch.Tensor): ''' end-point-error (ground truthと予測値の二乗誤差)を計算 @@ -35,6 +39,20 @@ def compute_epe_error(pred_flow: torch.Tensor, gt_flow: torch.Tensor): ''' epe = torch.mean(torch.mean(torch.norm(pred_flow - gt_flow, p=2, dim=1), dim=(1, 2)), dim=0) return epe +""" +# 変更後 +def compute_epe_error(pred_flow_dict: Dict[str, torch.Tensor], gt_flow: torch.Tensor): + ''' + end-point-error (ground truthと予測値の二乗誤差)を計算 + pred_flow_dict: Dict[str, torch.Tensor] => 予測したオプティカルフローデータの辞書 + gt_flow: torch.Tensor, Shape: torch.Size([B, 2, 480, 640]) => 正解のオプティカルフローデータ + ''' + # 平均フローを使用してEPEを計算 + avg_flow = pred_flow_dict['avg_flow'] + gt_flow_resized = nn.functional.interpolate(gt_flow, size=avg_flow.shape[-2:], mode='bilinear', align_corners=False) + epe = torch.mean(torch.norm(avg_flow - gt_flow_resized, p=2, dim=1)) + + return epe def save_optical_flow_to_npy(flow: torch.Tensor, file_name: str): ''' @@ -126,6 +144,7 @@ def main(args: DictConfig): for i, batch in enumerate(tqdm(train_data)): batch: Dict[str, Any] event_image = batch["event_volume"].to(device) # [B, 4, 480, 640] + print(batch["event_volume"].shape) ground_truth_flow = batch["flow_gt"].to(device) # [B, 2, 480, 640] flow = model(event_image) # [B, 2, 480, 640] loss: torch.Tensor = compute_epe_error(flow, ground_truth_flow) @@ -152,6 +171,8 @@ def main(args: DictConfig): model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() flow: torch.Tensor = torch.tensor([]).to(device) + # 変更前 + """ with torch.no_grad(): print("start test") for batch in tqdm(test_data): @@ -159,6 +180,16 @@ def main(args: DictConfig): event_image = batch["event_volume"].to(device) batch_flow = model(event_image) # [1, 2, 480, 640] flow = torch.cat((flow, batch_flow), dim=0) # [N, 2, 480, 640] + print("test done")""" + # 変更後 + with torch.no_grad(): + print("start test") + for batch in tqdm(test_data): + event_image = batch["event_volume"].to(device) # [B, 2, 4, 480, 640] - > LSTM - > [B, 4, 480, 640] + batch_flow_dict = model(event_image) # モデルの出力が辞書 + batch_flow = batch_flow_dict['avg_flow'] # 平均フローデータを辞書から取得 + flow = torch.cat((flow, batch_flow), dim=0) # 必要に応じて形状を調整 + print("test done") # ------------------ # save submission diff --git a/src/models/evflownet.py b/src/models/evflownet.py index ddfab828..04ee3ec8 100644 --- a/src/models/evflownet.py +++ b/src/models/evflownet.py @@ -2,6 +2,7 @@ from torch import nn from src.models.base import * from typing import Dict, Any +import torch.nn.functional as F _BASE_CHANNELS = 64 @@ -29,7 +30,7 @@ def __init__(self, args): self.decoder4 = upsample_conv2d_and_predict_flow(in_channels=2*_BASE_CHANNELS+2, out_channels=int(_BASE_CHANNELS/2), do_batch_norm=not self._args.no_batch_norm) - def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def forward(self, inputs: torch.Tensor) -> Dict[str, torch.Tensor]: # encoder skip_connections = {} inputs = self.encoder1(inputs) @@ -62,7 +63,13 @@ def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: inputs, flow = self.decoder4(inputs) flow_dict['flow3'] = flow.clone() - return flow + max_size = [max(flow.size(2) for flow in flow_dict.values()), max(flow.size(3) for flow in flow_dict.values())] + resized_flows = [F.interpolate(flow, size=max_size, mode='bilinear', align_corners=False) for flow in flow_dict.values()] + total_flow = sum(resized_flows) + avg_flow = total_flow / len(flow_dict) + flow_dict['avg_flow'] = avg_flow + + return flow_dict # if __name__ == "__main__":