Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/decomon/keras_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def batch_multid_dot(
) -> Tensor:
"""Dot product of tensors by batch, along multiple axes

Hypothesis: we sum over last axes of x and first axes (skipping the batch one) of x.
Hypothesis: we sum over last axes of x and first axes (skipping the batch one) of y.

The 1-dimensional equivalent would be `batch_dot(x,y, axes=(-1, 1))`
or `keras.layers.Dot(axes=(-1, 1))(x,y)`
Expand Down
6 changes: 2 additions & 4 deletions src/decomon/layers/custom/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
from .max import DecomonMax
from .min import DecomonMin
from .mul import DecomonMulConstant
from .utils import get_affine_lower_bound_max, get_affine_upper_bound_max
from .constant import DecomonMulConstant
from .reduce import DecomonMax, DecomonMin
1 change: 1 addition & 0 deletions src/decomon/layers/custom/constant/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .constant import DecomonMulConstant
10 changes: 10 additions & 0 deletions src/decomon/layers/custom/constant/constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from keras_custom.layers import MulConstant

from decomon.layers.layer import DecomonLayer


class DecomonMulConstant(DecomonLayer):
layer: MulConstant
linear = True
diagonal = True
use_bias = False
165 changes: 0 additions & 165 deletions src/decomon/layers/custom/max.py

This file was deleted.

76 changes: 0 additions & 76 deletions src/decomon/layers/custom/min.py

This file was deleted.

46 changes: 0 additions & 46 deletions src/decomon/layers/custom/mul.py

This file was deleted.

6 changes: 6 additions & 0 deletions src/decomon/layers/custom/reduce/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .max import DecomonMax
from .min import DecomonMin
from .utils import (
get_affine_lower_bound_max_before_reduction,
get_affine_upper_bound_max_before_reduction,
)
37 changes: 37 additions & 0 deletions src/decomon/layers/custom/reduce/max.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# define non native class Max
# Decomon Custom for Max(axis...)
from typing import Any

from keras_custom.layers import Max

from decomon.layers.layer import DecomonLayer
from decomon.types import Tensor

from .utils import (
get_affine_lower_bound_max_before_reduction,
get_affine_upper_bound_max_before_reduction,
get_batch_multi_dot_repr_for_axis_reduce_weights,
)


class DecomonMax(DecomonLayer):
layer: Max
linear = False
increasing = True

def get_affine_bounds(self, lower: Tensor, upper: Tensor, **kwargs: Any) -> tuple[Tensor, Tensor, Tensor, Tensor]:
return get_affine_bounds_max(lower=lower, upper=upper, axis=self.layer.axis, keepdims=self.layer.keepdims)


def get_affine_bounds_max(
lower: Tensor, upper: Tensor, axis: int, keepdims: bool
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
w_l, b_l = get_affine_lower_bound_max_before_reduction(lower, upper, axis=axis, keepdims=keepdims)
w_u, b_u = get_affine_upper_bound_max_before_reduction(lower, upper, axis=axis, keepdims=keepdims)

# for now the lower bound is given by sum(w_l*x, axis=axis) + b,
# so we need another transformation to get the final w_l
w_l = get_batch_multi_dot_repr_for_axis_reduce_weights(w_l, axis=axis, keepdims=keepdims)
w_u = get_batch_multi_dot_repr_for_axis_reduce_weights(w_u, axis=axis, keepdims=keepdims)

return (w_l, b_l, w_u, b_u)
Loading
Loading