diff --git a/keras_hub/src/models/controlnet/__init__.py b/keras_hub/src/models/controlnet/__init__.py new file mode 100644 index 0000000000..a8de1443d1 --- /dev/null +++ b/keras_hub/src/models/controlnet/__init__.py @@ -0,0 +1,17 @@ +from .controlnet_backbone import ControlNetBackbone +from .controlnet_preprocessor import ControlNetPreprocessor +from .controlnet_unet import ControlNetUNet +from .controlnet import ControlNet +from .controlnet_presets import controlnet_presets, from_preset +from .controlnet_layers import ZeroConv2D, ControlInjection + +__all__ = [ + "ControlNetBackbone", + "ControlNetPreprocessor", + "ControlNetUNet", + "ControlNet", + "controlnet_presets", + "from_preset", + "ZeroConv2D", + "ControlInjection", +] diff --git a/keras_hub/src/models/controlnet/controlnet.py b/keras_hub/src/models/controlnet/controlnet.py new file mode 100644 index 0000000000..298d67f85c --- /dev/null +++ b/keras_hub/src/models/controlnet/controlnet.py @@ -0,0 +1,42 @@ +import keras + +from .controlnet_backbone import ControlNetBackbone +from .controlnet_preprocessor import ControlNetPreprocessor +from .controlnet_unet import ControlNetUNet + + +class ControlNet(keras.Model): + + def __init__(self, image_size=128, base_channels=64, **kwargs): + super().__init__(**kwargs) + + self.image_size = image_size + self.base_channels = base_channels + + self.preprocessor = ControlNetPreprocessor( + target_size=(image_size, image_size) + ) + self.backbone = ControlNetBackbone() + self.unet = ControlNetUNet(base_channels=base_channels) + + def call(self, inputs): + image = inputs["image"] + control = inputs["control"] + + image = self.preprocessor(image) + control = self.preprocessor(control) + control_features = self.backbone(control) + + output = self.unet(image, control_features) + + return output + + def get_config(self): + config = super().get_config() + config.update( + { + "image_size": self.image_size, + "base_channels": self.base_channels, + } + ) + return config diff --git a/keras_hub/src/models/controlnet/controlnet_backbone.py b/keras_hub/src/models/controlnet/controlnet_backbone.py new file mode 100644 index 0000000000..1ca46b4f2f --- /dev/null +++ b/keras_hub/src/models/controlnet/controlnet_backbone.py @@ -0,0 +1,46 @@ +import keras +import tensorflow as tf + + +class ControlNetBackbone(keras.Model): + """Lightweight conditioning encoder for ControlNet.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.down1 = keras.layers.Conv2D( + 64, kernel_size=3, padding="same", activation="relu" + ) + self.down2 = keras.layers.Conv2D( + 128, kernel_size=3, padding="same", activation="relu" + ) + self.down3 = keras.layers.Conv2D( + 256, kernel_size=3, padding="same", activation="relu" + ) + + self.pool = keras.layers.MaxPooling2D(pool_size=2) + + def build(self, input_shape): + self.down1.build(input_shape) + b, h, w, c = input_shape + half_shape = (b, h // 2, w // 2, 64) + self.down2.build(half_shape) + quarter_shape = (b, h // 4, w // 4, 128) + self.down3.build(quarter_shape) + + super().build(input_shape) + + def call(self, x): + f1 = self.down1(x) + p1 = self.pool(f1) + + f2 = self.down2(p1) + p2 = self.pool(f2) + + f3 = self.down3(p2) + + return { + "scale_1": f1, + "scale_2": f2, + "scale_3": f3, + } diff --git a/keras_hub/src/models/controlnet/controlnet_backbone_test.py b/keras_hub/src/models/controlnet/controlnet_backbone_test.py new file mode 100644 index 0000000000..da08181711 --- /dev/null +++ b/keras_hub/src/models/controlnet/controlnet_backbone_test.py @@ -0,0 +1,41 @@ +import tensorflow as tf +from keras_hub.src.models.controlnet.controlnet_backbone import ( + ControlNetBackbone, +) + + +def test_controlnet_backbone_smoke(): + model = ControlNetBackbone() + x = tf.random.uniform((1, 512, 512, 1)) + outputs = model(x) + assert isinstance(outputs, dict) + + +def test_controlnet_backbone_required_keys(): + model = ControlNetBackbone() + x = tf.random.uniform((1, 512, 512, 1)) + outputs = model(x) + + assert "scale_1" in outputs + assert "scale_2" in outputs + assert "scale_3" in outputs + + +def test_controlnet_backbone_rank(): + model = ControlNetBackbone() + x = tf.random.uniform((2, 256, 256, 1)) + outputs = model(x) + + for v in outputs.values(): + assert len(v.shape) == 4 + assert v.shape[0] == 2 + + +def test_controlnet_backbone_spatial_scaling(): + model = ControlNetBackbone() + x = tf.random.uniform((1, 256, 256, 1)) + outputs = model(x) + + assert outputs["scale_1"].shape[1:3] == (256, 256) + assert outputs["scale_2"].shape[1:3] == (128, 128) + assert outputs["scale_3"].shape[1:3] == (64, 64) diff --git a/keras_hub/src/models/controlnet/controlnet_layers.py b/keras_hub/src/models/controlnet/controlnet_layers.py new file mode 100644 index 0000000000..5ca77bd9cb --- /dev/null +++ b/keras_hub/src/models/controlnet/controlnet_layers.py @@ -0,0 +1,45 @@ +import keras +from keras import layers + + +class ZeroConv2D(layers.Layer): + + def __init__(self, filters, **kwargs): + super().__init__(**kwargs) + self.filters = filters + self.conv = layers.Conv2D( + filters, + kernel_size=1, + padding="same", + kernel_initializer="zeros", + bias_initializer="zeros", + ) + + def call(self, inputs): + return self.conv(inputs) + + def get_config(self): + config = super().get_config() + config.update({"filters": self.filters}) + return config + + +class ControlInjection(layers.Layer): + + def __init__(self, out_channels, **kwargs): + super().__init__(**kwargs) + self.out_channels = out_channels + self.projection = ZeroConv2D(out_channels) + + def call(self, x, control): + if x.shape[1:3] != control.shape[1:3]: + raise ValueError( + f"Spatial mismatch: {x.shape[1:3]} vs {control.shape[1:3]}" + ) + control = self.projection(control) + return x + control + + def get_config(self): + config = super().get_config() + config.update({"out_channels": self.out_channels}) + return config diff --git a/keras_hub/src/models/controlnet/controlnet_layers_test.py b/keras_hub/src/models/controlnet/controlnet_layers_test.py new file mode 100644 index 0000000000..0a769b30b7 --- /dev/null +++ b/keras_hub/src/models/controlnet/controlnet_layers_test.py @@ -0,0 +1,16 @@ +import tensorflow as tf +from keras_hub.src.models.controlnet.controlnet_layers import ZeroConv2D + + +def test_zero_conv_output_shape(): + layer = ZeroConv2D(64) + x = tf.random.uniform((1, 128, 128, 3)) + y = layer(x) + assert y.shape == (1, 128, 128, 64) + + +def test_zero_conv_initial_output_is_zero(): + layer = ZeroConv2D(64) + x = tf.random.uniform((1, 64, 64, 3)) + y = layer(x) + assert tf.reduce_sum(tf.abs(y)).numpy() == 0.0 diff --git a/keras_hub/src/models/controlnet/controlnet_preprocessor.py b/keras_hub/src/models/controlnet/controlnet_preprocessor.py new file mode 100644 index 0000000000..c73253a449 --- /dev/null +++ b/keras_hub/src/models/controlnet/controlnet_preprocessor.py @@ -0,0 +1,39 @@ +import keras +import tensorflow as tf + + +class ControlNetPreprocessor(keras.layers.Layer): + def __init__(self, target_size=(512, 512), **kwargs): + super().__init__(**kwargs) + self.target_size = tuple(target_size) + + def call(self, inputs): + x = tf.convert_to_tensor(inputs) + + if x.shape.rank != 4: + raise ValueError("Inputs must be a 4D tensor (batch, height, width, channels).") + + x = tf.image.resize(x, self.target_size) + x = tf.cast(x, tf.float32) + + max_val = tf.reduce_max(x) + x = tf.cond( + max_val > 1.0, + lambda: x / 255.0, + lambda: x, + ) + + return x + + def compute_output_shape(self, input_shape): + return ( + input_shape[0], + self.target_size[0], + self.target_size[1], + input_shape[-1], + ) + + def get_config(self): + config = super().get_config() + config.update({"target_size": self.target_size}) + return config diff --git a/keras_hub/src/models/controlnet/controlnet_preprocessor_test.py b/keras_hub/src/models/controlnet/controlnet_preprocessor_test.py new file mode 100644 index 0000000000..0723636582 --- /dev/null +++ b/keras_hub/src/models/controlnet/controlnet_preprocessor_test.py @@ -0,0 +1,30 @@ +import tensorflow as tf +from keras_hub.src.models.controlnet.controlnet_preprocessor import ControlNetPreprocessor + + +def test_controlnet_preprocessor_output_shape(): + layer = ControlNetPreprocessor(target_size=(128, 128)) + + x = tf.random.uniform((1, 256, 256, 3), maxval=255, dtype=tf.float32) + y = layer(x) + + assert y.shape == (1, 128, 128, 3) + + +def test_controlnet_preprocessor_scaling(): + layer = ControlNetPreprocessor(target_size=(64, 64)) + + x = tf.ones((1, 128, 128, 3)) * 255.0 + y = layer(x) + + assert tf.reduce_max(y).numpy() <= 1.0 + assert tf.reduce_min(y).numpy() >= 0.0 + + +def test_controlnet_preprocessor_dtype(): + layer = ControlNetPreprocessor(target_size=(64, 64)) + + x = tf.random.uniform((1, 128, 128, 3), maxval=255, dtype=tf.float32) + y = layer(x) + + assert y.dtype == tf.float32 diff --git a/keras_hub/src/models/controlnet/controlnet_presets.py b/keras_hub/src/models/controlnet/controlnet_presets.py new file mode 100644 index 0000000000..58d2b28481 --- /dev/null +++ b/keras_hub/src/models/controlnet/controlnet_presets.py @@ -0,0 +1,20 @@ +from .controlnet import ControlNet + + +controlnet_presets = { + "controlnet_base": { + "description": "Minimal ControlNet base configuration.", + "config": { + "image_size": 128, + "base_channels": 64, + }, + } +} + + +def from_preset(preset_name): + if preset_name not in controlnet_presets: + raise ValueError(f"Unknown preset: {preset_name}") + + config = controlnet_presets[preset_name]["config"] + return ControlNet(**config) diff --git a/keras_hub/src/models/controlnet/controlnet_presets_test.py b/keras_hub/src/models/controlnet/controlnet_presets_test.py new file mode 100644 index 0000000000..8c343ab826 --- /dev/null +++ b/keras_hub/src/models/controlnet/controlnet_presets_test.py @@ -0,0 +1,15 @@ +import tensorflow as tf +from keras_hub.src.models.controlnet.controlnet_presets import from_preset + + +def test_controlnet_from_preset(): + model = from_preset("controlnet_base") + + inputs = { + "image": tf.random.uniform((1, 128, 128, 3)), + "control": tf.random.uniform((1, 128, 128, 3)), + } + + outputs = model(inputs) + + assert outputs.shape == (1, 128, 128, 3) diff --git a/keras_hub/src/models/controlnet/controlnet_test.py b/keras_hub/src/models/controlnet/controlnet_test.py new file mode 100644 index 0000000000..ea766db5e1 --- /dev/null +++ b/keras_hub/src/models/controlnet/controlnet_test.py @@ -0,0 +1,15 @@ +import tensorflow as tf +from keras_hub.src.models.controlnet.controlnet import ControlNet + + +def test_controlnet_full_model_smoke(): + model = ControlNet() + + inputs = { + "image": tf.random.uniform((1, 128, 128, 3)), + "control": tf.random.uniform((1, 128, 128, 3)), + } + + outputs = model(inputs) + + assert outputs.shape == (1, 128, 128, 3) diff --git a/keras_hub/src/models/controlnet/controlnet_unet.py b/keras_hub/src/models/controlnet/controlnet_unet.py new file mode 100644 index 0000000000..b1fa7567f5 --- /dev/null +++ b/keras_hub/src/models/controlnet/controlnet_unet.py @@ -0,0 +1,40 @@ +import keras +from .controlnet_layers import ControlInjection + + +class ControlNetUNet(keras.Model): + + def __init__(self, base_channels=64, **kwargs): + super().__init__(**kwargs) + + self.base_channels = base_channels + + self.conv1 = keras.layers.Conv2D( + base_channels, 3, padding="same", activation="relu" + ) + + self.inject = ControlInjection(base_channels) + + self.conv2 = keras.layers.Conv2D( + base_channels, 3, padding="same", activation="relu" + ) + + self.out_conv = keras.layers.Conv2D( + 3, 1, padding="same" + ) + + def call(self, image, control_features): + if "scale_1" not in control_features: + raise ValueError("Expected 'scale_1' in control_features.") + + x = self.conv1(image) + x = self.inject(x, control_features["scale_1"]) + x = self.conv2(x) + x = self.out_conv(x) + + return x + + def get_config(self): + config = super().get_config() + config.update({"base_channels": self.base_channels}) + return config diff --git a/keras_hub/src/models/controlnet/controlnet_unet_test.py b/keras_hub/src/models/controlnet/controlnet_unet_test.py new file mode 100644 index 0000000000..370575e471 --- /dev/null +++ b/keras_hub/src/models/controlnet/controlnet_unet_test.py @@ -0,0 +1,15 @@ +import tensorflow as tf +from keras_hub.src.models.controlnet.controlnet_unet import ControlNetUNet + + +def test_controlnet_unet_smoke(): + model = ControlNetUNet() + + image = tf.random.uniform((1, 128, 128, 3)) + control_features = { + "scale_1": tf.random.uniform((1, 128, 128, 64)) + } + + outputs = model(image, control_features) + + assert outputs.shape == (1, 128, 128, 3)