-
Notifications
You must be signed in to change notification settings - Fork 407
Description
Hello, please kindly update the codebase with the fix corresponding to Issue450. Otherwise, there will be significant discrepancies in the computation results of 1×1 SubMConv layers.
I'm using the following code to simulate sparse and dense computations. Without fixing this bug, the results will differ significantly when kernel_size=1.
`import torch
from torch import nn
import spconv
import spconv.pytorch
from spconv.core import ConvAlgo
from spconv.pytorch import SparseSequential, SparseConv2d
class SparseMapping(spconv.pytorch.SparseModule):
def init(self,
in_channels=64,
out_channels=256,
):
super(SparseMapping, self).__init__()
self.mapping = SparseSequential(
spconv.pytorch.SubMConv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, padding=0, stride=1, bias=False, algo=ConvAlgo.Native),
)
def forward(self, pillar_features, coors, input_shape):
batch_size = len(torch.unique(coors[:, 0]))
x = spconv.pytorch.SparseConvTensor(
pillar_features, coors, input_shape, batch_size)
x = self.mapping(x)
return x.dense()
class DenseMapping(nn.Module):
def init(self,
in_channels=64,
out_channels=256,
):
super(DenseMapping, self).__init__()
self.mapping = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, padding=0, stride=1, bias=False),
)
def forward(self, x, mask):
cur_mask = mask.float()
x = x * cur_mask
x = self.mapping(x) # [B, C, H/8, W/8]
x = x * cur_mask
return x
def feed_weight(sparse_model, dense_model) -> None:
"""Feed weights of sparse model to dense model
Args:
sparse_model: Sparse model
dense_model: Dense model
"""
# Inference mode
sparse_model.eval()
dense_model.eval()
sparse_dict = sparse_model.state_dict()
dense_dict = dense_model.state_dict()
print(f"acc: {sparse_model.mapping[0].weight.dtype, dense_model.mapping[0].weight.dtype}")
for name in sparse_dict.keys():
if name not in dense_dict.keys():
raise ValueError(f"'{name}' does not exist in dense model")
if name == "mapping.0.weight":
# mapping.0.weight is conv structure weight
target = sparse_dict[name].data.permute(0, 3, 1, 2)
else:
target = sparse_dict[name].data
dense_dict[name].data.copy_(target)
dense_model.load_state_dict(dense_dict, strict=True)
print("Feed weight success")
def test_sparse2dense_outputs(sparse_model, dense_model, B=1, C=64, H=32, W=32):
"""Compare outputs between sparse and dense model
Args:
sparse_model: Sparse model
dense_model: Dense model
B: Batch size. Defaults to 1.
C: Feature dims. Defaults to 64.
H: Height for dense features. Defaults to 32.
W: Width for dense features. Defaults to 32.
"""
device = torch.device("cuda")
sparse_model = sparse_model.to(device).eval()
dense_model = dense_model.to(device).eval()
# Construct inputs for sparse model with shape [N,C] + [N,3(batch_id,y,x)]
N = B * H * W
dense_feats = torch.randn(B, C, H, W, device=device, dtype=torch.float32)
dense_mask = torch.randn(B, 1, H, W, device=device, dtype=torch.float32) > 0.8
dense_feats = dense_feats * dense_mask
coords = []
sparse_feats = []
for b in range(B):
for y in range(H):
for x in range(W):
if dense_mask[b, 0, y, x]:
coords.append([b, y, x])
sparse_feats.append(dense_feats[b, :, y, x])
# dense_mask = dense_mask.float()
coords = torch.tensor(coords, device=device).to(torch.int32)
sparse_feats = torch.stack(sparse_feats, dim=0).to(device)
# Inference
with torch.no_grad():
sparse_out = sparse_model(pillar_features=sparse_feats, coors=coords, input_shape=[H, W])
dense_out = dense_model(dense_feats, dense_mask)
mae = torch.abs(sparse_out - dense_out).mean().item()
max_err = torch.abs(sparse_out - dense_out).max().item()
print("="*80)
print(f"Test case params: B={B}, C={C}, H={H}, W={W}, Valid points={N}")
print(f"Shape for sparse inputs: sparse_feats[{sparse_feats.shape}], coords[{coords.shape}]")
print(f"Shape for dense inputs: dense_feats[{dense_feats.shape}]")
print(f"Shape for sparse outputs: sparse_out[{sparse_out.shape}]")
print(f"Shape for dense outputs: dense_out[{dense_out.shape}]")
print("="*80)
print(f"MAE for outputs: {mae:.8f}")
print(f"Max Err for outputs: {max_err:.8f}")
if mae < 1e-6:
print("Success")
else:
print("Fail")
return mae, max_err
if name == "main":
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
sparse_model = SparseMapping(
in_channels=64,
out_channels=256
)
dense_model = DenseMapping(
in_channels=64,
out_channels=256
)
feed_weight(sparse_model=sparse_model, dense_model=dense_model)
mae, max_err = test_sparse2dense_outputs(sparse_model=sparse_model, dense_model=dense_model, H=16, W=16)`