diff --git a/eg3d/training/volumetric_rendering/renderer.py b/eg3d/training/volumetric_rendering/renderer.py index a27aea61..382db021 100644 --- a/eg3d/training/volumetric_rendering/renderer.py +++ b/eg3d/training/volumetric_rendering/renderer.py @@ -36,6 +36,22 @@ def generate_planes(): [1, 0, 0], [0, 1, 0]]], dtype=torch.float32) +# def project_onto_planes(planes, coordinates): +# """ +# Does a projection of a 3D point onto a batch of 2D planes, +# returning 2D plane coordinates. + +# Takes plane axes of shape n_planes, 3, 3 +# # Takes coordinates of shape N, M, 3 +# # returns projections of shape N*n_planes, M, 2 +# """ +# N, M, C = coordinates.shape +# n_planes, _, _ = planes.shape +# coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3) +# inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3) +# projections = torch.bmm(coordinates, inv_planes) +# return projections[..., :2] + def project_onto_planes(planes, coordinates): """ Does a projection of a 3D point onto a batch of 2D planes, @@ -45,12 +61,22 @@ def project_onto_planes(planes, coordinates): # Takes coordinates of shape N, M, 3 # returns projections of shape N*n_planes, M, 2 """ - N, M, C = coordinates.shape - n_planes, _, _ = planes.shape - coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3) - inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3) - projections = torch.bmm(coordinates, inv_planes) - return projections[..., :2] + + # # ORIGINAL + # N, M, C = coordinates.shape + # xy_coords = coordinates[..., [0, 1]] + # xz_coords = coordinates[..., [0, 2]] + # zx_coords = coordinates[..., [2, 0]] + # return torch.stack([xy_coords, xz_coords, zx_coords], dim=1).reshape(N*3, M, 2) + + # FIXED + N, M, _ = coordinates.shape + xy_coords = coordinates[..., [0, 1]] + yz_coords = coordinates[..., [1, 2]] + zx_coords = coordinates[..., [2, 0]] + return torch.stack([xy_coords, yz_coords, zx_coords], dim=1).reshape(N*3, M, 2) + + def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None): assert padding_mode == 'zeros'