Skip to content

Commit c68d724

Browse files
adamjstewartrwightman
authored andcommitted
adapt_input_conv: add type hints
1 parent 105a667 commit c68d724

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

timm/models/_manipulate.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
import torch.utils.checkpoint
1010
from torch import nn as nn
11+
from torch import Tensor
1112

1213
from timm.layers import use_reentrant_ckpt
1314

@@ -284,7 +285,7 @@ def forward(_x):
284285
return x
285286

286287

287-
def adapt_input_conv(in_chans, conv_weight):
288+
def adapt_input_conv(in_chans: int, conv_weight: Tensor) -> Tensor:
288289
conv_type = conv_weight.dtype
289290
conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
290291
O, I, J, K = conv_weight.shape

0 commit comments

Comments
 (0)