forked from open-mmlab/mmaction2
-
Notifications
You must be signed in to change notification settings - Fork 211
Open
Description
Hello, I am trying to build the Video Swin Transformer with some other embed_dim say 32 and depths, say [2, 2, 8, 4], but at the moment there are only three fixed version with these parameters , tiny, small and base:
I am wondering if you can create another class and include there which allows to build a custom model as well:
@register_model()
@handle_legacy_interface(weights=("pretrained", Swin3D_B_Weights.KINETICS400_V1))
def swin3d_custom(*, weights: Optional[Swin3D_B_Weights] = None, progress: bool = True,patch_size,embed_dim,depths,num_heads,window_size, **kwargs: Any) -> SwinTransformer3d:
"""
Constructs a swin_base architecture from
`Video Swin Transformer <https://arxiv.org/abs/2106.13230>`_.
Args:
weights (:class:`~torchvision.models.video.Swin3D_B_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.video.Swin3D_B_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.video.swin_transformer.SwinTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/swin_transformer.py>`_
for more details about this class.
.. autoclass:: torchvision.models.video.Swin3D_B_Weights
:members:
"""
weights = Swin3D_B_Weights.verify(weights)
return _swin_transformer3d(
patch_size=patch_size,
embed_dim=embed_dim,
depths=depths,
num_heads=num_heads,
window_size=window_size,
stochastic_depth_prob=0.1,
weights=weights,
progress=progress,
**kwargs,
)
At the moment these are fixed for each of its version.
Thank You
Metadata
Metadata
Assignees
Labels
No labels