Skip to content

Add initial ControlNet backbone#2586

Open
Amitavoo wants to merge 2 commits intokeras-team:masterfrom
Amitavoo:controlnet-backbone
Open

Add initial ControlNet backbone#2586
Amitavoo wants to merge 2 commits intokeras-team:masterfrom
Amitavoo:controlnet-backbone

Conversation

@Amitavoo
Copy link
Contributor

@Amitavoo Amitavoo commented Feb 6, 2026

This PR adds a minimal ControlNet conditioning backbone for Keras-Hub.

The backbone encodes a 1-channel conditioning image (e.g., Canny or depth map)
into multi-scale features that can later be injected into a diffusion UNet.

Included:

  • ControlNetBackbone implementation
  • Basic smoke and shape tests

This is intended as the first step toward full ControlNet support in follow-up PRs.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Amitavoo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request establishes the initial ControlNetBackbone component within Keras-Hub, providing a fundamental building block for conditional image generation. This backbone is engineered to process single-channel conditioning inputs, such as Canny edge maps or depth maps, transforming them into a set of multi-scale feature representations. These extracted features are intended to serve as crucial conditioning signals for diffusion UNets, paving the way for more advanced ControlNet functionalities in future developments.

Highlights

  • ControlNet Backbone Implementation: Introduced the ControlNetBackbone class, a keras.Model subclass, designed to encode 1-channel conditioning images into multi-scale features.
  • Multi-Scale Feature Extraction: The backbone utilizes a sequence of Conv2D layers with relu activation and MaxPooling2D layers to generate feature maps at different spatial scales (scale_1, scale_2, scale_3).
  • Comprehensive Unit Tests: Added dedicated tests to ensure the ControlNetBackbone correctly builds and runs, returns the expected output dictionary keys, and maintains correct tensor ranks and batch sizes.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • keras_hub/src/models/controlnet/controlnet_backbone.py
    • Added ControlNetBackbone class, inheriting from keras.Model.
    • Implemented __init__ to define three Conv2D layers with increasing filters (64, 128, 256) and a MaxPooling2D layer.
    • Defined the build method to correctly handle input shape propagation for the convolutional layers.
    • Implemented the call method to perform the forward pass, applying convolutions and pooling, and returning a dictionary of multi-scale feature maps (scale_1, scale_2, scale_3).
  • keras_hub/src/models/controlnet/controlnet_backbone_test.py
    • Added test_controlnet_backbone_smoke to verify basic model instantiation and execution.
    • Added test_controlnet_backbone_required_keys to assert that the output dictionary contains the expected feature scale keys.
    • Added test_controlnet_backbone_rank to confirm that all output tensors are 4D and maintain the correct batch size.
Activity
  • No specific activity (comments, reviews, or progress updates) has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a ControlNetBackbone, but the implementation deviates from the repository's contribution guidelines in several key areas. The backbone model should be implemented using the Keras Functional API and inherit from keras_hub.models.Backbone. The associated tests also need to be updated to use the standardized TestCase helpers. I've provided detailed suggestions to align the code with the project's standards.

Comment on lines +1 to +46
import keras
import tensorflow as tf


class ControlNetBackbone(keras.Model):
"""Lightweight conditioning encoder for ControlNet."""

def __init__(self, **kwargs):
super().__init__(**kwargs)

self.down1 = keras.layers.Conv2D(
64, kernel_size=3, padding="same", activation="relu"
)
self.down2 = keras.layers.Conv2D(
128, kernel_size=3, padding="same", activation="relu"
)
self.down3 = keras.layers.Conv2D(
256, kernel_size=3, padding="same", activation="relu"
)

self.pool = keras.layers.MaxPooling2D(pool_size=2)

def build(self, input_shape):
self.down1.build(input_shape)
b, h, w, c = input_shape
half_shape = (b, h // 2, w // 2, 64)
self.down2.build(half_shape)
quarter_shape = (b, h // 4, w // 4, 128)
self.down3.build(quarter_shape)

super().build(input_shape)

def call(self, x):
f1 = self.down1(x)
p1 = self.pool(f1)

f2 = self.down2(p1)
p2 = self.pool(f2)

f3 = self.down3(p2)

return {
"scale_1": f1,
"scale_2": f2,
"scale_3": f3,
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation of ControlNetBackbone deviates significantly from the KerasHub style guide. To align with the repository's standards, the model should be refactored.

Here are the key issues and how the suggestion addresses them:

  • Inheritance: The model should inherit from keras_hub.models.backbone.Backbone instead of keras.Model to gain standard functionality like from_preset(). (Style Guide: line 86)
  • Functional API: Backbones must be implemented using the Keras Functional API inside the __init__ method, not as a subclassed model with a call method. This makes the model structure explicit and avoids the need for a manual build() method. (Style Guide: line 79)
  • Docstrings: The class is missing a comprehensive Google-style docstring, including Args and an Example section. (Style Guide: lines 366-371)
  • Serialization: A get_config() method is required for proper serialization. (Style Guide: line 528)
  • Backend Agnostic: The unused import tensorflow as tf should be removed to maintain backend-agnostic code. (Style Guide: line 7)
  • Standard Naming: The input should be named pixel_values as per the convention for image models. (Style Guide: line 67)
  • Export: The class should be decorated with @keras_hub_export to make it part of the public API. (Style Guide: line 85)

I've provided a code suggestion that refactors the entire class to follow these guidelines.

import keras
import numpy as np

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.backbone import Backbone


@keras_hub_export("keras_hub.models.ControlNetBackbone")
class ControlNetBackbone(Backbone):
    """Lightweight conditioning encoder for ControlNet.

    This backbone model takes a conditioning image (e.g., a Canny edge map) and
    encodes it into multi-scale feature maps. These feature maps can then be
    injected into the intermediate layers of a diffusion model like a UNet to
    condition its generation process.

    This implementation is a simplified version of the ControlNet encoder,
    consisting of a series of convolutional and pooling layers.

    Args:
        input_shape: tuple. The shape of the input image, excluding the batch
            dimension. Defaults to `(None, None, 1)`.
        dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
            for model computations and weights.

    Example:
    ```python
    import numpy as np

    # Define the input for the conditioning image.
    input_data = {
        "pixel_values": np.ones(shape=(1, 512, 512, 1), dtype="float32"),
    }

    # Create a ControlNetBackbone instance.
    model = ControlNetBackbone()

    # Get the multi-scale feature maps.
    outputs = model(input_data)
    # `outputs` will be a dictionary with keys "scale_1", "scale_2", "scale_3".
    ```
    """

    def __init__(
        self,
        input_shape=(None, None, 1),
        dtype=None,
        **kwargs,
    ):
        # === Layers ===
        self.down1 = keras.layers.Conv2D(
            64, kernel_size=3, padding="same", activation="relu", name="down1_conv"
        )
        self.down2 = keras.layers.Conv2D(
            128, kernel_size=3, padding="same", activation="relu", name="down2_conv"
        )
        self.down3 = keras.layers.Conv2D(
            256, kernel_size=3, padding="same", activation="relu", name="down3_conv"
        )
        self.pool = keras.layers.MaxPooling2D(pool_size=2, name="pool")

        # === Functional Model ===
        pixel_values = keras.Input(
            shape=input_shape, dtype=dtype, name="pixel_values"
        )

        f1 = self.down1(pixel_values)
        p1 = self.pool(f1)

        f2 = self.down2(p1)
        p2 = self.pool(f2)

        f3 = self.down3(p2)

        outputs = {
            "scale_1": f1,
            "scale_2": f2,
            "scale_3": f3,
        }

        super().__init__(
            inputs=pixel_values,
            outputs=outputs,
            dtype=dtype,
            **kwargs,
        )

        # Store parameters as attributes.
        self.input_shape_arg = input_shape

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "input_shape": self.input_shape_arg,
            }
        )
        return config
References
  1. Backbone models must be implemented using the Keras Functional API within the __init__ method, rather than as a subclassed model with a call method. (link)
  2. Backbone models should inherit from keras_hub.models.Backbone to ensure they have standard features like from_preset. (link)
  3. All public classes must have Google-style docstrings, including Args and Example sections. (link)
  4. All layers and models must implement a get_config() method for serialization. (link)
  5. All code must be backend-agnostic, which means avoiding direct imports from tensorflow and using keras.ops instead. (link)
  6. Model inputs should use standardized names. For image models, pixel_values is the convention. (link)
  7. Public models must be decorated with @keras_hub_export to be included in the library's public API. (link)

Comment on lines +1 to +38
import tensorflow as tf
from keras_hub.src.models.controlnet.controlnet_backbone import (
ControlNetBackbone,
)


def test_controlnet_backbone_smoke():
"""Basic smoke test: model builds and runs."""
model = ControlNetBackbone()

x = tf.random.uniform((1, 512, 512, 1))
outputs = model(x)

assert isinstance(outputs, dict)


def test_controlnet_backbone_required_keys():
"""Ensure expected feature scales exist."""
model = ControlNetBackbone()
x = tf.random.uniform((1, 512, 512, 1))

outputs = model(x)

assert "scale_1" in outputs
assert "scale_2" in outputs
assert "scale_3" in outputs


def test_controlnet_backbone_rank():
"""Each output should be a 4D tensor (B, H, W, C)."""
model = ControlNetBackbone()
x = tf.random.uniform((2, 256, 256, 1))

outputs = model(x)

for v in outputs.values():
assert len(v.shape) == 4
assert v.shape[0] == 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The test file does not follow the standardized testing structure required by the KerasHub contribution guidelines. The tests should be refactored to use the provided TestCase class and its helper methods.

Specifically:

  • Test Structure: Tests should be methods within a class that inherits from keras_hub.src.tests.test_case.TestCase. (Style Guide: line 449)
  • Standardized Routines: Instead of manual checks, the tests should use helper methods like self.run_backbone_test() and self.run_model_saving_test() to ensure comprehensive coverage of basic functionality, shape inference, and serialization. (Style Guide: lines 409, 412, 424)
  • Backend Agnostic: The tests use tensorflow directly. This should be replaced with keras.ops to ensure the tests are backend-agnostic. (Style Guide: line 7)
  • Test Input Size: The input tensors are very large (512x512). Tests should use small inputs (e.g., 32x32) for faster execution. (Style Guide: line 404)

The suggested code refactors the entire test file to align with these standards.

import pytest
from keras import ops

from keras_hub.src.models.controlnet.controlnet_backbone import ControlNetBackbone
from keras_hub.src.tests.test_case import TestCase


class ControlNetBackboneTest(TestCase):
    def setUp(self):
        self.init_kwargs = {}
        self.input_data = {
            "pixel_values": ops.ones((2, 32, 32, 1), dtype="float32"),
        }

    def test_backbone_basics(self):
        self.run_backbone_test(
            cls=ControlNetBackbone,
            init_kwargs=self.init_kwargs,
            input_data=self.input_data,
            expected_output_shape={
                "scale_1": (2, 32, 32, 64),
                "scale_2": (2, 16, 16, 128),
                "scale_3": (2, 8, 8, 256),
            },
        )

    @pytest.mark.large
    def test_saved_model(self):
        self.run_model_saving_test(
            cls=ControlNetBackbone,
            init_kwargs=self.init_kwargs,
            input_data=self.input_data,
        )
References
  1. Tests must be implemented as methods within a class that inherits from TestCase. (link)
  2. Standardized test helper methods like run_backbone_test and run_model_saving_test must be used to ensure consistent and thorough testing. (link)
  3. All code, including tests, must be backend-agnostic and should use keras.ops instead of framework-specific modules like tensorflow. (link)
  4. Test inputs should be small to ensure fast test execution. (link)

@sachinprasadhs sachinprasadhs added the new model For PRs that contribute a new model to the Keras Hub registry. label Feb 9, 2026
@Amitavoo
Copy link
Contributor Author

Hi @keras-team, just checking in on this PR when you have time. Happy to make any changes or add tests if needed. Thanks!

@sachinprasadhs
Copy link
Collaborator

Thanks for the PR.
Please add all the code related to this model in a single PR, it would be easy for us to review and suggest.

@Amitavoo
Copy link
Contributor Author

Thanks for the clarification@sachinprasadhs!

Understood — I will extend this PR to include the full ControlNet implementation
(backbone, UNet integration, preprocessor, presets, and corresponding tests)
so it can be reviewed as a complete model addition.

I’ll update this PR shortly with the additional components.

@samudraneel05
Copy link

Thanks for the PR. Please add all the code related to this model in a single PR, it would be easy for us to review and suggest.

Hi, is this the case just for this model or all new model introductions? If everything in one go is the normal in KerasHub, then CONTRIBUTING_MODELS.md should be updated... it still says to add changes like backbone, tokenizer, preprocessor one by one in different PRs.

Copy link
Contributor Author

@Amitavoo Amitavoo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sachinprasadhs

I’ve added the full ControlNet implementation in a single PR as suggested, including backbone, UNet, preprocessor, layers, presets, and corresponding tests.

All unit tests are passing locally, and the full model smoke tests are working correctly.

Please let me know if you’d like any architectural changes or refinements.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

new model For PRs that contribute a new model to the Keras Hub registry. stat:awaiting response from contributor

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants