Update support for global cond U-Net#331
Open
sgreenbury wants to merge 3 commits intomainfrom
Open
Conversation
AzulaUNetProcessor.forward(x, x_noise=None) regressed after the TemporalUNetBackbone refactor: None was forwarded into the base class, which hit AttributeError on t.ndim. Substitute zeros of shape (B, mod_features) to match what map() already does. Also drop the duplicate include_global_cond/global_cond_channels check; TemporalBackboneBase already raises.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Analogous PR to #330 but for U-Net.
This pull request updates the Azula UNet processor to support global conditioning and improves flexibility in normalization and modulation. The main changes include adding support for an optional global conditioning vector, updating configuration files to control this feature, and refactoring the processor to use the new
TemporalUNetBackbonewith enhanced argument handling.Global conditioning support and configuration:
include_global_condandglobal_cond_channelsarguments to theAzulaUNetProcessorand corresponding YAML config files, allowing the model to optionally use a global conditioning vector. [1] [2] [3] [4]forwardandmapmethods to handle global conditioning, including input validation and correct tensor passing. [1] [2]Processor and backbone refactoring:
UNetto the newTemporalUNetBackbone, introducing new arguments for normalization (norm,groups), modulation, and global conditioning. [1] [2]Documentation improvements: