This repository was archived by the owner on Mar 10, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 327
Expand file tree
/
Copy pathvivit_layers.py
More file actions
129 lines (111 loc) · 4.02 KB
/
vivit_layers.py
File metadata and controls
129 lines (111 loc) · 4.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# Copyright 2024 The KerasCV Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras
from keras_cv.backend import ops
@keras_cv_export(
"keras_cv.layers.TubeletEmebedding",
package="keras_cv.layers",
)
class TubeletEmbedding(keras.layers.Layer):
"""
A Keras layer for spatio-temporal tube embedding applied to input sequences
retrieved from video frames.
References:
- [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691)
(ICCV 2021)
Args:
embed_dim: int, number of dimensions in the embedding space.
Defaults to 128.
patch_size: tuple , size of the spatio-temporal patch.
Specifies the size for each dimension.
Defaults to (8,8,8).
"""
def __init__(self, embed_dim=128, patch_size=(8, 8, 8), **kwargs):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.patch_size = patch_size
self.projection = keras.layers.Conv3D(
filters=self.embed_dim,
kernel_size=self.patch_size,
strides=self.patch_size,
data_format="channels_last",
padding="VALID",
)
self.flatten = keras.layers.Reshape(target_shape=(-1, self.embed_dim))
def build(self, input_shape):
super().build(input_shape)
self.projection.build(
(
None,
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
)
)
projected_patch_shape = self.projection.compute_output_shape(
(
None,
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
)
)
self.flatten.build(projected_patch_shape)
def compute_output_shape(self, input_shape):
projected_patch_shape = self.projection.compute_output_shape(
(
None,
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
)
)
return self.flatten.compute_output_shape(projected_patch_shape)
def call(self, videos):
projected_patches = self.projection(videos)
flattened_patches = self.flatten(projected_patches)
return flattened_patches
@keras_cv_export(
"keras_cv.layers.PositionalEncoder",
package="keras_cv.layers",
)
class PositionalEncoder(keras.layers.Layer):
"""
A Keras layer for adding positional information to the encoded video tokens.
References:
- [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691)
(ICCV 2021)
Args:
embed_dim: int, number of dimensions in the embedding space.
Defaults to 128.
"""
def __init__(self, embed_dim=128, **kwargs):
super().__init__(**kwargs)
self.embed_dim = embed_dim
def build(self, input_shape):
super().build(input_shape)
_, num_tokens, _ = input_shape
self.position_embedding = keras.layers.Embedding(
input_dim=num_tokens, output_dim=self.embed_dim
)
self.position_embedding.build(input_shape)
self.positions = ops.arange(start=0, stop=num_tokens, step=1)
def call(self, encoded_tokens):
encoded_positions = self.position_embedding(self.positions)
encoded_tokens = encoded_tokens + encoded_positions
return encoded_tokens