Skip to content
This repository was archived by the owner on Jul 9, 2025. It is now read-only.

Commit a993686

Browse files
author
Owen Vallis
authored
#303 fix dtype policy bug in GEM layers. (#304)
* GEM layers create a general pooling layer in the init, but we didn't pass the kwargs. This means the general pooling layer didn't have the dtype policy. This caused the GEM layers to fail when using a mixed_float dtype policy as the general pooling layer returns float32 and the GEM dtype policy is float16. The fix is to pass all kwargs onto the general pooling layer. * Patch bump * Cap the TF version at 2.9 for the current master branch.
1 parent 0a01bdc commit a993686

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

setup.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def get_version(rel_path):
7878
"mkdocs-autorefs",
7979
"mkdocs-material",
8080
"mkdocstrings",
81-
"mypy",
81+
"mypy<=0.982",
8282
"pytest",
8383
"pytype",
8484
"setuptools",
@@ -87,9 +87,9 @@ def get_version(rel_path):
8787
"types-tabulate",
8888
"wheel",
8989
],
90-
"tensorflow": ["tensorflow>=2.4"],
91-
"tensorflow-gpu": ["tensorflow-gpu>=2.4"],
92-
"tensorflow-cpu": ["tensorflow-cpu>=2.4"],
90+
"tensorflow": ["tensorflow>=2.4,<=2.9"],
91+
"tensorflow-gpu": ["tensorflow-gpu>=2.4,<=2.9"],
92+
"tensorflow-cpu": ["tensorflow-cpu>=2.4,<=2.9"],
9393
},
9494
classifiers=[
9595
"Development Status :: 5 - Production/Stable",

tensorflow_similarity/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
__version__ = "0.16.8"
14+
__version__ = "0.16.9"
1515

1616

1717
from . import algebra # noqa

tensorflow_similarity/layers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def __init__(self, p: float = 3.0, data_format: Optional[str] = None, keepdims:
145145
super().__init__(p=p, data_format=data_format, keepdims=keepdims, **kwargs)
146146

147147
self.input_spec = layers.InputSpec(ndim=3)
148-
self.gap = layers.GlobalAveragePooling1D(data_format=data_format, keepdims=keepdims)
148+
self.gap = layers.GlobalAveragePooling1D(data_format=data_format, keepdims=keepdims, **kwargs)
149149
self.step_axis = 1 if self.data_format == "channels_last" else 2
150150

151151
def call(self, inputs: FloatTensor) -> FloatTensor:
@@ -231,7 +231,7 @@ def __init__(self, p: float = 3.0, data_format: Optional[str] = None, keepdims:
231231
super().__init__(p=p, data_format=data_format, keepdims=keepdims, **kwargs)
232232

233233
self.input_spec = layers.InputSpec(ndim=4)
234-
self.gap = layers.GlobalAveragePooling2D(data_format, keepdims)
234+
self.gap = layers.GlobalAveragePooling2D(data_format=data_format, keepdims=keepdims, **kwargs)
235235

236236
def call(self, inputs: FloatTensor) -> FloatTensor:
237237
x = inputs

0 commit comments

Comments
 (0)