diff --git a/src/datasets.py b/src/datasets.py index e1d02f94..7eef713e 100644 --- a/src/datasets.py +++ b/src/datasets.py @@ -317,26 +317,57 @@ def rectify_events(self, x: np.ndarray, y: np.ndarray): return rectify_map[y, x] def get_data(self, index) -> Dict[str, any]: - ts_start: int = self.timestamps_flow[index] - self.delta_t_us - ts_end: int = self.timestamps_flow[index] + #ts_start: int = self.timestamps_flow[index] - self.delta_t_us + #ts_end: int = self.timestamps_flow[index] + ts_start1: int = self.timestamps_flow[index] - self.delta_t_us + ts_end1: int = self.timestamps_flow[index] + ts_start2: int = self.timestamps_flow[index + 1] - self.delta_t_us + ts_end2: int = self.timestamps_flow[index + 1] - file_index = self.indices[index] + #file_index = self.indices[index] + file_index1 = self.indices[index] + file_index2 = self.indices[index + 1] output = { - 'file_index': file_index, - 'timestamp': self.timestamps_flow[index], + #'file_index': file_index, + 'file_index1': file_index1, + 'file_index2': file_index2, + #'timestamp': self.timestamps_flow[index], + 'timestamp1': self.timestamps_flow[index], + 'timestamp2': self.timestamps_flow[index + 1], 'seq_name': self.seq_name } # Save sample for benchmark submission output['save_submission'] = file_index in self.idx_to_visualize output['visualize'] = self.visualize_samples - event_data = self.event_slicer.get_events( - ts_start, ts_end) - p = event_data['p'] - t = event_data['t'] - x = event_data['x'] - y = event_data['y'] + + #event_data = self.event_slicer.get_events(ts_start, ts_end) + event_data1 = self.event_slicer.get_events(ts_start1, ts_end1) + event_data2 = self.event_slicer.get_events(ts_start2, ts_end2) + + #p = event_data['p'] + #t = event_data['t'] + #x = event_data['x'] + #y = event_data['y'] + + #event_data1 = self.event_slicer.get_events(ts_start1, ts_end1) + p1 = event_data1['p'] + t1 = event_data1['t'] + x1 = event_data1['x'] + y1 = event_data1['y'] + + #event_data2 = self.event_slicer.get_events(ts_start2, ts_end2) + p2 = event_data2['p'] + t2 = event_data2['t'] + x2 = event_data2['x'] + y2 = event_data2['y'] + + p = np.concatenate((p1, p2)) + t = np.concatenate((t1, t2)) + x = np.concatenate((x1, x2)) + y = np.concatenate((y1, y2)) + xy_rect = self.rectify_events(x, y) x_rect = xy_rect[:, 0] y_rect = xy_rect[:, 1] @@ -344,19 +375,24 @@ def get_data(self, index) -> Dict[str, any]: if self.voxel_grid is None: raise NotImplementedError else: - event_representation = self.events_to_voxel_grid( - p, t, x_rect, y_rect) + event_representation = self.events_to_voxel_grid(p, t, x_rect, y_rect) output['event_volume'] = event_representation output['name_map'] = self.name_idx - if self.load_gt: - output['flow_gt' - ] = [torch.tensor(x) for x in self.load_flow(self.flow_png[index])] - - output['flow_gt' - ][0] = torch.moveaxis(output['flow_gt'][0], -1, 0) - output['flow_gt' - ][1] = torch.unsqueeze(output['flow_gt'][1], 0) + if self.load_gt: + #output['flow_gt'] = [torch.tensor(x) for x in self.load_flow(self.flow_png[index])] + flow_gt1 = [torch.tensor(x) for x in self.load_flow(self.flow_png[index])] + flow_gt2 = [torch.tensor(x) for x in self.load_flow(self.flow_png[index + 1])] + + #output['flow_gt'][0] = torch.moveaxis(output['flow_gt'][0], -1, 0) + #output['flow_gt'][1] = torch.unsqueeze(output['flow_gt'][1], 0) + flow_gt1[0] = torch.moveaxis(flow_gt1[0], -1, 0) + flow_gt1[1] = torch.unsqueeze(flow_gt1[1], 0) + flow_gt2[0] = torch.moveaxis(flow_gt2[0], -1, 0) + flow_gt2[1] = torch.unsqueeze(flow_gt2[1], 0) + + output['flow_gt'] = [flow_gt1, flow_gt2] + return output def __getitem__(self, idx):