Skip to content

Commit a702110

Browse files
committed
Fixed regional controlnet only usage
1 parent da8f931 commit a702110

1 file changed

Lines changed: 18 additions & 6 deletions

File tree

ai_diffusion/model/control.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,13 @@ def to_api(self, bounds: Bounds | None = None, time: int | None = None):
123123
if self.mode.is_ip_adapter and not layer.bounds.is_zero:
124124
bounds = None # ignore mask bounds, use layer bounds
125125

126-
image = (
127-
self._regional_control_image(bounds, time)
128-
if self.mode is ControlMode.regional
129-
else layer.get_pixels(bounds, time)
130-
)
126+
image = None
127+
if self.mode is ControlMode.regional:
128+
image = self._regional_control_image(bounds, time)
129+
if image is None:
130+
image = self._regional_control_image_from_layer(layer, bounds, time)
131+
if image is None:
132+
image = layer.get_pixels(bounds, time)
131133

132134
if self.mode.is_lines or self.mode is ControlMode.stencil:
133135
image.make_opaque(background=Qt.GlobalColor.white)
@@ -143,11 +145,20 @@ def to_api(self, bounds: Bounds | None = None, time: int | None = None):
143145
strength = self.strength / self.strength_multiplier
144146
return ControlInput(self.mode, image, strength, (self.start, self.end))
145147

148+
def _regional_control_image_from_layer(
149+
self, layer: Layer, bounds: Bounds | None, time: int | None
150+
):
151+
bounds = bounds or Bounds.from_extent(self._model.document.extent)
152+
image = Image.create(bounds.extent, fill=Qt.GlobalColor.white)
153+
image.draw_image(layer.get_pixels(bounds, time))
154+
return image
155+
146156
def _regional_control_image(self, bounds: Bounds | None, time: int | None):
147157
from .region import RegionLink
148158

149159
bounds = bounds or Bounds.from_extent(self._model.document.extent)
150160
image = Image.create(bounds.extent, fill=Qt.GlobalColor.white)
161+
has_region_layer = False
151162
root = self._model.regions
152163

153164
for layer in root.layers.all:
@@ -159,8 +170,9 @@ def _regional_control_image(self, bounds: Bounds | None, time: int | None):
159170
continue
160171

161172
image.draw_image(layer.get_pixels(bounds, time))
173+
has_region_layer = True
162174

163-
return image
175+
return image if has_region_layer else None
164176

165177
def generate(self):
166178
self._generate_job = self._model.generate_control_layer(self)

0 commit comments

Comments
 (0)