Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions keras_hub/src/models/controlnet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from .controlnet_backbone import ControlNetBackbone
from .controlnet_preprocessor import ControlNetPreprocessor
from .controlnet_unet import ControlNetUNet
from .controlnet import ControlNet
from .controlnet_presets import controlnet_presets, from_preset
from .controlnet_layers import ZeroConv2D, ControlInjection

__all__ = [
"ControlNetBackbone",
"ControlNetPreprocessor",
"ControlNetUNet",
"ControlNet",
"controlnet_presets",
"from_preset",
"ZeroConv2D",
"ControlInjection",
]
42 changes: 42 additions & 0 deletions keras_hub/src/models/controlnet/controlnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import keras

from .controlnet_backbone import ControlNetBackbone
from .controlnet_preprocessor import ControlNetPreprocessor
from .controlnet_unet import ControlNetUNet


class ControlNet(keras.Model):

def __init__(self, image_size=128, base_channels=64, **kwargs):
super().__init__(**kwargs)

self.image_size = image_size
self.base_channels = base_channels

self.preprocessor = ControlNetPreprocessor(
target_size=(image_size, image_size)
)
self.backbone = ControlNetBackbone()
self.unet = ControlNetUNet(base_channels=base_channels)

def call(self, inputs):
image = inputs["image"]
control = inputs["control"]

image = self.preprocessor(image)
control = self.preprocessor(control)
control_features = self.backbone(control)

output = self.unet(image, control_features)

return output

def get_config(self):
config = super().get_config()
config.update(
{
"image_size": self.image_size,
"base_channels": self.base_channels,
}
)
return config
46 changes: 46 additions & 0 deletions keras_hub/src/models/controlnet/controlnet_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import keras
import tensorflow as tf


class ControlNetBackbone(keras.Model):
"""Lightweight conditioning encoder for ControlNet."""

def __init__(self, **kwargs):
super().__init__(**kwargs)

self.down1 = keras.layers.Conv2D(
64, kernel_size=3, padding="same", activation="relu"
)
self.down2 = keras.layers.Conv2D(
128, kernel_size=3, padding="same", activation="relu"
)
self.down3 = keras.layers.Conv2D(
256, kernel_size=3, padding="same", activation="relu"
)

self.pool = keras.layers.MaxPooling2D(pool_size=2)

def build(self, input_shape):
self.down1.build(input_shape)
b, h, w, c = input_shape
half_shape = (b, h // 2, w // 2, 64)
self.down2.build(half_shape)
quarter_shape = (b, h // 4, w // 4, 128)
self.down3.build(quarter_shape)

super().build(input_shape)

def call(self, x):
f1 = self.down1(x)
p1 = self.pool(f1)

f2 = self.down2(p1)
p2 = self.pool(f2)

f3 = self.down3(p2)

return {
"scale_1": f1,
"scale_2": f2,
"scale_3": f3,
}
Comment on lines +1 to +46
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation of ControlNetBackbone deviates significantly from the KerasHub style guide. To align with the repository's standards, the model should be refactored.

Here are the key issues and how the suggestion addresses them:

  • Inheritance: The model should inherit from keras_hub.models.backbone.Backbone instead of keras.Model to gain standard functionality like from_preset(). (Style Guide: line 86)
  • Functional API: Backbones must be implemented using the Keras Functional API inside the __init__ method, not as a subclassed model with a call method. This makes the model structure explicit and avoids the need for a manual build() method. (Style Guide: line 79)
  • Docstrings: The class is missing a comprehensive Google-style docstring, including Args and an Example section. (Style Guide: lines 366-371)
  • Serialization: A get_config() method is required for proper serialization. (Style Guide: line 528)
  • Backend Agnostic: The unused import tensorflow as tf should be removed to maintain backend-agnostic code. (Style Guide: line 7)
  • Standard Naming: The input should be named pixel_values as per the convention for image models. (Style Guide: line 67)
  • Export: The class should be decorated with @keras_hub_export to make it part of the public API. (Style Guide: line 85)

I've provided a code suggestion that refactors the entire class to follow these guidelines.

import keras
import numpy as np

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.backbone import Backbone


@keras_hub_export("keras_hub.models.ControlNetBackbone")
class ControlNetBackbone(Backbone):
    """Lightweight conditioning encoder for ControlNet.

    This backbone model takes a conditioning image (e.g., a Canny edge map) and
    encodes it into multi-scale feature maps. These feature maps can then be
    injected into the intermediate layers of a diffusion model like a UNet to
    condition its generation process.

    This implementation is a simplified version of the ControlNet encoder,
    consisting of a series of convolutional and pooling layers.

    Args:
        input_shape: tuple. The shape of the input image, excluding the batch
            dimension. Defaults to `(None, None, 1)`.
        dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
            for model computations and weights.

    Example:
    ```python
    import numpy as np

    # Define the input for the conditioning image.
    input_data = {
        "pixel_values": np.ones(shape=(1, 512, 512, 1), dtype="float32"),
    }

    # Create a ControlNetBackbone instance.
    model = ControlNetBackbone()

    # Get the multi-scale feature maps.
    outputs = model(input_data)
    # `outputs` will be a dictionary with keys "scale_1", "scale_2", "scale_3".
    ```
    """

    def __init__(
        self,
        input_shape=(None, None, 1),
        dtype=None,
        **kwargs,
    ):
        # === Layers ===
        self.down1 = keras.layers.Conv2D(
            64, kernel_size=3, padding="same", activation="relu", name="down1_conv"
        )
        self.down2 = keras.layers.Conv2D(
            128, kernel_size=3, padding="same", activation="relu", name="down2_conv"
        )
        self.down3 = keras.layers.Conv2D(
            256, kernel_size=3, padding="same", activation="relu", name="down3_conv"
        )
        self.pool = keras.layers.MaxPooling2D(pool_size=2, name="pool")

        # === Functional Model ===
        pixel_values = keras.Input(
            shape=input_shape, dtype=dtype, name="pixel_values"
        )

        f1 = self.down1(pixel_values)
        p1 = self.pool(f1)

        f2 = self.down2(p1)
        p2 = self.pool(f2)

        f3 = self.down3(p2)

        outputs = {
            "scale_1": f1,
            "scale_2": f2,
            "scale_3": f3,
        }

        super().__init__(
            inputs=pixel_values,
            outputs=outputs,
            dtype=dtype,
            **kwargs,
        )

        # Store parameters as attributes.
        self.input_shape_arg = input_shape

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "input_shape": self.input_shape_arg,
            }
        )
        return config
References
  1. Backbone models must be implemented using the Keras Functional API within the __init__ method, rather than as a subclassed model with a call method. (link)
  2. Backbone models should inherit from keras_hub.models.Backbone to ensure they have standard features like from_preset. (link)
  3. All public classes must have Google-style docstrings, including Args and Example sections. (link)
  4. All layers and models must implement a get_config() method for serialization. (link)
  5. All code must be backend-agnostic, which means avoiding direct imports from tensorflow and using keras.ops instead. (link)
  6. Model inputs should use standardized names. For image models, pixel_values is the convention. (link)
  7. Public models must be decorated with @keras_hub_export to be included in the library's public API. (link)

41 changes: 41 additions & 0 deletions keras_hub/src/models/controlnet/controlnet_backbone_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import tensorflow as tf
from keras_hub.src.models.controlnet.controlnet_backbone import (
ControlNetBackbone,
)


def test_controlnet_backbone_smoke():
model = ControlNetBackbone()
x = tf.random.uniform((1, 512, 512, 1))
outputs = model(x)
assert isinstance(outputs, dict)


def test_controlnet_backbone_required_keys():
model = ControlNetBackbone()
x = tf.random.uniform((1, 512, 512, 1))
outputs = model(x)

assert "scale_1" in outputs
assert "scale_2" in outputs
assert "scale_3" in outputs


def test_controlnet_backbone_rank():
model = ControlNetBackbone()
x = tf.random.uniform((2, 256, 256, 1))
outputs = model(x)

for v in outputs.values():
assert len(v.shape) == 4
assert v.shape[0] == 2


def test_controlnet_backbone_spatial_scaling():
model = ControlNetBackbone()
x = tf.random.uniform((1, 256, 256, 1))
outputs = model(x)

assert outputs["scale_1"].shape[1:3] == (256, 256)
assert outputs["scale_2"].shape[1:3] == (128, 128)
assert outputs["scale_3"].shape[1:3] == (64, 64)
45 changes: 45 additions & 0 deletions keras_hub/src/models/controlnet/controlnet_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import keras
from keras import layers


class ZeroConv2D(layers.Layer):

def __init__(self, filters, **kwargs):
super().__init__(**kwargs)
self.filters = filters
self.conv = layers.Conv2D(
filters,
kernel_size=1,
padding="same",
kernel_initializer="zeros",
bias_initializer="zeros",
)

def call(self, inputs):
return self.conv(inputs)

def get_config(self):
config = super().get_config()
config.update({"filters": self.filters})
return config


class ControlInjection(layers.Layer):

def __init__(self, out_channels, **kwargs):
super().__init__(**kwargs)
self.out_channels = out_channels
self.projection = ZeroConv2D(out_channels)

def call(self, x, control):
if x.shape[1:3] != control.shape[1:3]:
raise ValueError(
f"Spatial mismatch: {x.shape[1:3]} vs {control.shape[1:3]}"
)
control = self.projection(control)
return x + control

def get_config(self):
config = super().get_config()
config.update({"out_channels": self.out_channels})
return config
16 changes: 16 additions & 0 deletions keras_hub/src/models/controlnet/controlnet_layers_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import tensorflow as tf
from keras_hub.src.models.controlnet.controlnet_layers import ZeroConv2D


def test_zero_conv_output_shape():
layer = ZeroConv2D(64)
x = tf.random.uniform((1, 128, 128, 3))
y = layer(x)
assert y.shape == (1, 128, 128, 64)


def test_zero_conv_initial_output_is_zero():
layer = ZeroConv2D(64)
x = tf.random.uniform((1, 64, 64, 3))
y = layer(x)
assert tf.reduce_sum(tf.abs(y)).numpy() == 0.0
39 changes: 39 additions & 0 deletions keras_hub/src/models/controlnet/controlnet_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import keras
import tensorflow as tf


class ControlNetPreprocessor(keras.layers.Layer):
def __init__(self, target_size=(512, 512), **kwargs):
super().__init__(**kwargs)
self.target_size = tuple(target_size)

def call(self, inputs):
x = tf.convert_to_tensor(inputs)

if x.shape.rank != 4:
raise ValueError("Inputs must be a 4D tensor (batch, height, width, channels).")

x = tf.image.resize(x, self.target_size)
x = tf.cast(x, tf.float32)

max_val = tf.reduce_max(x)
x = tf.cond(
max_val > 1.0,
lambda: x / 255.0,
lambda: x,
)

return x

def compute_output_shape(self, input_shape):
return (
input_shape[0],
self.target_size[0],
self.target_size[1],
input_shape[-1],
)

def get_config(self):
config = super().get_config()
config.update({"target_size": self.target_size})
return config
30 changes: 30 additions & 0 deletions keras_hub/src/models/controlnet/controlnet_preprocessor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import tensorflow as tf
from keras_hub.src.models.controlnet.controlnet_preprocessor import ControlNetPreprocessor


def test_controlnet_preprocessor_output_shape():
layer = ControlNetPreprocessor(target_size=(128, 128))

x = tf.random.uniform((1, 256, 256, 3), maxval=255, dtype=tf.float32)
y = layer(x)

assert y.shape == (1, 128, 128, 3)


def test_controlnet_preprocessor_scaling():
layer = ControlNetPreprocessor(target_size=(64, 64))

x = tf.ones((1, 128, 128, 3)) * 255.0
y = layer(x)

assert tf.reduce_max(y).numpy() <= 1.0
assert tf.reduce_min(y).numpy() >= 0.0


def test_controlnet_preprocessor_dtype():
layer = ControlNetPreprocessor(target_size=(64, 64))

x = tf.random.uniform((1, 128, 128, 3), maxval=255, dtype=tf.float32)
y = layer(x)

assert y.dtype == tf.float32
20 changes: 20 additions & 0 deletions keras_hub/src/models/controlnet/controlnet_presets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from .controlnet import ControlNet


controlnet_presets = {
"controlnet_base": {
"description": "Minimal ControlNet base configuration.",
"config": {
"image_size": 128,
"base_channels": 64,
},
}
}


def from_preset(preset_name):
if preset_name not in controlnet_presets:
raise ValueError(f"Unknown preset: {preset_name}")

config = controlnet_presets[preset_name]["config"]
return ControlNet(**config)
15 changes: 15 additions & 0 deletions keras_hub/src/models/controlnet/controlnet_presets_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import tensorflow as tf
from keras_hub.src.models.controlnet.controlnet_presets import from_preset


def test_controlnet_from_preset():
model = from_preset("controlnet_base")

inputs = {
"image": tf.random.uniform((1, 128, 128, 3)),
"control": tf.random.uniform((1, 128, 128, 3)),
}

outputs = model(inputs)

assert outputs.shape == (1, 128, 128, 3)
15 changes: 15 additions & 0 deletions keras_hub/src/models/controlnet/controlnet_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import tensorflow as tf
from keras_hub.src.models.controlnet.controlnet import ControlNet


def test_controlnet_full_model_smoke():
model = ControlNet()

inputs = {
"image": tf.random.uniform((1, 128, 128, 3)),
"control": tf.random.uniform((1, 128, 128, 3)),
}

outputs = model(inputs)

assert outputs.shape == (1, 128, 128, 3)
40 changes: 40 additions & 0 deletions keras_hub/src/models/controlnet/controlnet_unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import keras
from .controlnet_layers import ControlInjection


class ControlNetUNet(keras.Model):

def __init__(self, base_channels=64, **kwargs):
super().__init__(**kwargs)

self.base_channels = base_channels

self.conv1 = keras.layers.Conv2D(
base_channels, 3, padding="same", activation="relu"
)

self.inject = ControlInjection(base_channels)

self.conv2 = keras.layers.Conv2D(
base_channels, 3, padding="same", activation="relu"
)

self.out_conv = keras.layers.Conv2D(
3, 1, padding="same"
)

def call(self, image, control_features):
if "scale_1" not in control_features:
raise ValueError("Expected 'scale_1' in control_features.")

x = self.conv1(image)
x = self.inject(x, control_features["scale_1"])
x = self.conv2(x)
x = self.out_conv(x)

return x

def get_config(self):
config = super().get_config()
config.update({"base_channels": self.base_channels})
return config
15 changes: 15 additions & 0 deletions keras_hub/src/models/controlnet/controlnet_unet_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import tensorflow as tf
from keras_hub.src.models.controlnet.controlnet_unet import ControlNetUNet


def test_controlnet_unet_smoke():
model = ControlNetUNet()

image = tf.random.uniform((1, 128, 128, 3))
control_features = {
"scale_1": tf.random.uniform((1, 128, 128, 64))
}

outputs = model(image, control_features)

assert outputs.shape == (1, 128, 128, 3)
Loading