Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions keras_hub/src/models/controlnet/controlnet_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import keras
import tensorflow as tf


class ControlNetBackbone(keras.Model):
"""Lightweight conditioning encoder for ControlNet."""

def __init__(self, **kwargs):
super().__init__(**kwargs)

self.down1 = keras.layers.Conv2D(
64, kernel_size=3, padding="same", activation="relu"
)
self.down2 = keras.layers.Conv2D(
128, kernel_size=3, padding="same", activation="relu"
)
self.down3 = keras.layers.Conv2D(
256, kernel_size=3, padding="same", activation="relu"
)

self.pool = keras.layers.MaxPooling2D(pool_size=2)

def build(self, input_shape):
self.down1.build(input_shape)
b, h, w, c = input_shape
half_shape = (b, h // 2, w // 2, 64)
self.down2.build(half_shape)
quarter_shape = (b, h // 4, w // 4, 128)
self.down3.build(quarter_shape)

super().build(input_shape)

def call(self, x):
f1 = self.down1(x)
p1 = self.pool(f1)

f2 = self.down2(p1)
p2 = self.pool(f2)

f3 = self.down3(p2)

return {
"scale_1": f1,
"scale_2": f2,
"scale_3": f3,
}
Comment on lines +1 to +46
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation of ControlNetBackbone deviates significantly from the KerasHub style guide. To align with the repository's standards, the model should be refactored.

Here are the key issues and how the suggestion addresses them:

  • Inheritance: The model should inherit from keras_hub.models.backbone.Backbone instead of keras.Model to gain standard functionality like from_preset(). (Style Guide: line 86)
  • Functional API: Backbones must be implemented using the Keras Functional API inside the __init__ method, not as a subclassed model with a call method. This makes the model structure explicit and avoids the need for a manual build() method. (Style Guide: line 79)
  • Docstrings: The class is missing a comprehensive Google-style docstring, including Args and an Example section. (Style Guide: lines 366-371)
  • Serialization: A get_config() method is required for proper serialization. (Style Guide: line 528)
  • Backend Agnostic: The unused import tensorflow as tf should be removed to maintain backend-agnostic code. (Style Guide: line 7)
  • Standard Naming: The input should be named pixel_values as per the convention for image models. (Style Guide: line 67)
  • Export: The class should be decorated with @keras_hub_export to make it part of the public API. (Style Guide: line 85)

I've provided a code suggestion that refactors the entire class to follow these guidelines.

import keras
import numpy as np

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.backbone import Backbone


@keras_hub_export("keras_hub.models.ControlNetBackbone")
class ControlNetBackbone(Backbone):
    """Lightweight conditioning encoder for ControlNet.

    This backbone model takes a conditioning image (e.g., a Canny edge map) and
    encodes it into multi-scale feature maps. These feature maps can then be
    injected into the intermediate layers of a diffusion model like a UNet to
    condition its generation process.

    This implementation is a simplified version of the ControlNet encoder,
    consisting of a series of convolutional and pooling layers.

    Args:
        input_shape: tuple. The shape of the input image, excluding the batch
            dimension. Defaults to `(None, None, 1)`.
        dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
            for model computations and weights.

    Example:
    ```python
    import numpy as np

    # Define the input for the conditioning image.
    input_data = {
        "pixel_values": np.ones(shape=(1, 512, 512, 1), dtype="float32"),
    }

    # Create a ControlNetBackbone instance.
    model = ControlNetBackbone()

    # Get the multi-scale feature maps.
    outputs = model(input_data)
    # `outputs` will be a dictionary with keys "scale_1", "scale_2", "scale_3".
    ```
    """

    def __init__(
        self,
        input_shape=(None, None, 1),
        dtype=None,
        **kwargs,
    ):
        # === Layers ===
        self.down1 = keras.layers.Conv2D(
            64, kernel_size=3, padding="same", activation="relu", name="down1_conv"
        )
        self.down2 = keras.layers.Conv2D(
            128, kernel_size=3, padding="same", activation="relu", name="down2_conv"
        )
        self.down3 = keras.layers.Conv2D(
            256, kernel_size=3, padding="same", activation="relu", name="down3_conv"
        )
        self.pool = keras.layers.MaxPooling2D(pool_size=2, name="pool")

        # === Functional Model ===
        pixel_values = keras.Input(
            shape=input_shape, dtype=dtype, name="pixel_values"
        )

        f1 = self.down1(pixel_values)
        p1 = self.pool(f1)

        f2 = self.down2(p1)
        p2 = self.pool(f2)

        f3 = self.down3(p2)

        outputs = {
            "scale_1": f1,
            "scale_2": f2,
            "scale_3": f3,
        }

        super().__init__(
            inputs=pixel_values,
            outputs=outputs,
            dtype=dtype,
            **kwargs,
        )

        # Store parameters as attributes.
        self.input_shape_arg = input_shape

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "input_shape": self.input_shape_arg,
            }
        )
        return config
References
  1. Backbone models must be implemented using the Keras Functional API within the __init__ method, rather than as a subclassed model with a call method. (link)
  2. Backbone models should inherit from keras_hub.models.Backbone to ensure they have standard features like from_preset. (link)
  3. All public classes must have Google-style docstrings, including Args and Example sections. (link)
  4. All layers and models must implement a get_config() method for serialization. (link)
  5. All code must be backend-agnostic, which means avoiding direct imports from tensorflow and using keras.ops instead. (link)
  6. Model inputs should use standardized names. For image models, pixel_values is the convention. (link)
  7. Public models must be decorated with @keras_hub_export to be included in the library's public API. (link)

38 changes: 38 additions & 0 deletions keras_hub/src/models/controlnet/controlnet_backbone_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import tensorflow as tf
from keras_hub.src.models.controlnet.controlnet_backbone import (
ControlNetBackbone,
)


def test_controlnet_backbone_smoke():
"""Basic smoke test: model builds and runs."""
model = ControlNetBackbone()

x = tf.random.uniform((1, 512, 512, 1))
outputs = model(x)

assert isinstance(outputs, dict)


def test_controlnet_backbone_required_keys():
"""Ensure expected feature scales exist."""
model = ControlNetBackbone()
x = tf.random.uniform((1, 512, 512, 1))

outputs = model(x)

assert "scale_1" in outputs
assert "scale_2" in outputs
assert "scale_3" in outputs


def test_controlnet_backbone_rank():
"""Each output should be a 4D tensor (B, H, W, C)."""
model = ControlNetBackbone()
x = tf.random.uniform((2, 256, 256, 1))

outputs = model(x)

for v in outputs.values():
assert len(v.shape) == 4
assert v.shape[0] == 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The test file does not follow the standardized testing structure required by the KerasHub contribution guidelines. The tests should be refactored to use the provided TestCase class and its helper methods.

Specifically:

  • Test Structure: Tests should be methods within a class that inherits from keras_hub.src.tests.test_case.TestCase. (Style Guide: line 449)
  • Standardized Routines: Instead of manual checks, the tests should use helper methods like self.run_backbone_test() and self.run_model_saving_test() to ensure comprehensive coverage of basic functionality, shape inference, and serialization. (Style Guide: lines 409, 412, 424)
  • Backend Agnostic: The tests use tensorflow directly. This should be replaced with keras.ops to ensure the tests are backend-agnostic. (Style Guide: line 7)
  • Test Input Size: The input tensors are very large (512x512). Tests should use small inputs (e.g., 32x32) for faster execution. (Style Guide: line 404)

The suggested code refactors the entire test file to align with these standards.

import pytest
from keras import ops

from keras_hub.src.models.controlnet.controlnet_backbone import ControlNetBackbone
from keras_hub.src.tests.test_case import TestCase


class ControlNetBackboneTest(TestCase):
    def setUp(self):
        self.init_kwargs = {}
        self.input_data = {
            "pixel_values": ops.ones((2, 32, 32, 1), dtype="float32"),
        }

    def test_backbone_basics(self):
        self.run_backbone_test(
            cls=ControlNetBackbone,
            init_kwargs=self.init_kwargs,
            input_data=self.input_data,
            expected_output_shape={
                "scale_1": (2, 32, 32, 64),
                "scale_2": (2, 16, 16, 128),
                "scale_3": (2, 8, 8, 256),
            },
        )

    @pytest.mark.large
    def test_saved_model(self):
        self.run_model_saving_test(
            cls=ControlNetBackbone,
            init_kwargs=self.init_kwargs,
            input_data=self.input_data,
        )
References
  1. Tests must be implemented as methods within a class that inherits from TestCase. (link)
  2. Standardized test helper methods like run_backbone_test and run_model_saving_test must be used to ensure consistent and thorough testing. (link)
  3. All code, including tests, must be backend-agnostic and should use keras.ops instead of framework-specific modules like tensorflow. (link)
  4. Test inputs should be small to ensure fast test execution. (link)

Loading