Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ train:
no_batch_norm: false
initial_learning_rate: 0.01
weight_decay: 0.0001
epochs: 10
epochs: 3
31 changes: 31 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from typing import Dict, Any
import os
import time
import torch.nn as nn
import torch.nn.functional as F


class RepresentationType(Enum):
Expand All @@ -27,6 +29,8 @@ def set_seed(seed):
torch.backends.cudnn.benchmark = False
np.random.seed(seed)

# 変更前
"""
def compute_epe_error(pred_flow: torch.Tensor, gt_flow: torch.Tensor):
'''
end-point-error (ground truthと予測値の二乗誤差)を計算
Expand All @@ -35,6 +39,20 @@ def compute_epe_error(pred_flow: torch.Tensor, gt_flow: torch.Tensor):
'''
epe = torch.mean(torch.mean(torch.norm(pred_flow - gt_flow, p=2, dim=1), dim=(1, 2)), dim=0)
return epe
"""
# 変更後
def compute_epe_error(pred_flow_dict: Dict[str, torch.Tensor], gt_flow: torch.Tensor):
'''
end-point-error (ground truthと予測値の二乗誤差)を計算
pred_flow_dict: Dict[str, torch.Tensor] => 予測したオプティカルフローデータの辞書
gt_flow: torch.Tensor, Shape: torch.Size([B, 2, 480, 640]) => 正解のオプティカルフローデータ
'''
# 平均フローを使用してEPEを計算
avg_flow = pred_flow_dict['avg_flow']
gt_flow_resized = nn.functional.interpolate(gt_flow, size=avg_flow.shape[-2:], mode='bilinear', align_corners=False)
epe = torch.mean(torch.norm(avg_flow - gt_flow_resized, p=2, dim=1))

return epe

def save_optical_flow_to_npy(flow: torch.Tensor, file_name: str):
'''
Expand Down Expand Up @@ -126,6 +144,7 @@ def main(args: DictConfig):
for i, batch in enumerate(tqdm(train_data)):
batch: Dict[str, Any]
event_image = batch["event_volume"].to(device) # [B, 4, 480, 640]
print(batch["event_volume"].shape)
ground_truth_flow = batch["flow_gt"].to(device) # [B, 2, 480, 640]
flow = model(event_image) # [B, 2, 480, 640]
loss: torch.Tensor = compute_epe_error(flow, ground_truth_flow)
Expand All @@ -152,13 +171,25 @@ def main(args: DictConfig):
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
flow: torch.Tensor = torch.tensor([]).to(device)
# 変更前
"""
with torch.no_grad():
print("start test")
for batch in tqdm(test_data):
batch: Dict[str, Any]
event_image = batch["event_volume"].to(device)
batch_flow = model(event_image) # [1, 2, 480, 640]
flow = torch.cat((flow, batch_flow), dim=0) # [N, 2, 480, 640]
print("test done")"""
# 変更後
with torch.no_grad():
print("start test")
for batch in tqdm(test_data):
event_image = batch["event_volume"].to(device) # [B, 2, 4, 480, 640] - > LSTM - > [B, 4, 480, 640]
batch_flow_dict = model(event_image) # モデルの出力が辞書
batch_flow = batch_flow_dict['avg_flow'] # 平均フローデータを辞書から取得
flow = torch.cat((flow, batch_flow), dim=0) # 必要に応じて形状を調整

print("test done")
# ------------------
# save submission
Expand Down
11 changes: 9 additions & 2 deletions src/models/evflownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torch import nn
from src.models.base import *
from typing import Dict, Any
import torch.nn.functional as F

_BASE_CHANNELS = 64

Expand Down Expand Up @@ -29,7 +30,7 @@ def __init__(self, args):
self.decoder4 = upsample_conv2d_and_predict_flow(in_channels=2*_BASE_CHANNELS+2,
out_channels=int(_BASE_CHANNELS/2), do_batch_norm=not self._args.no_batch_norm)

def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def forward(self, inputs: torch.Tensor) -> Dict[str, torch.Tensor]:
# encoder
skip_connections = {}
inputs = self.encoder1(inputs)
Expand Down Expand Up @@ -62,7 +63,13 @@ def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
inputs, flow = self.decoder4(inputs)
flow_dict['flow3'] = flow.clone()

return flow
max_size = [max(flow.size(2) for flow in flow_dict.values()), max(flow.size(3) for flow in flow_dict.values())]
resized_flows = [F.interpolate(flow, size=max_size, mode='bilinear', align_corners=False) for flow in flow_dict.values()]
total_flow = sum(resized_flows)
avg_flow = total_flow / len(flow_dict)
flow_dict['avg_flow'] = avg_flow

return flow_dict


# if __name__ == "__main__":
Expand Down