diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 84f5661..0000000 Binary files a/.DS_Store and /dev/null differ diff --git a/.gitignore b/.gitignore index b7faf40..dfa1df5 100644 --- a/.gitignore +++ b/.gitignore @@ -205,3 +205,30 @@ cython_debug/ marimo/_static/ marimo/_lsp/ __marimo__/ + + + +pretrained_model +__pycache__/ +*.pyc +*.pyo +*.pyd +*.json +*.jsonl +env/ +venv/ +.venv/ +results/ +fusion_result.json +*output* +*kernel_meta/ +data/*.json +maps +data/metric_cache +data/maps +tmp +navsim_v1 +nuscenes +extra-info/ +pretrained_model/ +.DS_Store \ No newline at end of file diff --git a/.pylintrc b/.pylintrc index 7e930ac..aa823bf 100644 --- a/.pylintrc +++ b/.pylintrc @@ -31,7 +31,7 @@ extension-pkg-allow-list= # be loaded. Extensions are loading into the active Python interpreter and may # run arbitrary code. (This is an alternative name to extension-pkg-allow-list # for backward compatibility.) -extension-pkg-whitelist=cv2 +extension-pkg-whitelist= # Return non-zero exit code if any of these messages/categories are detected, # even if score is above --fail-under value. Syntax same as enable. Messages @@ -59,15 +59,16 @@ ignore-paths= # Emacs file locks ignore-patterns=^\.# -# List of module names for which member attributes should not be checked -# (useful for modules/projects where namespaces are manipulated during runtime -# and thus existing member attributes cannot be deduced by static analysis). It -# supports qualified module names, as well as Unix pattern matching. -ignored-modules=cv2 +# List of module names for which member attributes should not be checked and +# will not be imported (useful for modules/projects where namespaces are +# manipulated during runtime and thus existing member attributes cannot be +# deduced by static analysis). It supports qualified module names, as well as +# Unix pattern matching. +ignored-modules=torch # Python code to execute, usually for sys.path manipulation such as # pygtk.require(). -init-hook='import sys; sys.path.append(".")' +#init-hook= # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the # number of processors available to use, and will cap the count on Windows to @@ -86,9 +87,13 @@ load-plugins= # Pickle collected data for later comparisons. persistent=yes +# Resolve imports to .pyi stubs if available. May reduce no-member messages and +# increase not-an-iterable messages. +prefer-stubs=no + # Minimum Python version to use for version dependent checks. Will default to # the version used to run pylint. -py-version=3.10 +py-version=3.12 # Discover python modules and packages in the file system subtree. recursive=no @@ -99,10 +104,6 @@ recursive=no # source root. source-roots= -# When enabled, pylint would attempt to guess common misconfiguration and emit -# user-friendly hints instead of false-positive error messages. -suggestion-mode=yes - # Allow loading of arbitrary C extensions. Extensions are imported into the # active Python interpreter and may run arbitrary code. unsafe-load-any-extension=no @@ -229,6 +230,11 @@ name-group= # not require a docstring. no-docstring-rgx=^_ +# Regular expression matching correct parameter specification variable names. +# If left empty, parameter specification variable names will be checked with +# the set naming style. +#paramspec-rgx= + # List of decorators that produce properties, such as abc.abstractproperty. Add # to this list to register other decorators that produce valid properties. # These decorators are taken in consideration only for invalid-name. @@ -242,13 +248,17 @@ property-classes=abc.abstractproperty # variable names will be checked with the set naming style. #typevar-rgx= +# Regular expression matching correct type variable tuple names. If left empty, +# type variable tuple names will be checked with the set naming style. +#typevartuple-rgx= + # Naming style matching correct variable names. variable-naming-style=snake_case # Regular expression matching correct variable names. Overrides variable- # naming-style. If left empty, variable names will be checked with the set # naming style. -variable-rgx=(_?[a-z][A-Za-z0-9]{0,30})|([A-Z0-9]{1,30}) +#variable-rgx= [CLASSES] @@ -285,23 +295,26 @@ exclude-too-few-public-methods= ignored-parents= # Maximum number of arguments for function / method. -max-args=7 +max-args=10 # Maximum number of attributes for a class (see R0902). -max-attributes=20 +max-attributes=7 # Maximum number of boolean expressions in an if statement (see R0916). max-bool-expr=5 # Maximum number of branch for function / method body. -max-branches=12 +max-branches=40 # Maximum number of locals for function / method body. -max-locals=15 +max-locals=50 # Maximum number of parents for a class (see R0901). max-parents=7 +# Maximum number of positional arguments for function / method. +max-positional-arguments=10 + # Maximum number of public methods for a class (see R0904). max-public-methods=20 @@ -309,10 +322,10 @@ max-public-methods=20 max-returns=6 # Maximum number of statements in function / method body. -max-statements=300 +max-statements=150 # Minimum number of public methods for a class (see R0903). -min-public-methods=1 +min-public-methods=2 [EXCEPTIONS] @@ -336,11 +349,13 @@ indent-after-paren=4 # tab). indent-string=' ' -# Maximum number of characters on a single line. +# Maximum number of characters on a single line. Pylint's default of 100 is +# based on PEP 8's guidance that teams may choose line lengths up to 99 +# characters. max-line-length=150 # Maximum number of lines in a module. -max-module-lines=2000 +max-module-lines=1000 # Allow the body of a class to be on the same line as the declaration if body # contains single statement. @@ -421,11 +436,21 @@ confidence=HIGH, # --enable=similarities". If you want to run only the classes checker, but have # no Warning level messages displayed, use "--disable=all --enable=classes # --disable=W". -disable=too-many-arguments, - too-many-locals, - too-many-branches, - protected-access - +disable=raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + use-symbolic-message-instead, + use-implicit-booleaness-not-comparison-to-string, + use-implicit-booleaness-not-comparison-to-zero, + C0116, + C0114, + W0621, + E0601, + W0718 # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option @@ -443,9 +468,13 @@ timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests. [MISCELLANEOUS] +# Whether or not to search for fixme's in docstrings. +check-fixme-in-docstring=no + # List of note tags to take in consideration, separated by a comma. notes=FIXME, - XXX + XXX, + TODO # Regular expression of note tags to take in consideration. notes-rgx= @@ -465,7 +494,7 @@ never-returning-functions=sys.exit,argparse.parse_error # Let 'consider-using-join' be raised when the separator to join on would be # non-empty (resulting in expected fixes of the type: ``"- " + " - # ".join(items)``) -# suggest-join-with-non-empty-separator=yes +suggest-join-with-non-empty-separator=yes [REPORTS] @@ -481,10 +510,10 @@ evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor # used to format the message information. See doc for all details. msg-template= -# Set the output format. Available formats are: text, parseable, colorized, -# json2 (improved json format), json (old json format) and msvs (visual -# studio). You can also give a reporter class, e.g. -# mypackage.mymodule.MyReporterClass. +# Set the output format. Available formats are: 'text', 'parseable', +# 'colorized', 'json2' (improved json format), 'json' (old json format), msvs +# (visual studio) and 'github' (GitHub actions). You can also give a reporter +# class, e.g. mypackage.mymodule.MyReporterClass. #output-format= # Tells whether to display a full report or only the messages. @@ -586,7 +615,7 @@ ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace # of finding the hint is based on edit distance. missing-member-hint=yes -# The minimum edit distance a name should have in order to be considered a +# The maximum edit distance a name should have in order to be considered a # similar match for a missing member name. missing-member-hint-distance=1 @@ -630,4 +659,4 @@ init-import=no # List of qualified module names which can have objects that can redefine # builtins. -redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io \ No newline at end of file +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index 50cfc2f..5271aa3 100644 --- a/README.md +++ b/README.md @@ -27,17 +27,23 @@
+ + ## πŸ“° News - **`2025/12/06`**: πŸŽ‰πŸŽ‰πŸŽ‰ Paper submitted on [Arxiv](https://arxiv.org/pdf/2512.06112). + + ## πŸ“…οΈ Roadmap | Status | Milestone | ETA | | :----: | :----------------------------------------------------------------------------------------------------: | :--------: | -| πŸš€ | **[Releasing the inference source code](https://github.com/fudan-generative-vision/WAM-Flow)** | 2025.12.21 | -| πŸš€ | **[Pretrained models on Huggingface](https://huggingface.co/fudan-generative-ai/WAM-Flow)** | TBD | -| πŸš€ | **[Releasing the training scripts](#training)** | TBD | +| βœ… | **[Release the SFT and inference code](https://github.com/fudan-generative-vision/WAM-Flow)** | 2025.12.19 | +| πŸš€ | **[Pretrained models on Huggingface](https://huggingface.co/fudan-generative-ai/WAM-Flow)** | TBD | +| πŸš€ | **[Release the evaluation code](https://huggingface.co/fudan-generative-ai/WAM-Flow)** | TBD | +| πŸš€ | **[Release the RL code](https://github.com/fudan-generative-vision/WAM-Flow)** | TBD | +| πŸš€ | **[Release the pre-processed training data](#training)** | TBD | ## πŸ“Έ Showcase @@ -45,14 +51,68 @@ ## πŸ† Qualitative Results on NAVSIM ### NAVSIM-v1 benchmark results -![navsim-v1](assets/navsim-v1.png) +
+ navsim-v1 +
+ ### NAVSIM-v2 benchmark results -![navsim-v2](assets/navsim-v2.png) +
+navsim-v2 +
+ + ## πŸ”§οΈ Framework ![framework](assets/Figure_2.png) Our method takes as input a front-view image, a natural-language navigation command with a system prompt, and the ego-vehicle states, and outputs an 8-waypoint future trajectory spanning 4 seconds through parallel denoising. The model is first trained via supervised fine-tuning to learn accurate trajectory prediction. We then apply simulatorguided GRPO to further optimize closed-loop behavior. The GRPO reward function integrates safety constraints (collision avoidance, drivable-area compliance) with performance objectives (ego-progress, time-to-collision, comfort). + +## Preparation + +### Environment +```sh +conda create --name wam-flow python=3.10 +conda activate wam-flow +pip install -r requirements.txt +``` + +### Checkpoint + + +### Data + + + +## Training + +### Pretraining + +### SFT + +### RL + +### Debug +```sh +sh script/sft_debug.sh +``` + + +## Inference +```sh +sh script/infer.sh +``` + + +## Evaluation + +### NAVSIM-v1 + +### NAVSIM-v2 + +### nuScenes + + + ## πŸ“ Citation If you find our work useful for your research, please consider citing the paper: @@ -66,9 +126,13 @@ If you find our work useful for your research, please consider citing the paper: } ``` + + ## ⚠️ Social Risks and Mitigations The integration of Vision-Language-Action models into autonomous driving introduces ethical challenges, particularly regarding the opacity of neural decision-making and its impact on road safety. To mitigate these risks, it is imperative to implement explainable AI frameworks and robust safe protocols that ensure predictable vehicle behavior in long-tailed scenarios. Furthermore, addressing concerns over data privacy and public surveillance requires transparent data governance and rigorous de-identification practices. By prioritizing safety-critical alignment and ethical compliance, this research promotes the responsible development and deployment of VLA-based autonomous systems. + + ## πŸ€— Acknowledgements -We gratefully acknowledge the contributors to the [Janus](https://github.com/deepseek-ai/Janus), [FUDOKI](https://github.com/fudoki-hku/FUDOKI) and [flow_matching](https://github.com/facebookresearch/flow_matching) repositories, whose commitment to open source has provided us with their excellent codebases and pretrained models. \ No newline at end of file +We gratefully acknowledge the contributors to the [Recogdrive](https://github.com/xiaomi-research/recogdrive), [Janus](https://github.com/deepseek-ai/Janus), [FUDOKI](https://github.com/fudoki-hku/FUDOKI) and [flow_matching](https://github.com/facebookresearch/flow_matching) repositories, whose commitment to open source has provided us with their excellent codebases and pretrained models. \ No newline at end of file diff --git a/config/accelerate_config_ds2.yaml b/config/accelerate_config_ds2.yaml new file mode 100644 index 0000000..325fe7c --- /dev/null +++ b/config/accelerate_config_ds2.yaml @@ -0,0 +1,21 @@ +compute_environment: LOCAL_MACHINE +debug: true +deepspeed_config: + deepspeed_multinode_launcher: standard + gradient_accumulation_steps: 1 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 0 # use 2 to accelerate +distributed_type: DEEPSPEED +downcast_bf16: "no" +main_training_function: main +mixed_precision: "fp16" +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/config/debug.yaml b/config/debug.yaml new file mode 100644 index 0000000..4d7fe62 --- /dev/null +++ b/config/debug.yaml @@ -0,0 +1,36 @@ +model_path: pretrained_model/fudoki +text_embedding_path: pretrained_model/fudoki/text_embedding.pt +pretrain_model_path: pretrained_model/wam-flow/navsim + +stage: s2 +txt_max_length: 500 +new_embedding_path: "" + +pretrain_path: "" +ckpt_path: "" +train_llm_emb: false + +data_list: [ + "data/navsim_debug.jsonl", +] + +mixed_precision: "no" +accumulate_grad_batches: 1 +max_grad_norm: 1.0 +seed: 42 +source_distribution: "uniform" +vocab_size: 102400 +batch_size: 1 +dataloader_num_workers: 8 +learning_rate: 5e-6 +lr_scheduler_type: "cosine" # "constant", "cosine" +lr_warmup_steps: 500 +max_train_steps: 40000 +max_epochs: 100 +uncond_prob: 0 +use_quantize: true +random_seed: false +l2_loss_weight: 0 + +checkpointing_steps: 4000 +checkpoints_total_limit: 50 \ No newline at end of file diff --git a/config/pretrain.yaml b/config/pretrain.yaml new file mode 100644 index 0000000..6563464 --- /dev/null +++ b/config/pretrain.yaml @@ -0,0 +1,53 @@ +model_path: pretrained_model/fudoki +pretrain_model_path: /cache/xyf_model/ +text_embedding_path: pretrained_model/fudoki/text_embedding.pt + +stage: s2 +txt_max_length: 500 +new_embedding_path: "" + +pretrain_path: "" +ckpt_path: "" +train_llm_emb: false + +data_list: [ + "path/to/llava_v1_5_mix665k_2.jsonl" + "path/to/dataset_coda_lm.jsonl", + "path/to/dataset_drivegpt4.jsonl", + "path/to/dataset_lingoqa.jsonl", + "path/to/dataset_maplm.jsonl", + "path/to/dataset_nuscenes_qa.jsonl", + "path/to/dataset_omnidrive.jsonl", + "path/to/dataset_senna.jsonl", + "path/to/dataset_talk2car.jsonl", + "path/to/drivelm_change_box_type.jsonl", + + "path/to/nuplan_recogdrive.jsonl", + "path/to/navsim_recogdrive.jsonl", + "path/to/navsim_recogdrive.jsonl", + + "path/to/navsim_668k.jsonl", + "path/to/navsim_103k.jsonl", + "path/to/nuscenes_train.jsonl", +] + +mixed_precision: "no" +accumulate_grad_batches: 4 +max_grad_norm: 1.0 +seed: 42 +source_distribution: "uniform" +vocab_size: 102400 +batch_size: 1 +dataloader_num_workers: 8 +learning_rate: 1e-5 +lr_scheduler_type: "constant" # cosine, constant +lr_warmup_steps: 0 +max_train_steps: 80000 +max_epochs: 10 +uncond_prob: 0 +use_quantize: true +random_seed: true +l2_loss_weight: 0 + +checkpointing_steps: 2500 +checkpoints_total_limit: 50 \ No newline at end of file diff --git a/config/sft_navsim.yaml b/config/sft_navsim.yaml new file mode 100644 index 0000000..54264a7 --- /dev/null +++ b/config/sft_navsim.yaml @@ -0,0 +1,36 @@ +model_path: pretrained_model/fudoki +text_embedding_path: pretrained_model/fudoki/text_embedding.pt +pretrain_model_path: pretrained_model/wam-flow/navsim + +stage: s2 +txt_max_length: 500 +new_embedding_path: "" + +pretrain_path: "" +ckpt_path: "" +train_llm_emb: false + +data_list: [ + "path/to/navsim_668k.jsonl", +] + +mixed_precision: "no" +accumulate_grad_batches: 1 +max_grad_norm: 1.0 +seed: 42 +source_distribution: "uniform" +vocab_size: 102400 +batch_size: 1 +dataloader_num_workers: 8 +learning_rate: 5e-6 +lr_scheduler_type: "cosine" # "constant", "cosine" +lr_warmup_steps: 500 +max_train_steps: 40000 +max_epochs: 100 +uncond_prob: 0 +use_quantize: true +random_seed: false +l2_loss_weight: 0 + +checkpointing_steps: 4000 +checkpoints_total_limit: 50 \ No newline at end of file diff --git a/config/sft_nuscenes.yaml b/config/sft_nuscenes.yaml new file mode 100644 index 0000000..17848ab --- /dev/null +++ b/config/sft_nuscenes.yaml @@ -0,0 +1,36 @@ +model_path: pretrained_model/fudoki +pretrain_model_path: /cache/xyf_model/ +text_embedding_path: /cache/models/LucasJinWang-FUDOKI/text_embedding.pt + +stage: s2 +txt_max_length: 500 +new_embedding_path: "" + +pretrain_path: "" +ckpt_path: "" +train_llm_emb: false + +data_list: [ + "path/to/nuscenes_train.jsonl", +] + +mixed_precision: "no" +accumulate_grad_batches: 1 +max_grad_norm: 1.0 +seed: 42 +source_distribution: "uniform" +vocab_size: 102400 +batch_size: 1 +dataloader_num_workers: 8 +learning_rate: 5e-6 +lr_scheduler_type: "cosine" # "constant", "cosine" +lr_warmup_steps: 100 +max_train_steps: 10000 +max_epochs: 100 +uncond_prob: 0 +use_quantize: true +random_seed: false +l2_loss_weight: 0 + +checkpointing_steps: 1000 +checkpoints_total_limit: 50 \ No newline at end of file diff --git a/data/navsim_data/sensor_blobs/test/2021.09.09.17.18.51_veh-48_00889_01147/CAM_F0/9a6f0331d98258a0.jpg b/data/navsim_data/sensor_blobs/test/2021.09.09.17.18.51_veh-48_00889_01147/CAM_F0/9a6f0331d98258a0.jpg new file mode 100644 index 0000000..786c25e Binary files /dev/null and b/data/navsim_data/sensor_blobs/test/2021.09.09.17.18.51_veh-48_00889_01147/CAM_F0/9a6f0331d98258a0.jpg differ diff --git a/data/navsim_data/sensor_blobs/trainval/2021.06.09.18.23.43_veh-35_03190_03392/CAM_F0/f997082b36a65b27.jpg b/data/navsim_data/sensor_blobs/trainval/2021.06.09.18.23.43_veh-35_03190_03392/CAM_F0/f997082b36a65b27.jpg new file mode 100644 index 0000000..6a80392 Binary files /dev/null and b/data/navsim_data/sensor_blobs/trainval/2021.06.09.18.23.43_veh-35_03190_03392/CAM_F0/f997082b36a65b27.jpg differ diff --git a/data/navsim_data/sensor_blobs/trainval/2021.06.09.19.40.26_veh-12_00279_01212/CAM_F0/e69a584baa1a571f.jpg b/data/navsim_data/sensor_blobs/trainval/2021.06.09.19.40.26_veh-12_00279_01212/CAM_F0/e69a584baa1a571f.jpg new file mode 100644 index 0000000..a90ead4 Binary files /dev/null and b/data/navsim_data/sensor_blobs/trainval/2021.06.09.19.40.26_veh-12_00279_01212/CAM_F0/e69a584baa1a571f.jpg differ diff --git a/data/navsim_data/sensor_blobs/trainval/2021.06.23.14.06.20_veh-26_00020_01142/CAM_F0/ab90c429b6c851ae.jpg b/data/navsim_data/sensor_blobs/trainval/2021.06.23.14.06.20_veh-26_00020_01142/CAM_F0/ab90c429b6c851ae.jpg new file mode 100644 index 0000000..49fc0aa Binary files /dev/null and b/data/navsim_data/sensor_blobs/trainval/2021.06.23.14.06.20_veh-26_00020_01142/CAM_F0/ab90c429b6c851ae.jpg differ diff --git a/data/navsim_data/sensor_blobs/trainval/2021.07.09.23.23.48_veh-26_02228_04624/CAM_F0/12bfa1b0a8615565.jpg b/data/navsim_data/sensor_blobs/trainval/2021.07.09.23.23.48_veh-26_02228_04624/CAM_F0/12bfa1b0a8615565.jpg new file mode 100644 index 0000000..3a107ee Binary files /dev/null and b/data/navsim_data/sensor_blobs/trainval/2021.07.09.23.23.48_veh-26_02228_04624/CAM_F0/12bfa1b0a8615565.jpg differ diff --git a/data/navsim_data/sensor_blobs/trainval/2021.07.16.00.51.05_veh-17_01352_01901/CAM_F0/6c1a5fe095f95f1e.jpg b/data/navsim_data/sensor_blobs/trainval/2021.07.16.00.51.05_veh-17_01352_01901/CAM_F0/6c1a5fe095f95f1e.jpg new file mode 100644 index 0000000..7267ca2 Binary files /dev/null and b/data/navsim_data/sensor_blobs/trainval/2021.07.16.00.51.05_veh-17_01352_01901/CAM_F0/6c1a5fe095f95f1e.jpg differ diff --git a/data/navsim_data/sensor_blobs/trainval/2021.07.16.16.01.30_veh-38_02497_03871/CAM_F0/c2a1255ee2935f44.jpg b/data/navsim_data/sensor_blobs/trainval/2021.07.16.16.01.30_veh-38_02497_03871/CAM_F0/c2a1255ee2935f44.jpg new file mode 100644 index 0000000..4d24cc4 Binary files /dev/null and b/data/navsim_data/sensor_blobs/trainval/2021.07.16.16.01.30_veh-38_02497_03871/CAM_F0/c2a1255ee2935f44.jpg differ diff --git a/data/navsim_data/sensor_blobs/trainval/2021.07.16.18.49.56_veh-26_00833_03384/CAM_F0/bcf630737fbf5d29.jpg b/data/navsim_data/sensor_blobs/trainval/2021.07.16.18.49.56_veh-26_00833_03384/CAM_F0/bcf630737fbf5d29.jpg new file mode 100644 index 0000000..11c3237 Binary files /dev/null and b/data/navsim_data/sensor_blobs/trainval/2021.07.16.18.49.56_veh-26_00833_03384/CAM_F0/bcf630737fbf5d29.jpg differ diff --git a/data/navsim_data/sensor_blobs/trainval/2021.10.11.02.57.41_veh-50_00352_00535/CAM_F0/e9c0df7e1fae5b0c.jpg b/data/navsim_data/sensor_blobs/trainval/2021.10.11.02.57.41_veh-50_00352_00535/CAM_F0/e9c0df7e1fae5b0c.jpg new file mode 100644 index 0000000..66edd85 Binary files /dev/null and b/data/navsim_data/sensor_blobs/trainval/2021.10.11.02.57.41_veh-50_00352_00535/CAM_F0/e9c0df7e1fae5b0c.jpg differ diff --git a/flow_matching/__init__.py b/flow_matching/__init__.py new file mode 100644 index 0000000..7975227 --- /dev/null +++ b/flow_matching/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +__version__ = "1.0.9" diff --git a/flow_matching/data/__init__.py b/flow_matching/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/flow_matching/data/navsim.py b/flow_matching/data/navsim.py new file mode 100644 index 0000000..d6d2fdf --- /dev/null +++ b/flow_matching/data/navsim.py @@ -0,0 +1,261 @@ +# LINT_ME +import os +import json +import random + +import numpy as np +from PIL import Image +import torch +from torch.utils.data import Dataset +from torchvision import transforms + +from fudoki.janus.models import VLChatProcessor + + +VOCABULARY_SIZE_TXT = 102400 +VOCABULARY_SIZE_IMG = 16384 +IMG_LEN = 576 + + +def resize_pad(image, image_size=384): + w, h = image.size + if w <= 0 or h <= 0: + return image.resize((image_size, image_size), Image.Resampling.BILINEAR) + + resize_scale = image_size / max(w, h) + new_w = max(1, int(w * resize_scale)) + new_h = max(1, int(h * resize_scale)) + + padding_color = (127, 127, 127) + new_image = Image.new('RGB', (image_size, image_size), padding_color) + + if new_w <= 0 or new_h <= 0: + return image.resize((image_size, image_size), Image.Resampling.BILINEAR) + + image = image.resize((new_w, new_h), Image.Resampling.BILINEAR) + + paste_x = (image_size - new_w) // 2 + paste_y = (image_size - new_h) // 2 + + new_image.paste(image, (paste_x, paste_y)) + return new_image + + +class SupervisedDataset(Dataset): + """ + Dataset for supervised training on image-text conversation data. + + Loads image-text conversation samples from JSON/JSONL files, processes images (resize/pad/normalize), + tokenizes text prompts with image placeholders, and formats data for training. + """ + def __init__( + self, + data_list: list, + vl_chat_processor: VLChatProcessor, + txt_max_length=500 + ): + super().__init__() + self.vl_chat_processor = vl_chat_processor + self.txt_max_length = txt_max_length + self.list_data_dict = [] + + self.split_token = None + + self.transform_img = transforms.Compose([ + transforms.Lambda(resize_pad), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) + ]) + + for path in data_list: + jsonl_list = [] + ext = os.path.splitext(path)[-1].lower() + + with open(path, "r", encoding="utf-8") as f: + if ext == ".jsonl": + for line in f: + line = line.strip() + if line: + jsonl_list.append(json.loads(line)) + elif ext == ".json": + data = json.load(f) + if isinstance(data, list): + jsonl_list.extend(data) + else: + jsonl_list.append(data) + else: + raise ValueError(f"Unsupported file extension: {ext}") + + self.list_data_dict.extend(jsonl_list) + + + def __len__(self): + return len(self.list_data_dict) + + def __getitem__(self, i): + try: + sample = self._get_item(i) + return sample + except Exception as e: + print(f"[Warning] Error loading index {i}: {e}") + # random choice + rand_idx = np.random.randint(0, len(self.list_data_dict)) + print(f"[Retry] Loading random index {rand_idx} instead.") + return self.__getitem__(rand_idx) + + + def _get_item(self, i): + item = self.list_data_dict[i] + + if "image" not in item: + raise ValueError("Currently only image-based samples are supported.") + + image_list = item["image"] + conversation = item["conversations"] + + if conversation[0]["from"] == "system": + addition_system_prompt = conversation[0]["value"] + conversation = conversation[1:] + else: + addition_system_prompt = "" + + # conv_length = len(conversation) + # assert conv_length == 2, "only support single turn" + + conversation = self.construct_conv(conversation) + + data_dict = self.process_image_item( + image_list, + conversation, + system_prompt=addition_system_prompt, + txt_max_length=self.txt_max_length + ) + return data_dict + + def construct_conv(self, conv): + _conv = [] + for item in conv: + role = item["from"] + content = item["value"] + _item = {} + + if role == "human": + _item["role"] = "User" + elif role == "gpt": + _item["role"] = "Assistant" + else: + raise ValueError("role must be human or gpt") + + if "" in content: + content = content.replace("", "") + + _item["content"] = content + + _conv.append(_item) + + return _conv + + def _find_split_token(self, input_ids, split_token_length): + # start index for "Assistant:" + start_index = -1 + for j in range(len(input_ids) - split_token_length, 0, -1): + if input_ids[j:j + split_token_length].numpy().tolist() == self.split_token: + start_index = j + break + return start_index + + def process_image_item( + self, + image_paths, + conversation, + system_prompt="", + txt_max_length=500 + ): + imgs = [] + if isinstance(image_paths, str): + image_paths = [image_paths] + + for path in image_paths: + img = Image.open(path).convert("RGB") + imgs.append(self.transform_img(img)) + + if len(imgs) > 0: + imgs = torch.stack(imgs, dim=0) # [N, C, H, W] + img_len = len(imgs) * IMG_LEN + else: + imgs = None + img_len = 0 # default + + generation_understanding_indicator = 0 + + sft_format = self.vl_chat_processor.apply_sft_template_for_multi_turn_prompts( + conversations=conversation, + sft_format=self.vl_chat_processor.sft_format, + system_prompt=self.vl_chat_processor.system_prompt + system_prompt, + ) + + # tokenize + input_ids = self.vl_chat_processor.tokenizer.encode(sft_format) + input_ids = torch.LongTensor(input_ids) + + # add image tokens to the input_ids + image_token_mask = input_ids == self.vl_chat_processor.image_id + image_indices = image_token_mask.nonzero() + assert len(image_indices) == len(image_paths), \ + f"Number of images ({len(image_paths)}) does not match number of image tokens ({len(image_indices)})" + + input_ids, _ = self.vl_chat_processor.add_image_token( + image_indices=image_indices, + input_ids=input_ids, + ) + + # pad tokens + if input_ids.shape[0] >= txt_max_length + img_len: + rows_to_pad = random.randint(0, 50) + else: + rows_to_pad = txt_max_length + img_len - input_ids.shape[0] + input_ids = torch.cat([input_ids, torch.LongTensor([self.vl_chat_processor.pad_id]).repeat(rows_to_pad)], dim=0) + attention_mask = torch.zeros((input_ids.shape[0]), dtype=torch.bool) + attention_mask[:] = True + + # obtain image token mask and fill in img token_ids + if imgs is not None: + image_expanded_token_mask = (input_ids == self.vl_chat_processor.image_id).to(dtype=int) + image_expanded_mask_indices = torch.where(image_expanded_token_mask == 1)[0] + input_ids[image_expanded_mask_indices] = 0 + else: + image_expanded_token_mask = torch.zeros_like(input_ids) + + # obtain text token mask + # support multi turn, indicating the last one + if self.split_token is None: + self.split_token = self.vl_chat_processor.tokenizer.encode("Assistant:", add_special_tokens=False) + split_token_length = len(self.split_token) + start_index = self._find_split_token(input_ids, split_token_length) + + text_expanded_token_mask = torch.zeros_like(image_expanded_token_mask) + if start_index != -1: + text_expanded_token_mask[(start_index+split_token_length):] = 1 + else: + raise ValueError("Split token not found in input_ids") + + + generation_or_understanding_mask = generation_understanding_indicator + data_info = {} + data_info['text_token_mask'] = text_expanded_token_mask + data_info['image_token_mask'] = image_expanded_token_mask + data_info['generation_or_understanding_mask'] = torch.Tensor([generation_or_understanding_mask]) + + data_info['attention_mask'] = attention_mask + data_info['sft_format'] = sft_format + + data_info['understanding_img'] = imgs + data_info['has_understanding_img'] = torch.Tensor([True]).to(dtype=int) + + data_info["input_ids"] = torch.LongTensor(input_ids) + + # print("\n\n\n", sft_format) + # target = self.vl_chat_processor.tokenizer.batch_decode(input_ids[text_expanded_token_mask == 1]) + # print("\n \n", ''.join(target).strip()) + # exit() + return data_info diff --git a/flow_matching/data/num.py b/flow_matching/data/num.py new file mode 100644 index 0000000..34338fd --- /dev/null +++ b/flow_matching/data/num.py @@ -0,0 +1,25 @@ +import torch +from torch.utils.data import Dataset +import numpy as np +import random + +class SupervisedDataset(Dataset): + def __init__(self, vl_chat_processor, seq_len=512, total_samples=100000, + min_num=-100, max_num=100, interval=0.01): + self.tokenizer = vl_chat_processor.tokenizer + self.seq_len = seq_len + self.total_samples = total_samples + self.all_num = np.linspace(min_num, max_num, int((max_num - min_num) / interval) + 1) + self.min_num = min_num + self.max_num = max_num + self.interval = interval + + def __len__(self): + return self.total_samples + + def __getitem__(self, idx): + sampled_nums = random.sample(list(self.all_num), self.seq_len) + sampled_tokens = [f"{x:.2f}" for x in sampled_nums] + token_ids = self.tokenizer.encode(" ".join(sampled_tokens), add_special_tokens=False) + token_ids = torch.tensor(token_ids, dtype=torch.long) + return {"input_ids": token_ids} diff --git a/flow_matching/loss/__init__.py b/flow_matching/loss/__init__.py new file mode 100644 index 0000000..24ec1a9 --- /dev/null +++ b/flow_matching/loss/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from .generalized_loss import MixturePathGeneralizedKL + +__all__ = [ + "MixturePathGeneralizedKL", +] diff --git a/flow_matching/loss/generalized_loss.py b/flow_matching/loss/generalized_loss.py new file mode 100644 index 0000000..cc1507e --- /dev/null +++ b/flow_matching/loss/generalized_loss.py @@ -0,0 +1,80 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor +from torch.nn.modules.loss import _Loss + +from flow_matching.path import MixtureDiscreteProbPath + + +class MixturePathGeneralizedKL(_Loss): + r"""A generalized KL loss for discrete flow matching. + A class that measures the generalized KL of a discrete flow model :math:`p_{1|t}` w.r.t. a probability path given by ``path``. Note: this class is assuming that the model is trained on the same path. + + For a model trained on a space :math:`\mathcal{S} = \mathcal{T}^d`, :math:`\mathcal{T} = [K] = \set{1,2,\ldots,K}`, the loss is given by + + .. math:: + \ell_i(x_1, x_t, t) = -\frac{\dot{\kappa}_t}{1-\kappa_t} \biggr[ p_{1|t}(x_t^i|x_t) -\delta_{x^i_1}(x_t^i) + (1-\delta_{x^i_1}(x_t^i))\left(\log p_{1|t}(x_1^i|x_t)\right)\biggr], + + where :math:`\kappa_t` is the scheduler associated with ``path``. + + Args: + path (MixtureDiscreteProbPath): Probability path (x-prediction training). + reduction (str, optional): Specify the reduction to apply to the output ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction is applied to the output, ``'mean'``: the output is reduced by mean over sequence elements, ``'sum'``: the output is reduced by sum over sequence elements. Defaults to 'mean'. + """ + + def __init__(self, path: MixtureDiscreteProbPath, reduction: str = "mean") -> None: + super().__init__(None, None, reduction) + self.path = path + + def forward(self, logits: Tensor, x_1: Tensor, x_t: Tensor, t: Tensor) -> Tensor: + r"""Evaluates the generalized KL loss. + + Args: + logits (Tensor): posterior model output (i.e., softmax(``logits``) :math:`=p_{1|t}(x|x_t)`), shape (batch, d, K). + x_1 (Tensor): target data point :math:`x_1 \sim q`, shape (batch, d). + x_t (Tensor): conditional sample at :math:`x_t \sim p_t(\cdot|x_1)`, shape (batch, d). + t (Tensor): times in :math:`[0,1]`, shape (batch). + + Raises: + ValueError: reduction value must be one of ``'none'`` | ``'mean'`` | ``'sum'``. + + Returns: + Tensor: Generalized KL loss. + """ + x_1_shape = x_1.shape + + # extract x_1 value of log(p_{1|t}(x|x_t)). + log_p_1t = torch.log_softmax(logits, dim=-1) + log_p_1t_x1 = torch.gather(log_p_1t, dim=-1, index=x_1.unsqueeze(-1)) + log_p_1t_x1 = log_p_1t_x1.view(*x_1_shape) + + # extract x_t value of p_{1|t}(x|x_t). + p_1t = torch.exp(log_p_1t) + p_1t_xt = torch.gather(p_1t, dim=-1, index=x_t.unsqueeze(-1)) + p_1t_xt = p_1t_xt.view(*x_1_shape) + + scheduler_output = self.path.scheduler(t) + + jump_coefficient = ( + scheduler_output.d_alpha_t / (1 - scheduler_output.alpha_t) + )[(...,) + (None,) * (x_1.dim() - 1)] + jump_coefficient = jump_coefficient.repeat(1, *x_1_shape[1:]) + delta_x1_xt = (x_t == x_1).to(log_p_1t.dtype) + + loss = -jump_coefficient * ( + p_1t_xt - delta_x1_xt + (1 - delta_x1_xt) * log_p_1t_x1 + ) + + if self.reduction == "mean": + return torch.mean(loss) + elif self.reduction == "sum": + return torch.sum(loss) + elif self.reduction == "none": + return loss + else: + raise ValueError(f"{self.reduction} is not a valid value for reduction") diff --git a/flow_matching/path/__init__.py b/flow_matching/path/__init__.py new file mode 100644 index 0000000..c33cd0d --- /dev/null +++ b/flow_matching/path/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from .affine import AffineProbPath, CondOTProbPath +from .geodesic import GeodesicProbPath +from .mixture import MixtureDiscreteProbPath, MixtureDiscreteSoftmaxProbPath +from .path import ProbPath +from .path_sample import DiscretePathSample, PathSample + + +__all__ = [ + "ProbPath", + "AffineProbPath", + "CondOTProbPath", + "MixtureDiscreteProbPath", + "GeodesicProbPath", + "PathSample", + "DiscretePathSample", +] diff --git a/flow_matching/path/affine.py b/flow_matching/path/affine.py new file mode 100644 index 0000000..81cb7ed --- /dev/null +++ b/flow_matching/path/affine.py @@ -0,0 +1,260 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from torch import Tensor + +from flow_matching.path.path import ProbPath +from flow_matching.path.path_sample import PathSample +from flow_matching.path.scheduler.scheduler import CondOTScheduler, Scheduler +from flow_matching.utils import expand_tensor_like + + +class AffineProbPath(ProbPath): + r"""The ``AffineProbPath`` class represents a specific type of probability path where the transformation between distributions is affine. + An affine transformation can be represented as: + + .. math:: + + X_t = \alpha_t X_1 + \sigma_t X_0, + + where :math:`X_t` is the transformed data point at time `t`. :math:`X_0` and :math:`X_1` are the source and target data points, respectively. :math:`\alpha_t` and :math:`\sigma_t` are the parameters of the affine transformation at time `t`. + + The scheduler is responsible for providing the time-dependent parameters :math:`\alpha_t` and :math:`\sigma_t`, as well as their derivatives, which define the affine transformation at any given time `t`. + + Using ``AffineProbPath`` in the flow matching framework: + + .. code-block:: python + + # Instantiates a probability path + my_path = AffineProbPath(...) + mse_loss = torch.nn.MSELoss() + + for x_1 in dataset: + # Sets x_0 to random noise + x_0 = torch.randn() + + # Sets t to a random value in [0,1] + t = torch.rand() + + # Samples the conditional path X_t ~ p_t(X_t|X_0,X_1) + path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t) + + # Computes the MSE loss w.r.t. the velocity + loss = mse_loss(path_sample.dx_t, my_model(x_t, t)) + loss.backward() + + Args: + scheduler (Scheduler): An instance of a scheduler that provides the parameters :math:`\alpha_t`, :math:`\sigma_t`, and their derivatives over time. + + """ + + def __init__(self, scheduler: Scheduler): + self.scheduler = scheduler + + def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample: + r"""Sample from the affine probability path: + + | given :math:`(X_0,X_1) \sim \pi(X_0,X_1)` and a scheduler :math:`(\alpha_t,\sigma_t)`. + | return :math:`X_0, X_1, X_t = \alpha_t X_1 + \sigma_t X_0`, and the conditional velocity at :math:`X_t, \dot{X}_t = \dot{\alpha}_t X_1 + \dot{\sigma}_t X_0`. + + Args: + x_0 (Tensor): source data point, shape (batch_size, ...). + x_1 (Tensor): target data point, shape (batch_size, ...). + t (Tensor): times in [0,1], shape (batch_size). + + Returns: + PathSample: a conditional sample at :math:`X_t \sim p_t`. + """ + self.assert_sample_shape(x_0=x_0, x_1=x_1, t=t) + + scheduler_output = self.scheduler(t) + + alpha_t = expand_tensor_like( + input_tensor=scheduler_output.alpha_t, expand_to=x_1 + ) + sigma_t = expand_tensor_like( + input_tensor=scheduler_output.sigma_t, expand_to=x_1 + ) + d_alpha_t = expand_tensor_like( + input_tensor=scheduler_output.d_alpha_t, expand_to=x_1 + ) + d_sigma_t = expand_tensor_like( + input_tensor=scheduler_output.d_sigma_t, expand_to=x_1 + ) + + # construct xt ~ p_t(x|x1). + x_t = sigma_t * x_0 + alpha_t * x_1 + dx_t = d_sigma_t * x_0 + d_alpha_t * x_1 + + return PathSample(x_t=x_t, dx_t=dx_t, x_1=x_1, x_0=x_0, t=t) + + def target_to_velocity(self, x_1: Tensor, x_t: Tensor, t: Tensor) -> Tensor: + r"""Convert from x_1 representation to velocity. + + | given :math:`X_1`. + | return :math:`\dot{X}_t`. + + Args: + x_1 (Tensor): target data point. + x_t (Tensor): path sample at time t. + t (Tensor): time in [0,1]. + + Returns: + Tensor: velocity. + """ + scheduler_output = self.scheduler(t) + + alpha_t = scheduler_output.alpha_t + d_alpha_t = scheduler_output.d_alpha_t + sigma_t = scheduler_output.sigma_t + d_sigma_t = scheduler_output.d_sigma_t + + a_t = d_sigma_t / sigma_t + b_t = (d_alpha_t * sigma_t - d_sigma_t * alpha_t) / sigma_t + + return a_t * x_t + b_t * x_1 + + def epsilon_to_velocity(self, epsilon: Tensor, x_t: Tensor, t: Tensor) -> Tensor: + r"""Convert from epsilon representation to velocity. + + | given :math:`\epsilon`. + | return :math:`\dot{X}_t`. + + Args: + epsilon (Tensor): noise in the path sample. + x_t (Tensor): path sample at time t. + t (Tensor): time in [0,1]. + + Returns: + Tensor: velocity. + """ + scheduler_output = self.scheduler(t) + + alpha_t = scheduler_output.alpha_t + d_alpha_t = scheduler_output.d_alpha_t + sigma_t = scheduler_output.sigma_t + d_sigma_t = scheduler_output.d_sigma_t + + a_t = d_alpha_t / alpha_t + b_t = (d_sigma_t * alpha_t - d_alpha_t * sigma_t) / alpha_t + + return a_t * x_t + b_t * epsilon + + def velocity_to_target(self, velocity: Tensor, x_t: Tensor, t: Tensor) -> Tensor: + r"""Convert from velocity to x_1 representation. + + | given :math:`\dot{X}_t`. + | return :math:`X_1`. + + Args: + velocity (Tensor): velocity at the path sample. + x_t (Tensor): path sample at time t. + t (Tensor): time in [0,1]. + + Returns: + Tensor: target data point. + """ + scheduler_output = self.scheduler(t) + + alpha_t = scheduler_output.alpha_t + d_alpha_t = scheduler_output.d_alpha_t + sigma_t = scheduler_output.sigma_t + d_sigma_t = scheduler_output.d_sigma_t + + a_t = -d_sigma_t / (d_alpha_t * sigma_t - d_sigma_t * alpha_t) + b_t = sigma_t / (d_alpha_t * sigma_t - d_sigma_t * alpha_t) + + return a_t * x_t + b_t * velocity + + def epsilon_to_target(self, epsilon: Tensor, x_t: Tensor, t: Tensor) -> Tensor: + r"""Convert from epsilon representation to x_1 representation. + + | given :math:`\epsilon`. + | return :math:`X_1`. + + Args: + epsilon (Tensor): noise in the path sample. + x_t (Tensor): path sample at time t. + t (Tensor): time in [0,1]. + + Returns: + Tensor: target data point. + """ + scheduler_output = self.scheduler(t) + + alpha_t = scheduler_output.alpha_t + sigma_t = scheduler_output.sigma_t + + a_t = 1 / alpha_t + b_t = -sigma_t / alpha_t + + return a_t * x_t + b_t * epsilon + + def velocity_to_epsilon(self, velocity: Tensor, x_t: Tensor, t: Tensor) -> Tensor: + r"""Convert from velocity to noise representation. + + | given :math:`\dot{X}_t`. + | return :math:`\epsilon`. + + Args: + velocity (Tensor): velocity at the path sample. + x_t (Tensor): path sample at time t. + t (Tensor): time in [0,1]. + + Returns: + Tensor: noise in the path sample. + """ + scheduler_output = self.scheduler(t) + + alpha_t = scheduler_output.alpha_t + d_alpha_t = scheduler_output.d_alpha_t + sigma_t = scheduler_output.sigma_t + d_sigma_t = scheduler_output.d_sigma_t + + a_t = -d_alpha_t / (d_sigma_t * alpha_t - d_alpha_t * sigma_t) + b_t = alpha_t / (d_sigma_t * alpha_t - d_alpha_t * sigma_t) + + return a_t * x_t + b_t * velocity + + def target_to_epsilon(self, x_1: Tensor, x_t: Tensor, t: Tensor) -> Tensor: + r"""Convert from x_1 representation to velocity. + + | given :math:`X_1`. + | return :math:`\epsilon`. + + Args: + x_1 (Tensor): target data point. + x_t (Tensor): path sample at time t. + t (Tensor): time in [0,1]. + + Returns: + Tensor: noise in the path sample. + """ + scheduler_output = self.scheduler(t) + + alpha_t = scheduler_output.alpha_t + sigma_t = scheduler_output.sigma_t + + a_t = 1 / sigma_t + b_t = -alpha_t / sigma_t + + return a_t * x_t + b_t * x_1 + + +class CondOTProbPath(AffineProbPath): + r"""The ``CondOTProbPath`` class represents a conditional optimal transport probability path. + + This class is a specialized version of the ``AffineProbPath`` that uses a conditional optimal transport scheduler to determine the parameters of the affine transformation. + + The parameters :math:`\alpha_t` and :math:`\sigma_t` for the conditional optimal transport path are defined as: + + .. math:: + + \alpha_t = t \quad \text{and} \quad \sigma_t = 1 - t. + """ + + def __init__(self): + self.scheduler = CondOTScheduler() diff --git a/flow_matching/path/geodesic.py b/flow_matching/path/geodesic.py new file mode 100644 index 0000000..d04bf67 --- /dev/null +++ b/flow_matching/path/geodesic.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from torch import Tensor +from torch.func import jvp, vmap + +from flow_matching.path.path import ProbPath + +from flow_matching.path.path_sample import PathSample +from flow_matching.path.scheduler import ConvexScheduler +from flow_matching.utils import expand_tensor_like + +from flow_matching.utils.manifolds import geodesic, Manifold + + +class GeodesicProbPath(ProbPath): + r"""The ``GeodesicProbPath`` class represents a specific type of probability path where the transformation between distributions is defined through the geodesic path. + Mathematically, a geodesic path can be represented as: + + .. math:: + + X_t = \psi_t(X_0 | X_1) = \exp_{X_1}(\kappa_t \log_{X_1}(X_0)), + + where :math:`X_t` is the transformed data point at time `t`, :math:`X_0` and :math:`X_1` are the source and target data points, respectively, and :math:`\kappa_t` is a scheduler. + + The scheduler is responsible for providing the time-dependent :math:`\kappa_t` and must be differentiable. + + Using ``GeodesicProbPath`` in the flow matching framework: + + .. code-block:: python + # Instantiates a manifold + manifold = FlatTorus() + + # Instantiates a scheduler + scheduler = CondOTScheduler() + + # Instantiates a probability path + my_path = GeodesicProbPath(scheduler, manifold) + mse_loss = torch.nn.MSELoss() + + for x_1 in dataset: + # Sets x_0 to random noise + x_0 = torch.randn() + + # Sets t to a random value in [0,1] + t = torch.rand() + + # Samples the conditional path :math:`X_t \sim p_t(X_t|X_0,X_1)` + path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t) + + # Computes the MSE loss w.r.t. the velocity + loss = mse_loss(path_sample.dx_t, my_model(x_t, t)) + loss.backward() + + Args: + scheduler (ConvexScheduler): The scheduler that provides :math:`\kappa_t`. + manifold (Manifold): The manifold on which the probability path is defined. + + """ + + def __init__(self, scheduler: ConvexScheduler, manifold: Manifold): + self.scheduler = scheduler + self.manifold = manifold + + def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample: + r"""Sample from the Riemannian probability path with geodesic interpolation: + + | given :math:`(X_0,X_1) \sim \pi(X_0,X_1)` and a scheduler :math:`\kappa_t`. + | return :math:`X_0, X_1, X_t = \exp_{X_1}(\kappa_t \log_{X_1}(X_0))`, and the conditional velocity at :math:`X_t, \dot{X}_t`. + + Args: + x_0 (Tensor): source data point, shape (batch_size, ...). + x_1 (Tensor): target data point, shape (batch_size, ...). + t (Tensor): times in [0,1], shape (batch_size). + + Returns: + PathSample: A conditional sample at :math:`X_t \sim p_t`. + """ + self.assert_sample_shape(x_0=x_0, x_1=x_1, t=t) + t = expand_tensor_like(input_tensor=t, expand_to=x_1[..., 0:1]).clone() + + def cond_u(x_0, x_1, t): + path = geodesic(self.manifold, x_0, x_1) + x_t, dx_t = jvp( + lambda t: path(self.scheduler(t).alpha_t), + (t,), + (torch.ones_like(t).to(t),), + ) + return x_t, dx_t + + x_t, dx_t = vmap(cond_u)(x_0, x_1, t) + x_t = x_t.reshape_as(x_1) + dx_t = dx_t.reshape_as(x_1) + + return PathSample(x_t=x_t, dx_t=dx_t, x_1=x_1, x_0=x_0, t=t) diff --git a/flow_matching/path/mixture.py b/flow_matching/path/mixture.py new file mode 100644 index 0000000..f04c01a --- /dev/null +++ b/flow_matching/path/mixture.py @@ -0,0 +1,190 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F + +from torch import Tensor + +from flow_matching.path.path import ProbPath + +from flow_matching.path.path_sample import DiscretePathSample +from flow_matching.path.scheduler import ConvexScheduler +from flow_matching.utils import expand_tensor_like, unsqueeze_to_match + + +class MixtureDiscreteProbPath(ProbPath): + r"""The ``MixtureDiscreteProbPath`` class defines a factorized discrete probability path. + + This path remains constant at the source data point :math:`X_0` until a random time, determined by the scheduler, when it flips to the target data point :math:`X_1`. + The scheduler determines the flip probability using the parameter :math:`\sigma_t`, which is a function of time `t`. Specifically, :math:`\sigma_t` represents the probability of remaining at :math:`X_0`, while :math:`1 - \sigma_t` is the probability of flipping to :math:`X_1`: + + .. math:: + + P(X_t = X_0) = \sigma_t \quad \text{and} \quad P(X_t = X_1) = 1 - \sigma_t, + + where :math:`\sigma_t` is provided by the scheduler. + + Example: + + .. code-block:: python + + >>> x_0 = torch.zeros((1, 3, 3)) + >>> x_1 = torch.ones((1, 3, 3)) + + >>> path = MixtureDiscreteProbPath(PolynomialConvexScheduler(n=1.0)) + >>> result = path.sample(x_0, x_1, t=torch.tensor([0.1])).x_t + >>> result + tensor([[[0.0, 0.0, 0.0], + [0.0, 0.0, 1.0], + [0.0, 0.0, 0.0]]]) + + >>> result = path.sample(x_0, x_1, t=torch.tensor([0.5])).x_t + >>> result + tensor([[[1.0, 0.0, 1.0], + [0.0, 1.0, 0.0], + [0.0, 1.0, 0.0]]]) + + >>> result = path.sample(x_0, x_1, t=torch.tensor([1.0])).x_t + >>> result + tensor([[[1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0]]]) + + Args: + scheduler (ConvexScheduler): The scheduler that provides :math:`\sigma_t`. + """ + + def __init__(self, scheduler: ConvexScheduler): + assert isinstance( + scheduler, ConvexScheduler + ), "Scheduler for ConvexProbPath must be a ConvexScheduler." + + self.scheduler = scheduler + + def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> DiscretePathSample: + r"""Sample from the affine probability path: + | given :math:`(X_0,X_1) \sim \pi(X_0,X_1)` and a scheduler :math:`(\alpha_t,\sigma_t)`. + | return :math:`X_0, X_1, t`, and :math:`X_t \sim p_t`. + Args: + x_0 (Tensor): source data point, shape (batch_size, ...). + x_1 (Tensor): target data point, shape (batch_size, ...). + t (Tensor): times in [0,1], shape (batch_size). + + Returns: + DiscretePathSample: a conditional sample at :math:`X_t ~ p_t`. + """ + self.assert_sample_shape(x_0=x_0, x_1=x_1, t=t) + + sigma_t = self.scheduler(t).sigma_t + + sigma_t = expand_tensor_like(input_tensor=sigma_t, expand_to=x_1) + + source_indices = torch.rand(size=x_1.shape, device=x_1.device) < sigma_t + x_t = torch.where(condition=source_indices, input=x_0, other=x_1) + + return DiscretePathSample(x_t=x_t, x_1=x_1, x_0=x_0, t=t) + + def posterior_to_velocity( + self, posterior_logits: Tensor, x_t: Tensor, t: Tensor + ) -> Tensor: + r"""Convert the factorized posterior to velocity. + + | given :math:`p(X_1|X_t)`. In the factorized case: :math:`\prod_i p(X_1^i | X_t)`. + | return :math:`u_t`. + + Args: + posterior_logits (Tensor): logits of the x_1 posterior conditional on x_t, shape (..., vocab size). + x_t (Tensor): path sample at time t, shape (...). + t (Tensor): time in [0,1]. + + Returns: + Tensor: velocity. + """ + posterior = torch.softmax(posterior_logits, dim=-1) + vocabulary_size = posterior.shape[-1] + x_t = F.one_hot(x_t, num_classes=vocabulary_size) + t = unsqueeze_to_match(source=t, target=x_t) + + scheduler_output = self.scheduler(t) + + kappa_t = scheduler_output.alpha_t + d_kappa_t = scheduler_output.d_alpha_t + + return (d_kappa_t / (1 - kappa_t)) * (posterior - x_t) + + +class MixtureDiscreteSoftmaxProbPath(ProbPath): + def __init__(self, mode, embedding_path): + self.a = 0.9 + self.c = 3 + assert mode in ['image', 'text'], f"Unsupported mode probability path: {mode}" + self.mode = mode + self.embedding_path = embedding_path + self.embedding = self.get_embedding(embedding_path) + self.embedding.weight.requires_grad = False + self.embedding = self.embedding + torch.cuda.empty_cache() + + def get_embedding(self, embedding_path): + # with torch.serialization.safe_globals([torch.nn.modules.sparse.Embedding]): + embedding = torch.load(embedding_path, map_location="cpu") + embedding.requires_grad_(False) + torch.cuda.empty_cache() + return embedding.cuda() + + def set_embedding(self, new_embedding): + self.embedding = new_embedding + + def metric(self, z): + z_flattened = z.view(-1, z.shape[-1]) + z = F.normalize(z, p=2, dim=-1) + z_flattened = F.normalize(z_flattened, p=2, dim=-1) + embedding = F.normalize(self.embedding.weight, p=2, dim=-1) + d = (torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(embedding**2, dim=1) - 2 * \ + torch.einsum('bd,dn->bn', z_flattened, torch.einsum('n d -> d n', embedding))) ** 2 + return d + + def get_prob_distribution(self, z, t): + b, s = z.shape[:2] + d = self.metric(z) + d = d.reshape(b, s, -1) + beta_t = self.c * ((t / (1 - t)) ** self.a) + if beta_t.shape[0] == b: + beta_t = beta_t.reshape(b, 1, 1) + # print(beta_t.shape) + d = d * (-1) * beta_t + d = torch.softmax(d, dim=-1) + return d + + def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> DiscretePathSample: + # emb_x_0 = self.self.embedding(x_0).squeeze() # 2, 256, 8 + emb_x_1 = self.embedding(x_1) + # prob_x_0 = self.get_prob_distribution(emb_x_0) + prob_x_t = self.get_prob_distribution(emb_x_1, t) + b, s = prob_x_t.shape[:2] + x_t = torch.multinomial(prob_x_t.reshape(b*s, -1), num_samples=1, replacement=False) + x_t = x_t.reshape(b, s) + return DiscretePathSample(x_t=x_t, x_1=x_1, x_0=x_0, t=t) + + def posterior_to_velocity( + self, posterior_logits: Tensor, x_t: Tensor, t: Tensor + ) -> Tensor: + r"""Convert the factorized posterior to velocity. + + | given :math:`p(X_1|X_t)`. In the factorized case: :math:`\prod_i p(X_1^i | X_t)`. + | return :math:`u_t`. + + Args: + posterior_logits (Tensor): logits of the x_1 posterior conditional on x_t, shape (..., vocab size). + x_t (Tensor): path sample at time t, shape (...). + t (Tensor): time in [0,1]. + + Returns: + Tensor: velocity. + """ + raise NotImplementedError diff --git a/flow_matching/path/path.py b/flow_matching/path/path.py new file mode 100644 index 0000000..c133a14 --- /dev/null +++ b/flow_matching/path/path.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod + +from torch import Tensor + +from flow_matching.path.path_sample import PathSample + + +class ProbPath(ABC): + r"""Abstract class, representing a probability path. + + A probability path transforms the distribution :math:`p(X_0)` into :math:`p(X_1)` over :math:`t=0\rightarrow 1`. + + The ``ProbPath`` class is designed to support model training in the flow matching framework. It supports two key functionalities: (1) sampling the conditional probability path and (2) conversion between various training objectives. + Here is a high-level example + + .. code-block:: python + + # Instantiate a probability path + my_path = ProbPath(...) + + for x_0, x_1 in dataset: + # Sets t to a random value in [0,1] + t = torch.rand() + + # Samples the conditional path X_t ~ p_t(X_t|X_0,X_1) + path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t) + + # Optimizes the model. The loss function varies, depending on model and path. + loss(path_sample, my_model(x_t, t)).backward() + + """ + + @abstractmethod + def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample: + r"""Sample from an abstract probability path: + + | given :math:`(X_0,X_1) \sim \pi(X_0,X_1)`. + | returns :math:`X_0, X_1, X_t \sim p_t(X_t)`, and a conditional target :math:`Y`, all objects are under ``PathSample``. + + Args: + x_0 (Tensor): source data point, shape (batch_size, ...). + x_1 (Tensor): target data point, shape (batch_size, ...). + t (Tensor): times in [0,1], shape (batch_size). + + Returns: + PathSample: a conditional sample. + """ + + def assert_sample_shape(self, x_0: Tensor, x_1: Tensor, t: Tensor): + assert ( + t.ndim == 1 + ), f"The time vector t must have shape [batch_size]. Got {t.shape}." + assert ( + t.shape[0] == x_0.shape[0] == x_1.shape[0] + ), f"Time t dimension must match the batch size [{x_1.shape[0]}]. Got {t.shape}" diff --git a/flow_matching/path/path_sample.py b/flow_matching/path/path_sample.py new file mode 100644 index 0000000..867032e --- /dev/null +++ b/flow_matching/path/path_sample.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +from torch import Tensor + + +@dataclass +class PathSample: + r"""Represents a sample of a conditional-flow generated probability path. + + Attributes: + x_1 (Tensor): the target sample :math:`X_1`. + x_0 (Tensor): the source sample :math:`X_0`. + t (Tensor): the time sample :math:`t`. + x_t (Tensor): samples :math:`X_t \sim p_t(X_t)`, shape (batch_size, ...). + dx_t (Tensor): conditional target :math:`\frac{\partial X}{\partial t}`, shape: (batch_size, ...). + + """ + + x_1: Tensor = field(metadata={"help": "target samples X_1 (batch_size, ...)."}) + x_0: Tensor = field(metadata={"help": "source samples X_0 (batch_size, ...)."}) + t: Tensor = field(metadata={"help": "time samples t (batch_size, ...)."}) + x_t: Tensor = field( + metadata={"help": "samples x_t ~ p_t(X_t), shape (batch_size, ...)."} + ) + dx_t: Tensor = field( + metadata={"help": "conditional target dX_t, shape: (batch_size, ...)."} + ) + + +@dataclass +class DiscretePathSample: + """ + Represents a sample of a conditional-flow generated discrete probability path. + + Attributes: + x_1 (Tensor): the target sample :math:`X_1`. + x_0 (Tensor): the source sample :math:`X_0`. + t (Tensor): the time sample :math:`t`. + x_t (Tensor): the sample along the path :math:`X_t \sim p_t`. + """ + + x_1: Tensor = field(metadata={"help": "target samples X_1 (batch_size, ...)."}) + x_0: Tensor = field(metadata={"help": "source samples X_0 (batch_size, ...)."}) + t: Tensor = field(metadata={"help": "time samples t (batch_size, ...)."}) + x_t: Tensor = field( + metadata={"help": "samples X_t ~ p_t(X_t), shape (batch_size, ...)."} + ) diff --git a/flow_matching/path/scheduler/__init__.py b/flow_matching/path/scheduler/__init__.py new file mode 100644 index 0000000..f3b1a43 --- /dev/null +++ b/flow_matching/path/scheduler/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from .schedule_transform import ScheduleTransformedModel +from .scheduler import ( + CondOTScheduler, + ConvexScheduler, + CosineScheduler, + LinearVPScheduler, + PolynomialConvexScheduler, + Scheduler, + SchedulerOutput, + VPScheduler, +) + +__all__ = [ + "CondOTScheduler", + "CosineScheduler", + "ConvexScheduler", + "PolynomialConvexScheduler", + "ScheduleTransformedModel", + "Scheduler", + "VPScheduler", + "LinearVPScheduler", + "SchedulerOutput", +] diff --git a/flow_matching/path/scheduler/schedule_transform.py b/flow_matching/path/scheduler/schedule_transform.py new file mode 100644 index 0000000..a366f19 --- /dev/null +++ b/flow_matching/path/scheduler/schedule_transform.py @@ -0,0 +1,148 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from torch import Tensor + +from flow_matching.path.scheduler.scheduler import Scheduler +from flow_matching.utils import ModelWrapper + + +class ScheduleTransformedModel(ModelWrapper): + """ + Change of scheduler for a velocity model. + + This class wraps a given velocity model and transforms its scheduling + to a new scheduler function. It modifies the time + dynamics of the model according to the new scheduler while maintaining + the original model's behavior. + + Example: + + .. code-block:: python + + import torch + from flow_matching.path.scheduler import CondOTScheduler, CosineScheduler, ScheduleTransformedModel + from flow_matching.solver import ODESolver + + # Initialize the model and schedulers + model = ... + + original_scheduler = CondOTScheduler() + new_scheduler = CosineScheduler() + + # Create the transformed model + transformed_model = ScheduleTransformedModel( + velocity_model=model, + original_scheduler=original_scheduler, + new_scheduler=new_scheduler + ) + + # Set up the solver + solver = ODESolver(velocity_model=transformed_model) + + x_0 = torch.randn([10, 2]) # Example initial condition + + x_1 = solver.sample( + time_steps=torch.tensor([0.0, 1.0]), + x_init=x_0, + step_size=1/1000 + )[1] + + Args: + velocity_model (ModelWrapper): The original velocity model to be transformed. + original_scheduler (Scheduler): The scheduler used by the original model. Must implement the snr_inverse function. + new_scheduler (Scheduler): The new scheduler to be applied to the model. + """ + + def __init__( + self, + velocity_model: ModelWrapper, + original_scheduler: Scheduler, + new_scheduler: Scheduler, + ): + super().__init__(model=velocity_model) + self.original_scheduler = original_scheduler + self.new_scheduler = new_scheduler + + assert hasattr(self.original_scheduler, "snr_inverse") and callable( + getattr(self.original_scheduler, "snr_inverse") + ), "The original scheduler must have a callable 'snr_inverse' method." + + def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor: + r""" + Compute the transformed marginal velocity field for a new scheduler. + This method implements a post-training velocity scheduler change for + affine conditional flows. It transforms a generating marginal velocity + field :math:`u_t(x)` based on an original scheduler to a new marginal velocity + field :math:`\bar{u}_r(x)` based on a different scheduler, while maintaining + the same data coupling. + The transformation is based on the scale-time (ST) transformation + between the two conditional flows, defined as: + + .. math:: + + \bar{X}_r = s_r X_{t_r}, + + where :math:`X_t` and :math:`\bar{X}_r` are defined by their respective schedulers. + The ST transformation is computed as: + + .. math:: + + t_r = \rho^{-1}(\bar{\rho}(r)) \quad \text{and} \quad s_r = \frac{\bar{\sigma}_r}{\sigma_{t_r}}. + + Here, :math:`\rho(t)` is the signal-to-noise ratio (SNR) defined as: + + .. math:: + + \rho(t) = \frac{\alpha_t}{\sigma_t}. + + :math:`\bar{\rho}(r)` is similarly defined for the new scheduler. + The marginal velocity for the new scheduler is then given by: + + .. math:: + + \bar{u}_r(x) = \left(\frac{\dot{s}_r}{s_r}\right) x + s_r \dot{t}_r u_{t_r}\left(\frac{x}{s_r}\right). + + Args: + x (Tensor): :math:`x_t`, the input tensor. + t (Tensor): The time tensor (denoted as :math:`r` above). + **extras: Additional arguments for the model. + Returns: + Tensor: The transformed velocity. + """ + r = t + + r_scheduler_output = self.new_scheduler(t=r) + + alpha_r = r_scheduler_output.alpha_t + sigma_r = r_scheduler_output.sigma_t + d_alpha_r = r_scheduler_output.d_alpha_t + d_sigma_r = r_scheduler_output.d_sigma_t + + t = self.original_scheduler.snr_inverse(alpha_r / sigma_r) + + t_scheduler_output = self.original_scheduler(t=t) + + alpha_t = t_scheduler_output.alpha_t + sigma_t = t_scheduler_output.sigma_t + d_alpha_t = t_scheduler_output.d_alpha_t + d_sigma_t = t_scheduler_output.d_sigma_t + + s_r = sigma_r / sigma_t + + dt_r = ( + sigma_t + * sigma_t + * (sigma_r * d_alpha_r - alpha_r * d_sigma_r) + / (sigma_r * sigma_r * (sigma_t * d_alpha_t - alpha_t * d_sigma_t)) + ) + + ds_r = (sigma_t * d_sigma_r - sigma_r * d_sigma_t * dt_r) / (sigma_t * sigma_t) + + u_t = self.model(x=x / s_r, t=t, **extras) + u_r = ds_r * x / s_r + dt_r * s_r * u_t + + return u_r diff --git a/flow_matching/path/scheduler/scheduler.py b/flow_matching/path/scheduler/scheduler.py new file mode 100644 index 0000000..422618a --- /dev/null +++ b/flow_matching/path/scheduler/scheduler.py @@ -0,0 +1,199 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field + +from typing import Union + +import torch + +from torch import Tensor + + +@dataclass +class SchedulerOutput: + r"""Represents a sample of a conditional-flow generated probability path. + + Attributes: + alpha_t (Tensor): :math:`\alpha_t`, shape (...). + sigma_t (Tensor): :math:`\sigma_t`, shape (...). + d_alpha_t (Tensor): :math:`\frac{\partial}{\partial t}\alpha_t`, shape (...). + d_sigma_t (Tensor): :math:`\frac{\partial}{\partial t}\sigma_t`, shape (...). + + """ + + alpha_t: Tensor = field(metadata={"help": "alpha_t"}) + sigma_t: Tensor = field(metadata={"help": "sigma_t"}) + d_alpha_t: Tensor = field(metadata={"help": "Derivative of alpha_t."}) + d_sigma_t: Tensor = field(metadata={"help": "Derivative of sigma_t."}) + + +class Scheduler(ABC): + """Base Scheduler class.""" + + @abstractmethod + def __call__(self, t: Tensor) -> SchedulerOutput: + r""" + Args: + t (Tensor): times in [0,1], shape (...). + + Returns: + SchedulerOutput: :math:`\alpha_t,\sigma_t,\frac{\partial}{\partial t}\alpha_t,\frac{\partial}{\partial t}\sigma_t` + """ + ... + + @abstractmethod + def snr_inverse(self, snr: Tensor) -> Tensor: + r""" + Computes :math:`t` from the signal-to-noise ratio :math:`\frac{\alpha_t}{\sigma_t}`. + + Args: + snr (Tensor): The signal-to-noise, shape (...) + + Returns: + Tensor: t, shape (...) + """ + ... + + +class ConvexScheduler(Scheduler): + @abstractmethod + def __call__(self, t: Tensor) -> SchedulerOutput: + """Scheduler for convex paths. + + Args: + t (Tensor): times in [0,1], shape (...). + + Returns: + SchedulerOutput: :math:`\alpha_t,\sigma_t,\frac{\partial}{\partial t}\alpha_t,\frac{\partial}{\partial t}\sigma_t` + """ + ... + + @abstractmethod + def kappa_inverse(self, kappa: Tensor) -> Tensor: + """ + Computes :math:`t` from :math:`\kappa_t`. + + Args: + kappa (Tensor): :math:`\kappa`, shape (...) + + Returns: + Tensor: t, shape (...) + """ + ... + + def snr_inverse(self, snr: Tensor) -> Tensor: + r""" + Computes :math:`t` from the signal-to-noise ratio :math:`\frac{\alpha_t}{\sigma_t}`. + + Args: + snr (Tensor): The signal-to-noise, shape (...) + + Returns: + Tensor: t, shape (...) + """ + kappa_t = snr / (1.0 + snr) + + return self.kappa_inverse(kappa=kappa_t) + + +class CondOTScheduler(ConvexScheduler): + """CondOT Scheduler.""" + + def __call__(self, t: Tensor) -> SchedulerOutput: + return SchedulerOutput( + alpha_t=t, + sigma_t=1 - t, + d_alpha_t=torch.ones_like(t), + d_sigma_t=-torch.ones_like(t), + ) + + def kappa_inverse(self, kappa: Tensor) -> Tensor: + return kappa + + +class PolynomialConvexScheduler(ConvexScheduler): + """Polynomial Scheduler.""" + + def __init__(self, n: Union[float, int]) -> None: + assert isinstance( + n, (float, int) + ), f"`n` must be a float or int. Got {type(n)=}." + assert n > 0, f"`n` must be positive. Got {n=}." + + self.n = n + + def __call__(self, t: Tensor) -> SchedulerOutput: + return SchedulerOutput( + alpha_t=t**self.n, + sigma_t=1 - t**self.n, + d_alpha_t=self.n * (t ** (self.n - 1)), + d_sigma_t=-self.n * (t ** (self.n - 1)), + ) + + def kappa_inverse(self, kappa: Tensor) -> Tensor: + return torch.pow(kappa, 1.0 / self.n) + + +class VPScheduler(Scheduler): + """Variance Preserving Scheduler.""" + + def __init__(self, beta_min: float = 0.1, beta_max: float = 20.0) -> None: + self.beta_min = beta_min + self.beta_max = beta_max + super().__init__() + + def __call__(self, t: Tensor) -> SchedulerOutput: + b = self.beta_min + B = self.beta_max + T = 0.5 * (1 - t) ** 2 * (B - b) + (1 - t) * b + dT = -(1 - t) * (B - b) - b + + return SchedulerOutput( + alpha_t=torch.exp(-0.5 * T), + sigma_t=torch.sqrt(1 - torch.exp(-T)), + d_alpha_t=-0.5 * dT * torch.exp(-0.5 * T), + d_sigma_t=0.5 * dT * torch.exp(-T) / torch.sqrt(1 - torch.exp(-T)), + ) + + def snr_inverse(self, snr: Tensor) -> Tensor: + T = -torch.log(snr**2 / (snr**2 + 1)) + b = self.beta_min + B = self.beta_max + t = 1 - ((-b + torch.sqrt(b**2 + 2 * (B - b) * T)) / (B - b)) + return t + + +class LinearVPScheduler(Scheduler): + """Linear Variance Preserving Scheduler.""" + + def __call__(self, t: Tensor) -> SchedulerOutput: + return SchedulerOutput( + alpha_t=t, + sigma_t=(1 - t**2) ** 0.5, + d_alpha_t=torch.ones_like(t), + d_sigma_t=-t / (1 - t**2) ** 0.5, + ) + + def snr_inverse(self, snr: Tensor) -> Tensor: + return torch.sqrt(snr**2 / (1 + snr**2)) + + +class CosineScheduler(Scheduler): + """Cosine Scheduler.""" + + def __call__(self, t: Tensor) -> SchedulerOutput: + pi = torch.pi + return SchedulerOutput( + alpha_t=torch.sin(pi / 2 * t), + sigma_t=torch.cos(pi / 2 * t), + d_alpha_t=pi / 2 * torch.cos(pi / 2 * t), + d_sigma_t=-pi / 2 * torch.sin(pi / 2 * t), + ) + + def snr_inverse(self, snr: Tensor) -> Tensor: + return 2.0 * torch.atan(snr) / torch.pi diff --git a/flow_matching/solver/__init__.py b/flow_matching/solver/__init__.py new file mode 100644 index 0000000..8e62a7e --- /dev/null +++ b/flow_matching/solver/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from .discrete_solver import MixtureDiscreteEulerSolver, MixtureDiscreteSoftmaxEulerSolver +from .ode_solver import ODESolver +from .riemannian_ode_solver import RiemannianODESolver +from .solver import Solver + +__all__ = [ + "ODESolver", + "Solver", + "ModelWrapper", + "MixtureDiscreteEulerSolver", + "RiemannianODESolver", +] diff --git a/flow_matching/solver/discrete_solver.py b/flow_matching/solver/discrete_solver.py new file mode 100644 index 0000000..e8f9431 --- /dev/null +++ b/flow_matching/solver/discrete_solver.py @@ -0,0 +1,656 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from contextlib import nullcontext +from math import ceil +from typing import Callable, Optional, Union + +import torch +from torch import Tensor + +from torch.nn import functional as F +from tqdm import tqdm + +from flow_matching.path import MixtureDiscreteProbPath, MixtureDiscreteSoftmaxProbPath + +from flow_matching.solver.solver import Solver +from flow_matching.utils import categorical, ModelWrapper +from .utils import get_nearest_times + + +class MixtureDiscreteEulerSolver(Solver): + r"""Solver that simulates the CTMC process :math:`(X_t)_{t_{\text{init}}\leq t\leq t_{\text{final}}}` defined by :math:`p_t` the marginal probability path of ``path``. + Given :math:`X_t \sim p_t`, the algorithm of solver step from :math:`t` to :math:`t+h` for the i-th coordinate is: + + .. math:: + + \begin{align*} + & X_1^i \sim p_{1|t}^i(\cdot|X_t)\\ + & \lambda^i \gets \sum_{x^i\ne X_t^i} u_t^i(x^i, X_t^i|X_1^i)\\ + & Z^i_{\text{change}} \sim U[0,1]\\ + & X_{t+h}^i \sim \begin{cases} + \frac{u_t^i(\cdot, X_t^i|X_1^i)}{\lambda^i}(1-\delta_{X_t^i}(\cdot)) \text{ if $Z^i_{\text{change}}\le 1-e^{-h\lambda^i}$}\\ + \delta_{X_t^i}(\cdot) \text{ else } + \end{cases} + \end{align*} + + Where :math:`p_{1|t}(\cdot|X_t)` is the output of ``model``, and the conditional probability velocity is of the mixture probability path is: + + .. math:: + + u_t^i(x^i, y^i|x_1^i) = \hat{u}_t^i(x^i, y^i|x_1^i) + c_{\text{div\_free}}\left[\hat{u}_t^i(x^i, y^i|x_1^i) - \check{u}_t^i(x^i, y^i|x_1^i) \right], + + where + + .. math:: + \hat{u}_t^i(x^i, y^i|x_1^i) = \frac{\dot{\kappa}_t}{1-\kappa_t} \left[ \delta_{x_1^i}(x^i) - \delta_{y^i}(x^i) \right], + + and + + .. math:: + + \check{u}_t^i(x^i, y^i|x_1^i) = \frac{\dot{\kappa}_t}{\kappa_t}\left[ \delta_{y^i}(x^i) - p(x^i) \right]. + + The source distribution :math:`p(x^i)` is given by ``p``. + + Args: + model (ModelWrapper): trained with x-prediction, outputting posterior probabilities (in the range :math:`[0,1]`), output must be [..., vocabulary_size]. + path (MixtureDiscreteProbPath): Probability path used for x-prediction training. + vocabulary_size (int): size of the discrete vocabulary. + source_distribution_p (Optional[Tensor], optional): Source distribution, must be of shape [vocabulary_size]. Required only when divergence-free term for the probability velocity is non-zero. Defaults to None. + """ + + def __init__( + self, + model: ModelWrapper, + path: MixtureDiscreteProbPath, + vocabulary_size: int, + source_distribution_p: Optional[Tensor] = None, + ): + super().__init__() + self.model = model + self.path = path + self.vocabulary_size = vocabulary_size + + if source_distribution_p is not None: + assert source_distribution_p.shape == torch.Size( + [vocabulary_size] + ), f"Source distribution p dimension must match the vocabulary size {vocabulary_size}. Got {source_distribution_p.shape}." + + self.source_distribution_p = source_distribution_p + + @torch.no_grad() + def sample( + self, + x_init: Tensor, + step_size: Optional[float], + div_free: Union[float, Callable[[float], float]] = 0.0, + dtype_categorical: torch.dtype = torch.float32, + time_grid: Tensor = torch.tensor([0.0, 1.0]), + return_intermediates: bool = False, + verbose: bool = False, + **model_extras, + ) -> Tensor: + """ + Sample a sequence of discrete values from the given model. + + .. code-block:: python + + import torch + from flow_matching.utils import ModelWrapper + from flow_matching.solver import MixtureDiscreteEulerSolver + + class DummyModel(ModelWrapper): + def __init__(self): + super().__init__(None) + def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor: + return ... + + model = DummyModel() + solver = MixtureDiscreteEulerSolver(model=model) + + x_init = torch.LongTensor([122, 725]) + step_size = 0.001 + time_grid = torch.tensor([0.0, 1.0]) + + result = solver.sample(x_init=x_init, step_size=step_size, time_grid=time_grid) + + Args: + x_init (Tensor): The initial state. + step_size (Optional[float]): If float then time discretization is uniform with the given step size. If None then time discretization is set to be time_grid. + div_free (Union[float, Callable[[float], float]]): The coefficient of the divergence-free term in the probability velocity. Can be either a float or a time dependent function. Defaults to 0.0. + dtype_categorical (torch.dtype): Precision to use for categorical sampler. Defaults to torch.float32. + time_grid (Tensor): The CTMC process is solved in the interval [time_grid[0], time_grid[-1]] and if step_size is None then time discretization is set by the time grid. Defaults to torch.tensor([0.0,1.0]). + return_intermediates (bool): If True then return intermediate time steps according to time_grid. Defaults to False. + verbose (bool): Whether to print progress bars. Defaults to False. + **model_extras: Additional input for the model. + + Returns: + Tensor: The sampled sequence of discrete values. + """ + if not div_free == 0.0: + assert ( + self.source_distribution_p is not None + ), "Source distribution p must be specified in order to add a divergence-free term to the probability velocity." + + # Initialize the current state `x_t` with the initial state `X_0`. + time_grid = time_grid.to(device=x_init.device) + + if step_size is None: + # If step_size is None then set the t discretization to time_grid. + t_discretization = time_grid + n_steps = len(time_grid) - 1 + else: + # If step_size is float then t discretization is uniform with step size set by step_size. + t_init = time_grid[0].item() + t_final = time_grid[-1].item() + assert ( + t_final - t_init + ) > step_size, f"Time interval [time_grid[0], time_grid[-1]] must be larger than step_size. Got a time interval [{t_init}, {t_final}] and step_size {step_size}." + + n_steps = ceil((t_final - t_init) / step_size) + t_discretization = torch.tensor( + [t_init + step_size * i for i in range(n_steps)] + [t_final], + device=x_init.device, + ) + + if return_intermediates: + # get order of intermediate steps: + order = torch.argsort(time_grid) + # Compute intermediate steps to return via nearest points in t_discretization to time_grid. + time_grid = get_nearest_times( + time_grid=time_grid, t_discretization=t_discretization + ) + + x_t = x_init.clone() + steps_counter = 0 + res = [] + + if return_intermediates: + res = [x_init.clone()] + + if verbose: + ctx = tqdm(total=t_final, desc=f"NFE: {steps_counter}") + else: + ctx = nullcontext() + + with ctx: + for i in range(n_steps): + t = t_discretization[i : i + 1] + h = t_discretization[i + 1 : i + 2] - t_discretization[i : i + 1] + + # Sample x_1 ~ p_1|t( \cdot |x_t) + p_1t = self.model(x=x_t, t=t.repeat(x_t.shape[0]), **model_extras) + x_1 = categorical(p_1t.to(dtype=dtype_categorical)) + + # Checks if final step + if i == n_steps - 1: + x_t = x_1 + else: + # Compute u_t(x|x_t,x_1) + scheduler_output = self.path.scheduler(t=t) + + k_t = scheduler_output.alpha_t + d_k_t = scheduler_output.d_alpha_t + + delta_1 = F.one_hot(x_1, num_classes=self.vocabulary_size).to( + k_t.dtype + ) + u = d_k_t / (1 - k_t) * delta_1 + + # Add divergence-free part + div_free_t = div_free(t) if callable(div_free) else div_free + + if div_free_t > 0: + p_0 = self.source_distribution_p[(None,) * x_t.dim()] + u = u + div_free_t * d_k_t / (k_t * (1 - k_t)) * ( + (1 - k_t) * p_0 + k_t * delta_1 + ) + + # Set u_t(x_t|x_t,x_1) = 0 + delta_t = F.one_hot(x_t, num_classes=self.vocabulary_size) + u = torch.where( + delta_t.to(dtype=torch.bool), torch.zeros_like(u), u + ) + + # Sample x_t ~ u_t( \cdot |x_t,x_1) + intensity = u.sum(dim=-1) # Assuming u_t(xt|xt,x1) := 0 + mask_jump = torch.rand( + size=x_t.shape, device=x_t.device + ) < 1 - torch.exp(-h * intensity) + + if mask_jump.sum() > 0: + x_t[mask_jump] = categorical( + u[mask_jump].to(dtype=dtype_categorical) + ) + + steps_counter += 1 + t = t + h + + if return_intermediates and (t in time_grid): + res.append(x_t.clone()) + + if verbose: + ctx.n = t.item() + ctx.refresh() + ctx.set_description(f"NFE: {steps_counter}") + + if return_intermediates: + if step_size is None: + return torch.stack(res, dim=0) + else: + return torch.stack(res, dim=0)[order] + else: + return x_t + + +class MixtureDiscreteSoftmaxEulerSolver(Solver): + r"""Solver that simulates the CTMC process :math:`(X_t)_{t_{\text{init}}\leq t\leq t_{\text{final}}}` defined by :math:`p_t` the marginal probability path of ``path``. + Given :math:`X_t \sim p_t`, the algorithm of solver step from :math:`t` to :math:`t+h` for the i-th coordinate is: + + .. math:: + + \begin{align*} + & X_1^i \sim p_{1|t}^i(\cdot|X_t)\\ + & \lambda^i \gets \sum_{x^i\ne X_t^i} u_t^i(x^i, X_t^i|X_1^i)\\ + & Z^i_{\text{change}} \sim U[0,1]\\ + & X_{t+h}^i \sim \begin{cases} + \frac{u_t^i(\cdot, X_t^i|X_1^i)}{\lambda^i}(1-\delta_{X_t^i}(\cdot)) \text{ if $Z^i_{\text{change}}\le 1-e^{-h\lambda^i}$}\\ + \delta_{X_t^i}(\cdot) \text{ else } + \end{cases} + \end{align*} + + Where :math:`p_{1|t}(\cdot|X_t)` is the output of ``model``, and the conditional probability velocity is of the mixture probability path is: + + .. math:: + + u_t^i(x^i, y^i|x_1^i) = \hat{u}_t^i(x^i, y^i|x_1^i) + c_{\text{div\_free}}\left[\hat{u}_t^i(x^i, y^i|x_1^i) - \check{u}_t^i(x^i, y^i|x_1^i) \right], + + where + + .. math:: + \hat{u}_t^i(x^i, y^i|x_1^i) = \frac{\dot{\kappa}_t}{1-\kappa_t} \left[ \delta_{x_1^i}(x^i) - \delta_{y^i}(x^i) \right], + + and + + .. math:: + + \check{u}_t^i(x^i, y^i|x_1^i) = \frac{\dot{\kappa}_t}{\kappa_t}\left[ \delta_{y^i}(x^i) - p(x^i) \right]. + + The source distribution :math:`p(x^i)` is given by ``p``. + + Args: + model (ModelWrapper): trained with x-prediction, outputting posterior probabilities (in the range :math:`[0,1]`), output must be [..., vocabulary_size]. + path (MixtureDiscreteProbPath): Probability path used for x-prediction training. + vocabulary_size (int): size of the discrete vocabulary. + source_distribution_p (Optional[Tensor], optional): Source distribution, must be of shape [vocabulary_size]. Required only when divergence-free term for the probability velocity is non-zero. Defaults to None. + """ + + def __init__( + self, + model: ModelWrapper, + path_txt: MixtureDiscreteSoftmaxProbPath, + path_img: MixtureDiscreteSoftmaxProbPath, + vocabulary_size_txt: int, + vocabulary_size_img: int, + ): + super().__init__() + self.model = model + self.path_txt = path_txt + self.path_img = path_img + self.vocabulary_size_txt = vocabulary_size_txt + self.vocabulary_size_img = vocabulary_size_img + + @torch.no_grad() + def sample( + self, + x_init: Tensor, + step_size: Optional[float], + div_free: Union[float, Callable[[float], float]] = 0.0, + dtype_categorical: torch.dtype = torch.float32, + time_grid: Tensor = torch.tensor([0.0, 1.0]), + return_intermediates: bool = False, + verbose: bool = False, + # callback: bool = False, + **model_extras, + ) -> Tensor: + """ + Sample a sequence of discrete values from the given model. + + .. code-block:: python + + import torch + from flow_matching.utils import ModelWrapper + from flow_matching.solver import MixtureDiscreteEulerSolver + + class DummyModel(ModelWrapper): + def __init__(self): + super().__init__(None) + def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor: + return ... + + model = DummyModel() + solver = MixtureDiscreteEulerSolver(model=model) + + x_init = torch.LongTensor([122, 725]) + step_size = 0.001 + time_grid = torch.tensor([0.0, 1.0]) + + result = solver.sample(x_init=x_init, step_size=step_size, time_grid=time_grid) + + Args: + x_init (Tensor): The initial state. + step_size (Optional[float]): If float then time discretization is uniform with the given step size. If None then time discretization is set to be time_grid. + div_free (Union[float, Callable[[float], float]]): The coefficient of the divergence-free term in the probability velocity. Can be either a float or a time dependent function. Defaults to 0.0. + dtype_categorical (torch.dtype): Precision to use for categorical sampler. Defaults to torch.float32. + time_grid (Tensor): The CTMC process is solved in the interval [time_grid[0], time_grid[-1]] and if step_size is None then time discretization is set by the time grid. Defaults to torch.tensor([0.0,1.0]). + return_intermediates (bool): If True then return intermediate time steps according to time_grid. Defaults to False. + verbose (bool): Whether to print progress bars. Defaults to False. + **model_extras: Additional input for the model. + + Returns: + Tensor: The sampled sequence of discrete values. + """ + if not div_free == 0.0: + assert ( + self.source_distribution_p is not None + ), "Source distribution p must be specified in order to add a divergence-free term to the probability velocity." + + # Initialize the current state `x_t` with the initial state `X_0`. + time_grid = time_grid.to(device=x_init.device) + + if step_size is None: + # If step_size is None then set the t discretization to time_grid. + t_discretization = time_grid + n_steps = len(time_grid) - 1 + else: + # If step_size is float then t discretization is uniform with step size set by step_size. + t_init = time_grid[0].item() + t_final = time_grid[-1].item() + assert ( + t_final - t_init + ) > step_size, f"Time interval [time_grid[0], time_grid[-1]] must be larger than step_size. Got a time interval [{t_init}, {t_final}] and step_size {step_size}." + + n_steps = ceil((t_final - t_init) / step_size) + t_discretization = torch.tensor( + [t_init + step_size * i for i in range(n_steps)] + [t_final], + device=x_init.device, + ) + + if return_intermediates: + # get order of intermediate steps: + order = torch.argsort(time_grid) + # Compute intermediate steps to return via nearest points in t_discretization to time_grid. + time_grid = get_nearest_times( + time_grid=time_grid, t_discretization=t_discretization + ) + + x_t = x_init.clone() + steps_counter = 0 + res = [] + + if return_intermediates: + if self.model.g_or_u == 'generation': + res = [x_init.clone()[model_extras['datainfo']['image_token_mask']==1].reshape(x_init.shape[0], -1)] + elif self.model.g_or_u =='understanding': + res = [x_init.clone()[model_extras['datainfo']['text_token_mask']==1].reshape(x_init.shape[0], -1)] + else: + res = [x_init.clone()] + + + if verbose: + ctx = tqdm(total=time_grid[-1].item(), desc=f"NFE: {steps_counter}") + else: + ctx = nullcontext() + + with ctx: + original_x_t = x_t.clone() + batch_size = original_x_t.shape[0] + for i in range(n_steps): + t = t_discretization[i : i + 1] + h = t_discretization[i + 1 : i + 2] - t_discretization[i : i + 1] + + # Sample x_1 ~ p_1|t( \cdot |x_t) + p_1t_txt, p_1t_img, data_info = self.model(x=x_t, **model_extras) + if p_1t_txt is None: + x_1 = categorical(p_1t_img.to(dtype=dtype_categorical)) + x_1 = x_1[data_info['image_token_mask']==1].reshape(batch_size, -1) + x_t = x_t[data_info['image_token_mask']==1].reshape(batch_size, -1) + # x_1 = x_1 * data_info['image_token_mask'] + x_t * (1 - data_info['image_token_mask']) + elif p_1t_img is None: + x_1 = categorical(p_1t_txt.to(dtype=dtype_categorical)) + x_1 = x_1[data_info['text_token_mask']==1].reshape(batch_size, -1) + x_t = x_t[data_info['text_token_mask']==1].reshape(batch_size, -1) + # x_1 = x_1 * data_info['text_token_mask'] + x_t * (1 - data_info['text_token_mask']) + else: + x_1_img = categorical(p_1t_img.to(dtype=dtype_categorical)) + x_1_txt = categorical(p_1t_txt.to(dtype=dtype_categorical)) + x_1_img = x_1_img[data_info['image_token_mask']==1].reshape(batch_size, -1) + x_1_txt = x_1_txt[data_info['text_token_mask']==1].reshape(batch_size, -1) + x_t_img = x_t[data_info['image_token_mask']==1].reshape(batch_size, -1) + x_t_txt = x_t[data_info['text_token_mask']==1].reshape(batch_size, -1) + # x_1_txt = x_1_txt * data_info['text_token_mask'] + x_t * (1 - data_info['text_token_mask']) + # x_1_img = x_1_img * data_info['image_token_mask'] + x_t * (1 - data_info['image_token_mask']) + # x_1 = x_1_txt * (1 - data_info['image_token_mask']) + x_1_img * data_info['image_token_mask'] + + + # Checks if final step + if i == n_steps - 1: + if p_1t_txt is None: + x_t = x_1 + elif p_1t_img is None: + x_t = x_1 + else: + x_t = original_x_t.clone() + x_t[data_info['image_token_mask']==1] = x_1_img.flatten() + x_t[data_info['text_token_mask']==1] = x_1_txt.flatten() + + if return_intermediates: + res.append(x_t.clone()) + else: + if p_1t_txt is None: + # Compute p_t(x|x_1) + emb_x_1 = self.path_img.embedding(x_1) + prob_x_t = self.path_img.get_prob_distribution(emb_x_1, t) + prob_x_t = prob_x_t.reshape(-1, prob_x_t.shape[-1]) + # Comptute the metric + emb_x_t = self.path_img.embedding(x_t) + emb_x_t_flattened = F.normalize(emb_x_t.view(-1, emb_x_t.shape[-1]), p=2, dim=-1) + emb_x_1_flattened = F.normalize(emb_x_1.view(-1, emb_x_1.shape[-1]), p=2, dim=-1) + distance_x_t_2_x_1 = (torch.sum(emb_x_t_flattened ** 2, dim=1, keepdim=True) + torch.sum(emb_x_1_flattened**2, dim=1, keepdim=True) - 2 * torch.einsum('bd,bd->b', emb_x_t_flattened, emb_x_1_flattened).unsqueeze(1)) ** 2 + distance_x_1_2_x = self.path_img.metric(emb_x_1) + distance = F.relu(distance_x_t_2_x_1 - distance_x_1_2_x) + if t ==0 : + d_beta_t = 0 + else: + d_beta_t = self.path_img.c * self.path_img.a * ((t / (1 - t)) ** (self.path_img.a - 1)) * 1 / ((1 - t) ** 2) + # get u + u = prob_x_t * d_beta_t * distance + # print(f"prob_x_t:{prob_x_t}") + # print(f"d_beta_t:{d_beta_t}") + # print(f"distance:{distance}") + # print(f"t:{t}, {t.dtype}") + u = u.reshape(x_1.shape[0], x_1.shape[1], -1) + + # Set u_t(x_t|x_t,x_1) = 0 + delta_t = F.one_hot(x_t, num_classes=self.vocabulary_size_img) + u = torch.where( + delta_t.to(dtype=torch.bool), torch.zeros_like(u), u + ) + + # Sample x_t ~ u_t( \cdot |x_t,x_1) + intensity = u.sum(dim=-1) # Assuming u_t(xt|xt,x1) := 0 + # print(f"intensity:{intensity.sum()}") + mask_jump = torch.rand( + size=x_t.shape, device=x_t.device + ) < 1 - torch.exp(-h * intensity) + # torch.save(u, f'u_{u.device}.pt') + if mask_jump.sum() > 0: + x_t[mask_jump] = categorical( + u[mask_jump].to(dtype=dtype_categorical) + ) + if return_intermediates: + res.append(x_t.clone()) + # if callback: + # yield x_t + # res.append(x_1.clone()) + original_x_t[data_info['image_token_mask']==1] = x_t.flatten() + # original_x_t[data_info['image_token_mask']==1] = x_1.flatten() + x_t = original_x_t.clone() + elif p_1t_img is None: + # Compute p_t(x|x_1) + emb_x_1 = self.path_txt.embedding(x_1) + prob_x_t = self.path_txt.get_prob_distribution(emb_x_1, t) + prob_x_t = prob_x_t.reshape(-1, prob_x_t.shape[-1]) + # Comptute the metric + emb_x_t = self.path_txt.embedding(x_t) + emb_x_t_flattened = F.normalize(emb_x_t.view(-1, emb_x_t.shape[-1]), p=2, dim=-1) + emb_x_1_flattened = F.normalize(emb_x_1.view(-1, emb_x_1.shape[-1]), p=2, dim=-1) + distance_x_t_2_x_1 = (torch.sum(emb_x_t_flattened ** 2, dim=1, keepdim=True) + torch.sum(emb_x_1_flattened**2, dim=1, keepdim=True) - 2 * torch.einsum('bd,bd->b', emb_x_t_flattened, emb_x_1_flattened).unsqueeze(1)) ** 2 + distance_x_1_2_x = self.path_txt.metric(emb_x_1) + distance = F.relu(distance_x_t_2_x_1 - distance_x_1_2_x) + if t ==0 : + d_beta_t = 0 + else: + d_beta_t = self.path_txt.c * self.path_txt.a * ((t / (1 - t)) ** (self.path_txt.a - 1)) * 1 / ((1 - t) ** 2) + # get u + u = prob_x_t * d_beta_t * distance + # print(f"prob_x_t:{prob_x_t}") + # print(f"d_beta_t:{d_beta_t}") + # print(f"distance:{distance}") + # print(f"t:{t}, {t.dtype}") + u = u.reshape(x_1.shape[0], x_1.shape[1], -1) + + # Set u_t(x_t|x_t,x_1) = 0 + delta_t = F.one_hot(x_t, num_classes=self.vocabulary_size_txt) + u = torch.where( + delta_t.to(dtype=torch.bool), torch.zeros_like(u), u + ) + + # Sample x_t ~ u_t( \cdot |x_t,x_1) + intensity = u.sum(dim=-1) # Assuming u_t(xt|xt,x1) := 0 + mask_jump = torch.rand( + size=x_t.shape, device=x_t.device + ) < 1 - torch.exp(-h * intensity) + # torch.save(u, f'u_{u.device}.pt') + if mask_jump.sum() > 0: + x_t[mask_jump] = categorical( + u[mask_jump].to(dtype=dtype_categorical) + ) + if return_intermediates: + res.append(x_t.clone()) + # if callback: + # yield x_t + original_x_t[data_info['text_token_mask']==1] = x_t.flatten() + x_t = original_x_t.clone() + else: + # The text part + x_t = x_t_txt.clone() + x_1 = x_1_txt.clone() + # Compute p_t(x|x_1) + emb_x_1 = self.path_txt.embedding(x_1) + prob_x_t = self.path_txt.get_prob_distribution(emb_x_1, t) + prob_x_t = prob_x_t.reshape(-1, prob_x_t.shape[-1]) + # Comptute the metric + emb_x_t = self.path_txt.embedding(x_t) + emb_x_t_flattened = F.normalize(emb_x_t.view(-1, emb_x_t.shape[-1]), p=2, dim=-1) + emb_x_1_flattened = F.normalize(emb_x_1.view(-1, emb_x_1.shape[-1]), p=2, dim=-1) + distance_x_t_2_x_1 = (torch.sum(emb_x_t_flattened ** 2, dim=1, keepdim=True) + torch.sum(emb_x_1_flattened**2, dim=1, keepdim=True) - 2 * torch.einsum('bd,bd->b', emb_x_t_flattened, emb_x_1_flattened).unsqueeze(1)) ** 2 + distance_x_1_2_x = self.path_txt.metric(emb_x_1) + distance = F.relu(distance_x_t_2_x_1 - distance_x_1_2_x) + if t ==0 : + d_beta_t = 0 + else: + d_beta_t = self.path_txt.c * self.path_txt.a * ((t / (1 - t)) ** (self.path_txt.a - 1)) * 1 / ((1 - t) ** 2) + # get u + u = prob_x_t * d_beta_t * distance + # print(f"prob_x_t:{prob_x_t}") + # print(f"d_beta_t:{d_beta_t}") + # print(f"distance:{distance}") + # print(f"t:{t}, {t.dtype}") + u = u.reshape(x_1.shape[0], x_1.shape[1], -1) + + # Set u_t(x_t|x_t,x_1) = 0 + delta_t = F.one_hot(x_t, num_classes=self.vocabulary_size_txt) + u = torch.where( + delta_t.to(dtype=torch.bool), torch.zeros_like(u), u + ) + + # Sample x_t ~ u_t( \cdot |x_t,x_1) + intensity = u.sum(dim=-1) # Assuming u_t(xt|xt,x1) := 0 + mask_jump = torch.rand( + size=x_t.shape, device=x_t.device + ) < 1 - torch.exp(-h * intensity) + # torch.save(u, f'u_{u.device}.pt') + if mask_jump.sum() > 0: + x_t[mask_jump] = categorical( + u[mask_jump].to(dtype=dtype_categorical) + ) + original_x_t[data_info['text_token_mask']==1] = x_t.flatten() + + # The image part + x_t = x_t_img.clone() + x_1 = x_1_img.clone() + scheduler_output = self.path_img.scheduler(t=t) + emb_x_1 = self.path_img.embedding(x_1) + prob_x_t = self.path_img.get_prob_distribution(emb_x_1, t) + prob_x_t = prob_x_t.reshape(-1, prob_x_t.shape[-1]) + # Comptute the metric + emb_x_t = self.path_img.embedding(x_t) + emb_x_t_flattened = F.normalize(emb_x_t.view(-1, emb_x_t.shape[-1]), p=2, dim=-1) + emb_x_1_flattened = F.normalize(emb_x_1.view(-1, emb_x_1.shape[-1]), p=2, dim=-1) + distance_x_t_2_x_1 = (torch.sum(emb_x_t_flattened ** 2, dim=1, keepdim=True) + torch.sum(emb_x_1_flattened**2, dim=1, keepdim=True) - 2 * torch.einsum('bd,bd->b', emb_x_t_flattened, emb_x_1_flattened).unsqueeze(1)) ** 2 + distance_x_1_2_x = self.path_img.metric(emb_x_1) + distance = F.relu(distance_x_t_2_x_1 - distance_x_1_2_x) + if t ==0 : + d_beta_t = 0 + else: + d_beta_t = self.path_img.c * self.path_img.a * ((t / (1 - t)) ** (self.path_img.a - 1)) * 1 / ((1 - t) ** 2) + # get u + u = prob_x_t * d_beta_t * distance + # print(f"prob_x_t:{prob_x_t}") + # print(f"d_beta_t:{d_beta_t}") + # print(f"distance:{distance}") + # print(f"t:{t}, {t.dtype}") + u = u.reshape(x_1.shape[0], x_1.shape[1], -1) + + # Set u_t(x_t|x_t,x_1) = 0 + delta_t = F.one_hot(x_t, num_classes=self.vocabulary_size_img) + u = torch.where( + delta_t.to(dtype=torch.bool), torch.zeros_like(u), u + ) + + # Sample x_t ~ u_t( \cdot |x_t,x_1) + intensity = u.sum(dim=-1) # Assuming u_t(xt|xt,x1) := 0 + mask_jump = torch.rand( + size=x_t.shape, device=x_t.device + ) < 1 - torch.exp(-h * intensity) + # torch.save(u, f'u_{u.device}.pt') + if mask_jump.sum() > 0: + x_t[mask_jump] = categorical( + u[mask_jump].to(dtype=dtype_categorical) + ) + original_x_t[data_info['image_token_mask']==1] = x_t.flatten() + + x_t = original_x_t.clone() + if return_intermediates: + res.append(x_t.clone()) + + steps_counter += 1 + t = t + h + + if verbose: + ctx.n = t.item() + ctx.refresh() + ctx.set_description(f"NFE: {steps_counter}") + + # if return_intermediates and not callback: + if return_intermediates: + return torch.stack(res, dim=0)[:, 0, :].reshape(n_steps+1, -1) + # elif callback: + # yield x_t + else: + return x_t diff --git a/flow_matching/solver/ode_solver.py b/flow_matching/solver/ode_solver.py new file mode 100644 index 0000000..d2c1040 --- /dev/null +++ b/flow_matching/solver/ode_solver.py @@ -0,0 +1,194 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional, Sequence, Tuple, Union + +import torch +from torch import Tensor +from torchdiffeq import odeint + +from flow_matching.solver.solver import Solver +from flow_matching.utils import gradient, ModelWrapper + + +class ODESolver(Solver): + """A class to solve ordinary differential equations (ODEs) using a specified velocity model. + + This class utilizes a velocity field model to solve ODEs over a given time grid using numerical ode solvers. + + Args: + velocity_model (Union[ModelWrapper, Callable]): a velocity field model receiving :math:`(x,t)` and returning :math:`u_t(x)` + """ + + def __init__(self, velocity_model: Union[ModelWrapper, Callable]): + super().__init__() + self.velocity_model = velocity_model + + @torch.no_grad() + def sample( + self, + x_init: Tensor, + step_size: Optional[float], + method: str = "euler", + atol: float = 1e-5, + rtol: float = 1e-5, + time_grid: Tensor = torch.tensor([0.0, 1.0]), + return_intermediates: bool = False, + **model_extras, + ) -> Union[Tensor, Sequence[Tensor]]: + r"""Solve the ODE with the velocity field. + + Example: + + .. code-block:: python + + import torch + from flow_matching.utils import ModelWrapper + from flow_matching.solver import ODESolver + + class DummyModel(ModelWrapper): + def __init__(self): + super().__init__(None) + + def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor: + return torch.ones_like(x) * 3.0 * t**2 + + velocity_model = DummyModel() + solver = ODESolver(velocity_model=velocity_model) + x_init = torch.tensor([0.0, 0.0]) + step_size = 0.001 + time_grid = torch.tensor([0.0, 1.0]) + + result = solver.sample(x_init=x_init, step_size=step_size, time_grid=time_grid) + + Args: + x_init (Tensor): initial conditions (e.g., source samples :math:`X_0 \sim p`). Shape: [batch_size, ...]. + step_size (Optional[float]): The step size. Must be None for adaptive step solvers. + method (str): A method supported by torchdiffeq. Defaults to "euler". Other commonly used solvers are "dopri5", "midpoint" and "heun3". For a complete list, see torchdiffeq. + atol (float): Absolute tolerance, used for adaptive step solvers. + rtol (float): Relative tolerance, used for adaptive step solvers. + time_grid (Tensor): The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. May specify a descending time_grid to solve in the reverse direction. Defaults to torch.tensor([0.0, 1.0]). + return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Defaults to False. + **model_extras: Additional input for the model. + + Returns: + Union[Tensor, Sequence[Tensor]]: The last timestep when return_intermediates=False, otherwise all values specified in time_grid. + """ + + time_grid = time_grid.to(x_init.device) + + def ode_func(t, x): + return self.velocity_model(x=x, t=t, **model_extras) + + ode_opts = {"step_size": step_size} if step_size is not None else {} + + # Approximate ODE solution with numerical ODE solver + sol = odeint( + ode_func, + x_init, + time_grid, + method=method, + options=ode_opts, + atol=atol, + rtol=rtol, + ) + + if return_intermediates: + return sol + else: + return sol[-1] + + @torch.no_grad() + def compute_likelihood( + self, + x_1: Tensor, + log_p0: Callable[[Tensor], Tensor], + step_size: Optional[float], + method: str = "euler", + atol: float = 1e-5, + rtol: float = 1e-5, + time_grid: Tensor = torch.tensor([1.0, 0.0]), + return_intermediates: bool = False, + exact_divergence: bool = False, + **model_extras, + ) -> Union[Tuple[Tensor, Tensor], Tuple[Sequence[Tensor], Tensor]]: + r"""Solve for log likelihood given a target sample at :math:`t=0`. + + Works similarly to sample, but solves the ODE in reverse to compute the log-likelihood. The velocity model must be differentiable with respect to x. + The function assumes log_p0 is the log probability of the source distribution at :math:`t=0`. + + Args: + x_1 (Tensor): target sample (e.g., samples :math:`X_1 \sim p_1`). + log_p0 (Callable[[Tensor], Tensor]): Log probability function of the source distribution. + step_size (Optional[float]): The step size. Must be None for adaptive step solvers. + method (str): A method supported by torchdiffeq. Defaults to "euler". Other commonly used solvers are "dopri5", "midpoint" and "heun3". For a complete list, see torchdiffeq. + atol (float): Absolute tolerance, used for adaptive step solvers. + rtol (float): Relative tolerance, used for adaptive step solvers. + time_grid (Tensor): If step_size is None then time discretization is set by the time grid. Must start at 1.0 and end at 0.0, otherwise the likelihood computation is not valid. Defaults to torch.tensor([1.0, 0.0]). + return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Otherwise only return the final sample. Defaults to False. + exact_divergence (bool): Whether to compute the exact divergence or use the Hutchinson estimator. + **model_extras: Additional input for the model. + + Returns: + Union[Tuple[Tensor, Tensor], Tuple[Sequence[Tensor], Tensor]]: Samples at time_grid and log likelihood values of given x_1. + """ + assert ( + time_grid[0] == 1.0 and time_grid[-1] == 0.0 + ), f"Time grid must start at 1.0 and end at 0.0. Got {time_grid}" + + # Fix the random projection for the Hutchinson divergence estimator + if not exact_divergence: + z = (torch.randn_like(x_1).to(x_1.device) < 0) * 2.0 - 1.0 + + def ode_func(x, t): + return self.velocity_model(x=x, t=t, **model_extras) + + def dynamics_func(t, states): + xt = states[0] + with torch.set_grad_enabled(True): + xt.requires_grad_() + ut = ode_func(xt, t) + + if exact_divergence: + # Compute exact divergence + div = 0 + for i in range(ut.flatten(1).shape[1]): + div += gradient(ut[:, i], xt, create_graph=True)[:, i] + else: + # Compute Hutchinson divergence estimator E[z^T D_x(ut) z] + ut_dot_z = torch.einsum( + "ij,ij->i", ut.flatten(start_dim=1), z.flatten(start_dim=1) + ) + grad_ut_dot_z = gradient(ut_dot_z, xt) + div = torch.einsum( + "ij,ij->i", + grad_ut_dot_z.flatten(start_dim=1), + z.flatten(start_dim=1), + ) + + return ut.detach(), div.detach() + + y_init = (x_1, torch.zeros(x_1.shape[0], device=x_1.device)) + ode_opts = {"step_size": step_size} if step_size is not None else {} + + with torch.no_grad(): + sol, log_det = odeint( + dynamics_func, + y_init, + time_grid, + method=method, + options=ode_opts, + atol=atol, + rtol=rtol, + ) + + x_source = sol[-1] + source_log_p = log_p0(x_source) + + if return_intermediates: + return sol, source_log_p + log_det[-1] + else: + return sol[-1], source_log_p + log_det[-1] diff --git a/flow_matching/solver/riemannian_ode_solver.py b/flow_matching/solver/riemannian_ode_solver.py new file mode 100644 index 0000000..6eb3e5e --- /dev/null +++ b/flow_matching/solver/riemannian_ode_solver.py @@ -0,0 +1,243 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Callable + +import torch +from torch import Tensor +from tqdm import tqdm + +from flow_matching.solver.solver import Solver +from flow_matching.utils import ModelWrapper +from flow_matching.utils.manifolds import geodesic, Manifold + + +class RiemannianODESolver(Solver): + r"""Riemannian ODE solver + Initialize the ``RiemannianODESolver``. + + Args: + manifold (Manifold): the manifold to solve on. + velocity_model (ModelWrapper): a velocity field model receiving :math:`(x,t)` + and returning :math:`u_t(x)` which is assumed to lie on the tangent plane at `x`. + """ + + def __init__(self, manifold: Manifold, velocity_model: ModelWrapper): + super().__init__() + self.manifold = manifold + self.velocity_model = velocity_model + + @torch.no_grad() + def sample( + self, + x_init: Tensor, + step_size: float, + projx: bool = True, + proju: bool = True, + method: str = "euler", + time_grid: Tensor = torch.tensor([0.0, 1.0]), + return_intermediates: bool = False, + verbose: bool = False, + **model_extras, + ) -> Tensor: + r"""Solve the ODE with the `velocity_field` on the manifold. + + Args: + x_init (Tensor): initial conditions (e.g., source samples :math:`X_0 \sim p`). + step_size (float): The step size. + projx (bool): Whether to project the point onto the manifold at each step. Defaults to True. + proju (bool): Whether to project the vector field onto the tangent plane at each step. Defaults to True. + method (str): One of ["euler", "midpoint", "rk4"]. Defaults to "euler". + time_grid (Tensor, optional): The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. Defaults to torch.tensor([0.0,1.0]). + return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Defaults to False. + verbose (bool, optional): Whether to print progress bars. Defaults to False. + **model_extras: Additional input for the model. + + Returns: + Tensor: The sampled sequence. Defaults to returning samples at :math:`t=1`. + """ + step_fns = { + "euler": _euler_step, + "midpoint": _midpoint_step, + "rk4": _rk4_step, + } + assert method in step_fns.keys(), f"Unknown method {method}" + step_fn = step_fns[method] + + # --- Factor this out. + time_grid = torch.sort(time_grid.to(device=x_init.device)).values + + if step_size is None: + # If step_size is None then set the t discretization to time_grid. + t_discretization = time_grid + n_steps = len(time_grid) - 1 + else: + # If step_size is float then t discretization is uniform with step size set by step_size. + t_init = time_grid[0].item() + t_final = time_grid[-1].item() + assert ( + t_final - t_init + ) > step_size, f"Time interval [min(time_grid), max(time_grid)] must be larger than step_size. Got a time interval [{t_init}, {t_final}] and step_size {step_size}." + + n_steps = math.ceil((t_final - t_init) / step_size) + t_discretization = torch.tensor( + [step_size * i for i in range(n_steps)] + [t_final], + device=x_init.device, + ) + # --- + t0s = t_discretization[:-1] + + if verbose: + t0s = tqdm(t0s) + + if return_intermediates: + xts = [] + i_ret = 0 + + xt = x_init + for t0, t1 in zip(t0s, t_discretization[1:]): + dt = t1 - t0 + xt_next = step_fn( + self.velocity_model, + xt, + t0, + dt, + manifold=self.manifold, + projx=projx, + proju=proju, + ) + if return_intermediates: + while ( + i_ret < len(time_grid) + and t0 <= time_grid[i_ret] + and time_grid[i_ret] <= t1 + ): + xts.append( + interp(self.manifold, xt, xt_next, t0, t1, time_grid[i_ret]) + ) + i_ret += 1 + xt = xt_next + + if return_intermediates: + return torch.stack(xts, dim=0) + else: + return xt + + +def interp(manifold, xt, xt_next, t, t_next, t_ret): + return geodesic(manifold, xt, xt_next)( + (t_ret - t) / (t_next - t).reshape(1) + ).reshape_as(xt) + + +def _euler_step( + velocity_model: Callable, + xt: Tensor, + t0: Tensor, + dt: Tensor, + manifold: Manifold, + projx: bool = True, + proju: bool = True, +) -> Tensor: + r"""Perform an Euler step on a manifold. + + Args: + velocity_model (Callable): the velocity model + xt (Tensor): tensor containing the state at time t0 + t0 (Tensor): the time at which this step is taken + dt (Tensor): the step size + manifold (Manifold): a manifold object + projx (bool, optional): whether to project the state onto the manifold. Defaults to True. + proju (bool, optional): whether to project the velocity onto the tangent plane. Defaults to True. + + Returns: + Tensor: tensor containing the state after the step + """ + velocity_fn = lambda x, t: ( + manifold.proju(x, velocity_model(x, t)) if proju else velocity_model(x, t) + ) + projx_fn = lambda x: manifold.projx(x) if projx else x + + vt = velocity_fn(xt, t0) + + xt = xt + dt * vt + + return projx_fn(xt) + + +def _midpoint_step( + velocity_model: Callable, + xt: Tensor, + t0: Tensor, + dt: Tensor, + manifold: Manifold, + projx: bool = True, + proju: bool = True, +) -> Tensor: + r"""Perform a midpoint step on a manifold. + + Args: + velocity_model (Callable): the velocity model + xt (Tensor): tensor containing the state at time t0 + t0 (Tensor): the time at which this step is taken + dt (Tensor): the step size + manifold (Manifold): a manifold object + projx (bool, optional): whether to project the state onto the manifold. Defaults to True. + proju (bool, optional): whether to project the velocity onto the tangent plane. Defaults to True. + + Returns: + Tensor: tensor containing the state after the step + """ + velocity_fn = lambda x, t: ( + manifold.proju(x, velocity_model(x, t)) if proju else velocity_model(x, t) + ) + projx_fn = lambda x: manifold.projx(x) if projx else x + + half_dt = 0.5 * dt + vt = velocity_fn(xt, t0) + x_mid = xt + half_dt * vt + x_mid = projx_fn(x_mid) + + xt = xt + dt * velocity_fn(x_mid, t0 + half_dt) + + return projx_fn(xt) + + +def _rk4_step( + velocity_model: Callable, + xt: Tensor, + t0: Tensor, + dt: Tensor, + manifold: Manifold, + projx: bool = True, + proju: bool = True, +) -> Tensor: + r"""Perform an RK4 step on a manifold. + + Args: + velocity_model (Callable): the velocity model + xt (Tensor): tensor containing the state at time t0 + t0 (Tensor): the time at which this step is taken + dt (Tensor): the step size + manifold (Manifold): a manifold object + projx (bool, optional): whether to project the state onto the manifold. Defaults to True. + proju (bool, optional): whether to project the velocity onto the tangent plane. Defaults to True. + + Returns: + Tensor: tensor containing the state after the step + """ + velocity_fn = lambda x, t: ( + manifold.proju(x, velocity_model(x, t)) if proju else velocity_model(x, t) + ) + projx_fn = lambda x: manifold.projx(x) if projx else x + + k1 = velocity_fn(xt, t0) + k2 = velocity_fn(projx_fn(xt + dt * k1 / 3), t0 + dt / 3) + k3 = velocity_fn(projx_fn(xt + dt * (k2 - k1 / 3)), t0 + dt * 2 / 3) + k4 = velocity_fn(projx_fn(xt + dt * (k1 - k2 + k3)), t0 + dt) + + return projx_fn(xt + (k1 + 3 * (k2 + k3) + k4) * dt * 0.125) diff --git a/flow_matching/solver/solver.py b/flow_matching/solver/solver.py new file mode 100644 index 0000000..4819e1c --- /dev/null +++ b/flow_matching/solver/solver.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod + +from torch import nn, Tensor + + +class Solver(ABC, nn.Module): + """Abstract base class for solvers.""" + + @abstractmethod + def sample(self, x_0: Tensor = None) -> Tensor: + ... diff --git a/flow_matching/solver/utils.py b/flow_matching/solver/utils.py new file mode 100644 index 0000000..f3a34ee --- /dev/null +++ b/flow_matching/solver/utils.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor + + +def get_nearest_times(time_grid: Tensor, t_discretization: Tensor) -> Tensor: + distances = torch.cdist( + time_grid.unsqueeze(1), + t_discretization.unsqueeze(1), + compute_mode="donot_use_mm_for_euclid_dist", + ) + nearest_indices = distances.argmin(dim=1) + + return t_discretization[nearest_indices] diff --git a/flow_matching/utils/__init__.py b/flow_matching/utils/__init__.py new file mode 100644 index 0000000..0085c44 --- /dev/null +++ b/flow_matching/utils/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from .categorical_sampler import categorical +from .model_wrapper import ModelWrapper +from .utils import expand_tensor_like, gradient, unsqueeze_to_match + +__all__ = [ + "unsqueeze_to_match", + "expand_tensor_like", + "gradient", + "categorical", + "ModelWrapper", +] diff --git a/flow_matching/utils/categorical_sampler.py b/flow_matching/utils/categorical_sampler.py new file mode 100644 index 0000000..761a44e --- /dev/null +++ b/flow_matching/utils/categorical_sampler.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor + + +def categorical(probs: Tensor) -> Tensor: + r"""Categorical sampler according to weights in the last dimension of ``probs`` using :func:`torch.multinomial`. + + Args: + probs (Tensor): probabilities. + + Returns: + Tensor: Samples. + """ + + return torch.multinomial(probs.flatten(0, -2), 1, replacement=True).view( + *probs.shape[:-1] + ) + + # return torch.argmax(probs, dim=-1) diff --git a/flow_matching/utils/flow.py b/flow_matching/utils/flow.py new file mode 100644 index 0000000..5e73df4 --- /dev/null +++ b/flow_matching/utils/flow.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC +from typing import Optional, Tuple + +import torch +from torch import Tensor +from torch.nn.modules.loss import _Loss + +from flow_matching.loss import MixturePathGeneralizedKL +from flow_matching.path import MixtureDiscreteProbPath, ProbPath +from flow_matching.path.scheduler import PolynomialConvexScheduler + + +class SourceDistribution(ABC): + def __init__( + self, + ) -> None: + ... + + def sample(self, tensor_size: Tuple[int, ...], device: torch.device) -> Tensor: + ... + + def sample_like(self, tensor_like: Tensor) -> Tensor: + ... + + +class MaskedSourceDistribution(SourceDistribution): + def __init__(self, mask_token: int) -> None: + self.mask_token = mask_token + + @property + def masked(self) -> bool: + return True + + def sample(self, tensor_size: Tuple[int, ...], device: torch.device) -> Tensor: + return torch.zeros(tensor_size, device=device).fill_(self.mask_token).long() + + def sample_like(self, tensor_like: Tensor) -> Tensor: + return torch.zeros_like(tensor_like).fill_(self.mask_token).long() + + +class UniformSourceDistribution(SourceDistribution): + def __init__(self, vocab_size: int) -> None: + self.vocab_size = vocab_size + + @property + def masked(self) -> bool: + return False + + def sample(self, tensor_size: Tuple[int, ...], device: torch.device) -> Tensor: + return torch.randint(size=tensor_size, high=self.vocab_size, device=device) + + def sample_like(self, tensor_like: Tensor) -> Tensor: + return torch.randint_like(tensor_like, high=self.vocab_size) + + +def get_path(scheduler_type: str, exponent: Optional[float] = None) -> ProbPath: + if scheduler_type == "polynomial": + scheduler = PolynomialConvexScheduler(n=exponent) + else: + raise ValueError(f"{scheduler_type} is not supported") + + return MixtureDiscreteProbPath(scheduler=scheduler) + + +def get_source_distribution( + source_distribution: str, vocab_size: int +) -> SourceDistribution: + if source_distribution == "mask": + return MaskedSourceDistribution(mask_token=vocab_size) + elif source_distribution == "uniform": + return UniformSourceDistribution(vocab_size=vocab_size) + else: + raise ValueError(f"{source_distribution} is not supported") + + +def get_loss_function(loss_function: str, path: Optional[ProbPath] = None) -> _Loss: + if loss_function == "cross_entropy": + return torch.nn.CrossEntropyLoss() + elif loss_function == "generalized_kl": + assert path is not None + + return MixturePathGeneralizedKL(path=path) + else: + raise ValueError(f"{loss_function} is not supported") \ No newline at end of file diff --git a/flow_matching/utils/manifolds/__init__.py b/flow_matching/utils/manifolds/__init__.py new file mode 100644 index 0000000..1148872 --- /dev/null +++ b/flow_matching/utils/manifolds/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from .manifold import Euclidean, Manifold +from .sphere import Sphere +from .torus import FlatTorus +from .utils import geodesic + +__all__ = [ + "Euclidean", + "Manifold", + "Sphere", + "FlatTorus", + "geodesic", +] diff --git a/flow_matching/utils/manifolds/manifold.py b/flow_matching/utils/manifolds/manifold.py new file mode 100644 index 0000000..52a6a1b --- /dev/null +++ b/flow_matching/utils/manifolds/manifold.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +import abc + +import torch.nn as nn +from torch import Tensor + + +class Manifold(nn.Module, metaclass=abc.ABCMeta): + """A manifold class that contains projection operations and logarithm and exponential maps.""" + + @abc.abstractmethod + def expmap(self, x: Tensor, u: Tensor) -> Tensor: + r"""Computes exponential map :math:`\exp_x(u)`. + + Args: + x (Tensor): point on the manifold + u (Tensor): tangent vector at point :math:`x` + + Raises: + NotImplementedError: if not implemented + + Returns: + Tensor: transported point + """ + raise NotImplementedError + + @abc.abstractmethod + def logmap(self, x: Tensor, y: Tensor) -> Tensor: + r"""Computes logarithmic map :math:`\log_x(y)`. + + Args: + x (Tensor): point on the manifold + y (Tensor): point on the manifold + + Raises: + NotImplementedError: if not implemented + + Returns: + Tensor: tangent vector at point :math:`x` + """ + raise NotImplementedError + + @abc.abstractmethod + def projx(self, x: Tensor) -> Tensor: + """Project point :math:`x` on the manifold. + + Args: + x (Tensor): point to be projected + + Raises: + NotImplementedError: if not implemented + + Returns: + Tensor: projected point on the manifold + """ + raise NotImplementedError + + @abc.abstractmethod + def proju(self, x: Tensor, u: Tensor) -> Tensor: + """Project vector :math:`u` on a tangent space for :math:`x`. + + Args: + x (Tensor): point on the manifold + u (Tensor): vector to be projected + + Raises: + NotImplementedError: if not implemented + + Returns: + Tensor: projected tangent vector + """ + raise NotImplementedError + + +class Euclidean(Manifold): + """The Euclidean manifold.""" + + def expmap(self, x: Tensor, u: Tensor) -> Tensor: + return x + u + + def logmap(self, x: Tensor, y: Tensor) -> Tensor: + return y - x + + def projx(self, x: Tensor) -> Tensor: + return x + + def proju(self, x: Tensor, u: Tensor) -> Tensor: + return u diff --git a/flow_matching/utils/manifolds/sphere.py b/flow_matching/utils/manifolds/sphere.py new file mode 100644 index 0000000..76bf748 --- /dev/null +++ b/flow_matching/utils/manifolds/sphere.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor + +from flow_matching.utils.manifolds import Manifold + + +class Sphere(Manifold): + """Represents a hyperpshere in :math:`R^D`. Isometric to the product of 1-D spheres.""" + + EPS = {torch.float32: 1e-4, torch.float64: 1e-7} + + def expmap(self, x: Tensor, u: Tensor) -> Tensor: + norm_u = u.norm(dim=-1, keepdim=True) + exp = x * torch.cos(norm_u) + u * torch.sin(norm_u) / norm_u + retr = self.projx(x + u) + cond = norm_u > self.EPS[norm_u.dtype] + + return torch.where(cond, exp, retr) + + def logmap(self, x: Tensor, y: Tensor) -> Tensor: + u = self.proju(x, y - x) + dist = self.dist(x, y, keepdim=True) + cond = dist.gt(self.EPS[x.dtype]) + result = torch.where( + cond, + u * dist / u.norm(dim=-1, keepdim=True).clamp_min(self.EPS[x.dtype]), + u, + ) + return result + + def projx(self, x: Tensor) -> Tensor: + return x / x.norm(dim=-1, keepdim=True) + + def proju(self, x: Tensor, u: Tensor) -> Tensor: + return u - (x * u).sum(dim=-1, keepdim=True) * x + + def dist(self, x: Tensor, y: Tensor, *, keepdim=False) -> Tensor: + inner = (x * y).sum(-1, keepdim=keepdim) + return torch.acos(inner) diff --git a/flow_matching/utils/manifolds/torus.py b/flow_matching/utils/manifolds/torus.py new file mode 100644 index 0000000..3587ed7 --- /dev/null +++ b/flow_matching/utils/manifolds/torus.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +from torch import Tensor + +from flow_matching.utils.manifolds import Manifold + + +class FlatTorus(Manifold): + r"""Represents a flat torus on the :math:`[0, 2\pi]^D` subspace. Isometric to the product of 1-D spheres.""" + + def expmap(self, x: Tensor, u: Tensor) -> Tensor: + return (x + u) % (2 * math.pi) + + def logmap(self, x: Tensor, y: Tensor) -> Tensor: + return torch.atan2(torch.sin(y - x), torch.cos(y - x)) + + def projx(self, x: Tensor) -> Tensor: + return x % (2 * math.pi) + + def proju(self, x: Tensor, u: Tensor) -> Tensor: + return u diff --git a/flow_matching/utils/manifolds/utils.py b/flow_matching/utils/manifolds/utils.py new file mode 100644 index 0000000..b83d2fa --- /dev/null +++ b/flow_matching/utils/manifolds/utils.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable + +import torch +from torch import Tensor + +from flow_matching.utils.manifolds import Manifold + + +def geodesic( + manifold: Manifold, start_point: Tensor, end_point: Tensor +) -> Callable[[Tensor], Tensor]: + """Generate parameterized function for geodesic curve. + + Args: + manifold (Manifold): the manifold to compute geodesic on. + start_point (Tensor): point on the manifold at :math:`t=0`. + end_point (Tensor): point on the manifold at :math:`t=1`. + + Returns: + Callable[[Tensor], Tensor]: a function that takes in :math:`t` and outputs the geodesic at time :math:`t`. + """ + + shooting_tangent_vec = manifold.logmap(start_point, end_point) + + def path(t: Tensor) -> Tensor: + """Generate parameterized function for geodesic curve. + + Args: + t (Tensor): Times at which to compute points of the geodesics. + + Returns: + Tensor: geodesic path evaluated at time t. + """ + tangent_vecs = torch.einsum("i,...k->...ik", t, shooting_tangent_vec) + points_at_time_t = manifold.expmap(start_point.unsqueeze(-2), tangent_vecs) + + return points_at_time_t + + return path diff --git a/flow_matching/utils/model_wrapper.py b/flow_matching/utils/model_wrapper.py new file mode 100644 index 0000000..ac7d932 --- /dev/null +++ b/flow_matching/utils/model_wrapper.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC + +from torch import nn, Tensor + + +class ModelWrapper(ABC, nn.Module): + """ + This class is used to wrap around another model, adding custom forward pass logic. + """ + + def __init__(self, model: nn.Module): + super().__init__() + self.model = model + + def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor: + r""" + This method defines how inputs should be passed through the wrapped model. + Here, we're assuming that the wrapped model takes both :math:`x` and :math:`t` as input, + along with any additional keyword arguments. + + Optional things to do here: + - check that t is in the dimensions that the model is expecting. + - add a custom forward pass logic. + - call the wrapped model. + + | given x, t + | returns the model output for input x at time t, with extra information `extra`. + + Args: + x (Tensor): input data to the model (batch_size, ...). + t (Tensor): time (batch_size). + **extras: additional information forwarded to the model, e.g., text condition. + + Returns: + Tensor: model output. + """ + return self.model(x=x, t=t, **extras) diff --git a/flow_matching/utils/utils.py b/flow_matching/utils/utils.py new file mode 100644 index 0000000..beb31ff --- /dev/null +++ b/flow_matching/utils/utils.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from torch import Tensor + + +def unsqueeze_to_match(source: Tensor, target: Tensor, how: str = "suffix") -> Tensor: + """ + Unsqueeze the source tensor to match the dimensionality of the target tensor. + + Args: + source (Tensor): The source tensor to be unsqueezed. + target (Tensor): The target tensor to match the dimensionality of. + how (str, optional): Whether to unsqueeze the source tensor at the beginning + ("prefix") or end ("suffix"). Defaults to "suffix". + + Returns: + Tensor: The unsqueezed source tensor. + """ + assert ( + how == "prefix" or how == "suffix" + ), f"{how} is not supported, only 'prefix' and 'suffix' are supported." + + dim_diff = target.dim() - source.dim() + + for _ in range(dim_diff): + if how == "prefix": + source = source.unsqueeze(0) + elif how == "suffix": + source = source.unsqueeze(-1) + + return source + + +def expand_tensor_like(input_tensor: Tensor, expand_to: Tensor) -> Tensor: + """`input_tensor` is a 1d vector of length equal to the batch size of `expand_to`, + expand `input_tensor` to have the same shape as `expand_to` along all remaining dimensions. + + Args: + input_tensor (Tensor): (batch_size,). + expand_to (Tensor): (batch_size, ...). + + Returns: + Tensor: (batch_size, ...). + """ + assert input_tensor.ndim == 1, "Input tensor must be a 1d vector." + assert ( + input_tensor.shape[0] == expand_to.shape[0] + ), f"The first (batch_size) dimension must match. Got shape {input_tensor.shape} and {expand_to.shape}." + + dim_diff = expand_to.ndim - input_tensor.ndim + + t_expanded = input_tensor.clone() + t_expanded = t_expanded.reshape(-1, *([1] * dim_diff)) + + return t_expanded.expand_as(expand_to) + + +def gradient( + output: Tensor, + x: Tensor, + grad_outputs: Optional[Tensor] = None, + create_graph: bool = False, +) -> Tensor: + """ + Compute the gradient of the inner product of output and grad_outputs w.r.t :math:`x`. + + Args: + output (Tensor): [N, D] Output of the function. + x (Tensor): [N, d_1, d_2, ... ] input + grad_outputs (Optional[Tensor]): [N, D] Gradient of outputs, if `None`, + then will use a tensor of ones + create_graph (bool): If True, graph of the derivative will be constructed, allowing + to compute higher order derivative products. Defaults to False. + Returns: + Tensor: [N, d_1, d_2, ... ]. the gradient w.r.t x. + """ + + if grad_outputs is None: + grad_outputs = torch.ones_like(output).detach() + grad = torch.autograd.grad( + output, x, grad_outputs=grad_outputs, create_graph=create_graph + )[0] + return grad diff --git a/fudoki/__init__.py b/fudoki/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fudoki/eval_loop.py b/fudoki/eval_loop.py new file mode 100644 index 0000000..454a37c --- /dev/null +++ b/fudoki/eval_loop.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. +import logging + +import torch +from torch.nn.modules import Module + +from flow_matching.utils import ModelWrapper + +PRINT_FREQUENCY = 50 + + +logger = logging.getLogger(__name__) + + +def top_k_logits(logits, top_k=None): + top_k = min(top_k, logits.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min) + return logits + + +class CFGScaledModel(ModelWrapper): + def __init__(self, model: Module, g_or_u, mode='eval', top=1): + super().__init__(model) + self.nfe_counter = 0 + self.g_or_u = g_or_u + self.mode = mode + self.top = top + assert self.g_or_u in ['understanding', 'generation', 'generation and understanding'] + assert self.mode in ['train', 'eval', "train-top"] + + def forward( + self, x: torch.Tensor, cfg_scale: float, datainfo, uncond_id=100015 + ): + with torch.no_grad(): + conditional_img_logits, conditional_txt_logits = self.model(x, datainfo) + if self.g_or_u == 'understanding': + conditional_logits = conditional_txt_logits + result_txt = conditional_logits + if self.mode == 'eval': + result_txt = top_k_logits(result_txt, top_k=1) + elif self.mode == 'train-top': + result_txt = self.add_gumbel_noise(result_txt, temperature=0.1, dtype=result_txt.dtype) + result_txt = top_k_logits(result_txt, top_k=self.top) + result_img = None + elif self.g_or_u == 'generation': + conditional_logits = conditional_img_logits + + uncondition_x = x.clone() + text_token_mask = datainfo['text_token_mask'] + for bs in range(text_token_mask.shape[0]): + nz = datainfo['text_token_mask'][bs].nonzero() + if nz.numel() > 0: # Make sure there's at least one nonzero + text_nonzero_idx_begin = nz[0, 0] # first nonzero along dim=0 + text_nonzero_idx_end = nz[-1, 0] + uncondition_x[bs, text_nonzero_idx_begin:text_nonzero_idx_end+1] = uncond_id + unconditional_img_logits, _ = self.model(uncondition_x, datainfo) + unconditional_logits = unconditional_img_logits + result_img = (1.0 + cfg_scale) * conditional_logits - cfg_scale * unconditional_logits + # result_img = top_k_logits(result_img, top_k=1) # forcing this reduces diversity + result_txt = None + else: + result_img = conditional_img_logits + result_txt = conditional_txt_logits + + self.nfe_counter += 1 + if self.g_or_u == 'understanding': + return torch.softmax(result_txt.to(dtype=torch.float32), dim=-1), result_img, datainfo + elif self.g_or_u == 'generation': + return result_txt, torch.softmax(result_img.to(dtype=torch.float32), dim=-1), datainfo + else: + return torch.softmax(result_txt.to(dtype=torch.float32), dim=-1), torch.softmax(result_img.to(dtype=torch.float32), dim=-1), datainfo + + def reset_nfe_counter(self) -> None: + self.nfe_counter = 0 + + def get_nfe(self) -> int: + return self.nfe_counter + + def add_gumbel_noise(self, logits, temperature, dtype): + """ + The Gumbel max is a method for sampling categorical distributions. + According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality. + Thus, we use float64. + """ + if temperature == 0.0: + return logits # Skip noise when temperature is 0 + logits = logits.to(dtype) + noise = torch.rand_like(logits, dtype=dtype) + gumbel_noise = (-torch.log(noise)) ** temperature + return logits.exp() / gumbel_noise diff --git a/fudoki/janus/__init__.py b/fudoki/janus/__init__.py new file mode 100644 index 0000000..09cc08c --- /dev/null +++ b/fudoki/janus/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + +# check if python version is above 3.10 +import sys + +if sys.version_info >= (3, 10): + print("Python version is above 3.10, patching the collections module.") + # Monkey patch collections + import collections + import collections.abc + + for type_name in collections.abc.__all__: + setattr(collections, type_name, getattr(collections.abc, type_name)) diff --git a/fudoki/janus/models/__init__.py b/fudoki/janus/models/__init__.py new file mode 100644 index 0000000..9469193 --- /dev/null +++ b/fudoki/janus/models/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +from .image_processing_vlm import VLMImageProcessor +from .modeling_vlm import MultiModalityCausalLM +from .processing_vlm import VLChatProcessor + +__all__ = [ + "VLMImageProcessor", + "VLChatProcessor", + "MultiModalityCausalLM", +] diff --git a/fudoki/janus/models/action_tokenizer.py b/fudoki/janus/models/action_tokenizer.py new file mode 100644 index 0000000..b653d51 --- /dev/null +++ b/fudoki/janus/models/action_tokenizer.py @@ -0,0 +1,83 @@ +""" +action_tokenizer.py + +Extension class; wraps base LLM/VLM tokenizer with logic to discretize and tokenize continuous robot actions. +""" + +from typing import List, Union + +import numpy as np +from transformers import PreTrainedTokenizerBase + + +class ActionTokenizer: + def __init__( + self, tokenizer: PreTrainedTokenizerBase, bins: int = 256, min_action: int = -1, max_action: int = 1 + ) -> None: + """ + Discretizes continuous robot actions into N bins per dimension and maps to the least used tokens. + + NOTE =>> by default, assumes a BPE-style tokenizer akin to the LlamaTokenizer, where *the least used tokens* + appear at the end of the vocabulary! + + :param tokenizer: Base LLM/VLM tokenizer to extend. + :param bins: Number of bins for each continuous value; we'll adopt a uniform binning strategy. + :param min_action: Minimum action value (for clipping, setting lower bound on bin interval). + :param max_action: Maximum action value (for clipping, setting upper bound on bin interval). + """ + self.tokenizer, self.n_bins, self.min_action, self.max_action = tokenizer, bins, min_action, max_action + + # Create Uniform Bins + Compute Bin Centers + self.bins = np.linspace(min_action, max_action, self.n_bins) + self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0 + + # [Contract] Set "action_token_begin_idx" based on `self.tokenizer.vocab_size - (self.n_bins + 1)` + # =>> Assumes we're always overwriting the final `n_bins` tokens of the vocabulary! + self.action_token_begin_idx: int = int(self.tokenizer.vocab_size - (self.n_bins + 1)) + + def __call__(self, action: np.ndarray) -> Union[str, List[str]]: + """Clip & bin actions to *the last `n_bins` tokens* of the vocabulary (e.g., tokenizer.vocab[-256:]).""" + action = np.clip(action, a_min=float(self.min_action), a_max=float(self.max_action)) + discretized_action = np.digitize(action, self.bins) + + # Handle single element vs. batch + if len(discretized_action.shape) == 1: + return self.tokenizer.decode(list(self.tokenizer.vocab_size - discretized_action)) + else: + return self.tokenizer.batch_decode((self.tokenizer.vocab_size - discretized_action).tolist()) + + def encode_actions_to_token_ids(self, action: np.ndarray) -> Union[str, List[str]]: + action = np.clip(action, a_min=float(self.min_action), a_max=float(self.max_action)) + discretized_action = np.digitize(action, self.bins) + + # Handle single element vs. batch + if len(discretized_action.shape) == 1: + return list(self.tokenizer.vocab_size - discretized_action) + + else: + return (self.tokenizer.vocab_size - discretized_action).tolist() + + def decode_token_ids_to_actions(self, action_token_ids: np.ndarray) -> np.ndarray: + """ + Returns continuous actions for discrete action token IDs. + + NOTE =>> Because of the way the actions are discretized w.r.t. the bins (and not the bin centers), the + digitization returns bin indices between [1, # bins], inclusive, when there are actually only + (# bins - 1) bin intervals. + + Therefore, if the digitization returns the last possible index, we map this to the last bin interval. + + EXAMPLE =>> Let's say self._bins has 256 values. Then self._bin_centers has 255 values. Digitization returns + indices between [1, 256]. We subtract 1 from all indices so that they are between [0, 255]. There + is still one index (i==255) that would cause an out-of-bounds error if used to index into + self._bin_centers. Therefore, if i==255, we subtract 1 from it so that it just becomes the index of + the last bin center. We implement this simply via clipping between [0, 255 - 1]. + """ + discretized_actions = self.tokenizer.vocab_size - action_token_ids + discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1) + + return self.bin_centers[discretized_actions] + + @property + def vocab_size(self) -> int: + return self.n_bins \ No newline at end of file diff --git a/fudoki/janus/models/clip_encoder.py b/fudoki/janus/models/clip_encoder.py new file mode 100644 index 0000000..c436498 --- /dev/null +++ b/fudoki/janus/models/clip_encoder.py @@ -0,0 +1,122 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +from typing import Dict, List, Literal, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torchvision.transforms +from einops import rearrange + +from fudoki.janus.models.siglip_vit import create_siglip_vit + + +class CLIPVisionTower(nn.Module): + def __init__( + self, + model_name: str = "siglip_large_patch16_384", + image_size: Union[Tuple[int, int], int] = 336, + select_feature: str = "patch", + select_layer: int = -2, + select_layers: list = None, + ckpt_path: str = "", + pixel_mean: Optional[List[float]] = None, + pixel_std: Optional[List[float]] = None, + **kwargs, + ): + super().__init__() + + self.model_name = model_name + self.select_feature = select_feature + self.select_layer = select_layer + self.select_layers = select_layers + + vision_tower_params = { + "model_name": model_name, + "image_size": image_size, + "ckpt_path": ckpt_path, + "select_layer": select_layer, + } + vision_tower_params.update(kwargs) + self.vision_tower, self.forward_kwargs = self.build_vision_tower( + vision_tower_params + ) + + if pixel_mean is not None and pixel_std is not None: + image_norm = torchvision.transforms.Normalize( + mean=pixel_mean, std=pixel_std + ) + else: + image_norm = None + + self.image_norm = image_norm + + def build_vision_tower(self, vision_tower_params): + if self.model_name.startswith("siglip"): + self.select_feature = "same" + vision_tower = create_siglip_vit(**vision_tower_params) + forward_kwargs = dict() + + elif self.model_name.startswith("sam"): + vision_tower = create_sam_vit(**vision_tower_params) + forward_kwargs = dict() + + else: # huggingface + from transformers import CLIPVisionModel + + vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params) + forward_kwargs = dict(output_hidden_states=True) + + return vision_tower, forward_kwargs + + def feature_select(self, image_forward_outs): + if isinstance(image_forward_outs, torch.Tensor): + # the output has been the self.select_layer"s features + image_features = image_forward_outs + else: + image_features = image_forward_outs.hidden_states[self.select_layer] + + if self.select_feature == "patch": + # if the output has cls_token + image_features = image_features[:, 1:] + elif self.select_feature == "cls_patch": + image_features = image_features + elif self.select_feature == "same": + image_features = image_features + + else: + raise ValueError(f"Unexpected select feature: {self.select_feature}") + return image_features + + def forward(self, images): + """ + + Args: + images (torch.Tensor): [b, 3, H, W] + + Returns: + image_features (torch.Tensor): [b, n_patch, d] + """ + + if self.image_norm is not None: + images = self.image_norm(images) + + image_forward_outs = self.vision_tower(images, **self.forward_kwargs) + image_features = self.feature_select(image_forward_outs) + return image_features diff --git a/fudoki/janus/models/image_processing_vlm.py b/fudoki/janus/models/image_processing_vlm.py new file mode 100644 index 0000000..367dee1 --- /dev/null +++ b/fudoki/janus/models/image_processing_vlm.py @@ -0,0 +1,208 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +from typing import List, Tuple, Union + +import numpy as np +import torch +import torchvision +import torchvision.transforms.functional +from PIL import Image +from transformers import AutoImageProcessor, PretrainedConfig +from transformers.image_processing_utils import BaseImageProcessor, BatchFeature +from transformers.image_utils import to_numpy_array +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +ImageType = Union[np.ndarray, torch.Tensor, Image.Image] +IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073) +IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711) +IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) +IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) + + +def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + +class VLMImageProcessorConfig(PretrainedConfig): + model_type = "deepseek_vlm" + image_size: int + min_size: int + image_mean: Union[Tuple[float, float, float], List[float]] + image_std: Union[Tuple[float, float, float], List[float]] + rescale_factor: float + do_normalize: bool + + def __init__( + self, + image_size: int, + min_size: int = 14, + image_mean: Union[Tuple[float, float, float], List[float]] = ( + 0.48145466, + 0.4578275, + 0.40821073, + ), + image_std: Union[Tuple[float, float, float], List[float]] = ( + 0.26862954, + 0.26130258, + 0.27577711, + ), + rescale_factor: float = 1.0 / 255.0, + do_normalize: bool = True, + **kwargs, + ): + self.image_size = image_size + self.min_size = min_size + self.image_mean = image_mean + self.image_std = image_std + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + + super().__init__(**kwargs) + + +class VLMImageProcessor(BaseImageProcessor): + model_input_names = ["pixel_values"] + + def __init__( + self, + image_size: int, + min_size: int = 14, + image_mean: Union[Tuple[float, float, float], List[float]] = ( + 0.48145466, + 0.4578275, + 0.40821073, + ), + image_std: Union[Tuple[float, float, float], List[float]] = ( + 0.26862954, + 0.26130258, + 0.27577711, + ), + rescale_factor: float = 1.0 / 255.0, + do_normalize: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + + self.image_size = image_size + self.rescale_factor = rescale_factor + self.image_mean = image_mean + self.image_std = image_std + self.min_size = min_size + self.do_normalize = do_normalize + + if image_mean is None: + self.background_color = (127, 127, 127) + else: + self.background_color = tuple([int(x * 255) for x in image_mean]) + + def resize(self, pil_img: Image) -> np.ndarray: + """ + + Args: + pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB + + Returns: + x (np.ndarray): [3, self.image_size, self.image_size] + """ + + width, height = pil_img.size + max_size = max(width, height) + + size = [ + max(int(height / max_size * self.image_size), self.min_size), + max(int(width / max_size * self.image_size), self.min_size), + ] + + if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0: + print(f"orig size = {pil_img.size}, new size = {size}") + raise ValueError("Invalid size!") + + pil_img = torchvision.transforms.functional.resize( + pil_img, + size, + interpolation=torchvision.transforms.functional.InterpolationMode.BICUBIC, + antialias=True, + ) + + pil_img = expand2square(pil_img, self.background_color) + x = to_numpy_array(pil_img) + + # [H, W, 3] -> [3, H, W] + x = np.transpose(x, (2, 0, 1)) + + return x + + def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature: + # resize and pad to [self.image_size, self.image_size] + # then convert from [H, W, 3] to [3, H, W] + images: List[np.ndarray] = [self.resize(image) for image in images] + + # resacle from [0, 255] -> [0, 1] + images = [ + self.rescale( + image=image, + scale=self.rescale_factor, + input_data_format="channels_first", + ) + for image in images + ] + + # normalize + if self.do_normalize: + images = [ + self.normalize( + image=image, + mean=self.image_mean, + std=self.image_std, + input_data_format="channels_first", + ) + for image in images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + @property + def default_shape(self): + return [3, self.image_size, self.image_size] + + +AutoImageProcessor.register(VLMImageProcessorConfig, VLMImageProcessor) + + +if __name__ == "__main__": + image_processor = VLMImageProcessor( + image_size=1024, + image_mean=IMAGENET_INCEPTION_MEAN, + image_std=IMAGENET_INCEPTION_STD, + do_normalize=True, + ) diff --git a/fudoki/janus/models/modeling_vlm.py b/fudoki/janus/models/modeling_vlm.py new file mode 100644 index 0000000..07dd37e --- /dev/null +++ b/fudoki/janus/models/modeling_vlm.py @@ -0,0 +1,439 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import torch +from attrdict import AttrDict +from einops import rearrange +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + LlamaConfig, + LlamaForCausalLM, + PreTrainedModel, +) +from transformers.configuration_utils import PretrainedConfig + +from fudoki.janus.models.clip_encoder import CLIPVisionTower +from fudoki.janus.models.projector import MlpProjector + +from icecream import ic + +class vision_head(torch.nn.Module): + def __init__(self, params): + super().__init__() + self.output_mlp_projector = torch.nn.Linear( + params.n_embed, params.image_token_embed + ) + self.vision_activation = torch.nn.GELU() + self.vision_head = torch.nn.Linear( + params.image_token_embed, params.image_token_size + ) + + def forward(self, x): + x = self.output_mlp_projector(x) + x = self.vision_activation(x) + x = self.vision_head(x) + return x + +class number_head(torch.nn.Module): + def __init__(self, d_model=2048, dim_feedforward=5632, numhead_bias=True): + super().__init__() + self.num_head = torch.nn.Sequential( + torch.nn.Linear(d_model, dim_feedforward, bias=numhead_bias), + torch.nn.GELU(), + torch.nn.Linear(dim_feedforward, 1, bias=numhead_bias), + ) + + def forward(self, x): + x = self.num_head(x) + return x + +def model_name_to_cls(cls_name): + if "MlpProjector" in cls_name: + cls = MlpProjector + + elif "CLIPVisionTower" in cls_name: + cls = CLIPVisionTower + + elif "VQ" in cls_name: + from fudoki.janus.models.vq_model import VQ_models + + cls = VQ_models[cls_name] + elif "vision_head" in cls_name: + cls = vision_head + else: + raise ValueError(f"class_name {cls_name} is invalid.") + + return cls + + +class VisionConfig(PretrainedConfig): + model_type = "vision" + cls: str = "" + params: AttrDict = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.cls = kwargs.get("cls", "") + if not isinstance(self.cls, str): + self.cls = self.cls.__name__ + + self.params = AttrDict(kwargs.get("params", {})) + + +class AlignerConfig(PretrainedConfig): + model_type = "aligner" + cls: str = "" + params: AttrDict = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.cls = kwargs.get("cls", "") + if not isinstance(self.cls, str): + self.cls = self.cls.__name__ + + self.params = AttrDict(kwargs.get("params", {})) + + +class GenVisionConfig(PretrainedConfig): + model_type = "gen_vision" + cls: str = "" + params: AttrDict = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.cls = kwargs.get("cls", "") + if not isinstance(self.cls, str): + self.cls = self.cls.__name__ + + self.params = AttrDict(kwargs.get("params", {})) + + +class GenAlignerConfig(PretrainedConfig): + model_type = "gen_aligner" + cls: str = "" + params: AttrDict = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.cls = kwargs.get("cls", "") + if not isinstance(self.cls, str): + self.cls = self.cls.__name__ + + self.params = AttrDict(kwargs.get("params", {})) + + +class GenHeadConfig(PretrainedConfig): + model_type = "gen_head" + cls: str = "" + params: AttrDict = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.cls = kwargs.get("cls", "") + if not isinstance(self.cls, str): + self.cls = self.cls.__name__ + + self.params = AttrDict(kwargs.get("params", {})) + + +class MultiModalityConfig(PretrainedConfig): + model_type = "multi_modality" + vision_config: VisionConfig + aligner_config: AlignerConfig + + gen_vision_config: GenVisionConfig + gen_aligner_config: GenAlignerConfig + gen_head_config: GenHeadConfig + + language_config: LlamaConfig + + def __init__(self, **kwargs): + super().__init__(**kwargs) + vision_config = kwargs.get("vision_config", {}) + self.vision_config = VisionConfig(**vision_config) + + aligner_config = kwargs.get("aligner_config", {}) + self.aligner_config = AlignerConfig(**aligner_config) + + gen_vision_config = kwargs.get("gen_vision_config", {}) + self.gen_vision_config = GenVisionConfig(**gen_vision_config) + + gen_aligner_config = kwargs.get("gen_aligner_config", {}) + self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config) + + gen_head_config = kwargs.get("gen_head_config", {}) + self.gen_head_config = GenHeadConfig(**gen_head_config) + + language_config = kwargs.get("language_config", {}) + if isinstance(language_config, LlamaConfig): + self.language_config = language_config + else: + self.language_config = LlamaConfig(**language_config) + + +class MultiModalityPreTrainedModel(PreTrainedModel): + config_class = MultiModalityConfig + base_model_prefix = "multi_modality" + _no_split_modules = [] + _skip_keys_device_placement = "past_key_values" + + +class MultiModalityCausalLM(MultiModalityPreTrainedModel): + def __init__(self, config: MultiModalityConfig): + super().__init__(config) + + vision_config = config.vision_config + vision_cls = model_name_to_cls(vision_config.cls) + self.vision_model = vision_cls(**vision_config.params) + + aligner_config = config.aligner_config + aligner_cls = model_name_to_cls(aligner_config.cls) + self.aligner = aligner_cls(aligner_config.params) + + gen_vision_config = config.gen_vision_config + gen_vision_cls = model_name_to_cls(gen_vision_config.cls) + self.gen_vision_model = gen_vision_cls() + + gen_aligner_config = config.gen_aligner_config + gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls) + self.gen_aligner = gen_aligner_cls(gen_aligner_config.params) + + gen_head_config = config.gen_head_config + gen_head_cls = model_name_to_cls(gen_head_config.cls) + self.gen_head = gen_head_cls(gen_head_config.params) + + self.gen_embed = torch.nn.Embedding( + gen_vision_config.params.image_token_size, gen_vision_config.params.n_embed + ) + + language_config = config.language_config + self.language_model = LlamaForCausalLM(language_config) + # Convert causal attention to full attention + self._convert_to_full_attention() + + def add_number_head(self): + self.language_model.num_head = number_head( + d_model=self.language_model.config.hidden_size, + dim_feedforward=self.language_model.config.intermediate_size, + numhead_bias=True, + ) + + def _convert_to_full_attention(self): + """Convert all causal attention layers to full attention using BlockDiagonalMask""" + import types + import xformers + import xformers.ops.fmha as fmha + from transformers.models.llama.modeling_llama import repeat_kv, apply_rotary_pos_emb + from transformers.cache_utils import Cache + import torch.nn.functional as F + from typing import Optional, Tuple + + for layer_idx, layer in enumerate(self.language_model.model.layers): + if hasattr(layer, 'self_attn'): + def full_attention_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + query_states = query_states.contiguous().transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads, self.head_dim).contiguous() + key_states = key_states.contiguous().transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).contiguous() + value_states = value_states.contiguous().transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).contiguous() + # attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, p=self.attention_dropout, attn_bias=None) + + attn_bias, flattened_q = fmha.BlockDiagonalMask.from_tensor_list([query_states[bs, :attention_mask[bs].sum()][None] for bs in range(bsz)]) + _, flattened_k = fmha.BlockDiagonalMask.from_tensor_list([key_states[bs, :attention_mask[bs].sum()][None] for bs in range(bsz)]) + _, flattened_v = fmha.BlockDiagonalMask.from_tensor_list([value_states[bs, :attention_mask[bs].sum()][None] for bs in range(bsz)]) + + + output = xformers.ops.memory_efficient_attention(flattened_q, flattened_k, flattened_v, p=self.attention_dropout, attn_bias=attn_bias) + output = attn_bias.split(output) + attn_output = hidden_states.clone().reshape(bsz, q_len, self.num_heads, self.head_dim) + for bs in range(bsz): + attn_output[bs, :attention_mask[bs].sum()] = output[bs] + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + else: + query, key = flattened_q, flattened_k + scale = 1.0 / query.shape[-1] ** 0.5 + query = query * scale + query = query.transpose(1, 2) + key = key.transpose(1, 2) + attn = query @ key.transpose(-2, -1) + attn = attn.softmax(-1) + attn_weights = attn + + return attn_output, attn_weights, past_key_value + + # Bind the modified forward method + layer.self_attn.forward = types.MethodType(full_attention_forward, layer.self_attn) + + if hasattr(self.language_model.model, '_update_causal_mask'): + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + ): + # Simply return the original attention mask without any causal modifications + return attention_mask + + # Override the method in the language model + self.language_model.model._update_causal_mask = types.MethodType( + _update_causal_mask, self.language_model.model + ) + + def prepare_inputs_embeds( + self, + input_ids: torch.LongTensor, + pixel_values: torch.FloatTensor, + images_seq_mask: torch.LongTensor, + images_emb_mask: torch.LongTensor, + **kwargs, + ): + """ + + Args: + input_ids (torch.LongTensor): [b, T] + pixel_values (torch.FloatTensor): [b, n_images, 3, h, w] + images_seq_mask (torch.BoolTensor): [b, T] + images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens] + + assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask) + + Returns: + input_embeds (torch.Tensor): [b, T, D] + """ + + bs, n = pixel_values.shape[0:2] + images = rearrange(pixel_values, "b n c h w -> (b n) c h w") + # [b x n, T2, D] + images_embeds = self.aligner(self.vision_model(images)) + + # [b x n, T2, D] -> [b, n x T2, D] + images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n) + # [b, n, T2] -> [b, n x T2] + images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)") + + # [b, T, D] + input_ids[input_ids < 0] = 0 # ignore the image embeddings + inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + + # replace with the image embeddings + inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask] + + return inputs_embeds + + def prepare_gen_img_embeds(self, image_ids: torch.LongTensor): + return self.gen_aligner(self.gen_embed(image_ids)) + + def get_fsdp_wrap_module_list(self): + return list(self.language_model.model.layers) + + def token_drop(self, mmsamples, datainfo, uncond_prob=0.1, uncond_id=100015): + batch_size = mmsamples.shape[0] + drop_ids = torch.rand(batch_size, device=mmsamples.device) < uncond_prob + uncondition_context = mmsamples.clone() + generation_mask = (datainfo['generation_or_understanding_mask'] == 1) + for b in range(batch_size): + if drop_ids[b]: + if generation_mask[b]: + nz = datainfo['text_token_mask'][b].nonzero() + if nz.numel() > 0: + text_nonzero_idx_begin = nz[0, 0] + text_nonzero_idx_end = nz[-1, 0] + uncondition_context[b, text_nonzero_idx_begin:text_nonzero_idx_end+1] = uncond_id + return uncondition_context + + def forward(self, mmsamples, datainfo): + if self.training: + mmsamples = self.token_drop(mmsamples, datainfo, uncond_prob=0.1, uncond_id=100015) + for b_index in range(mmsamples.shape[0]): + mask = datainfo['image_token_mask'][b_index] == 1 + indices = torch.nonzero(mask, as_tuple=False) + + if datainfo['generation_or_understanding_mask'][b_index] == 1: + imgsamples = mmsamples[b_index, indices[:, 0]] + img_embeds = self.prepare_gen_img_embeds(imgsamples.unsqueeze(0)) + inputs_embeds = self.language_model.get_input_embeddings()(mmsamples) + inputs_embeds[b_index, indices[:, 0]] = img_embeds.reshape(-1, img_embeds.shape[-1]) + elif datainfo['generation_or_understanding_mask'][b_index] == 0: + imgsamples = datainfo['understanding_img'] + if datainfo['has_understanding_img'][b_index] == 1: + imgs = imgsamples[b_index] + if imgs.dim() == 3: + imgs = imgs.unsqueeze(0) + img_embeds = self.aligner(self.vision_model(imgs)) + inputs_embeds = self.language_model.get_input_embeddings()(mmsamples) + inputs_embeds[b_index, indices[:, 0]] = img_embeds.reshape(-1, img_embeds.shape[-1]) + else: + inputs_embeds = self.language_model.get_input_embeddings()(mmsamples) + + outputs = self.language_model.model(inputs_embeds=inputs_embeds, use_cache=False, attention_mask=datainfo['attention_mask']) + hidden_states = outputs.last_hidden_state + + img_logits = self.gen_head(hidden_states) + txt_logits = self.language_model.lm_head(hidden_states) + + img_logits = torch.cat([torch.zeros((img_logits.shape[0], 1, img_logits.shape[2]), device=img_logits.device), img_logits[:, :-1, :]], dim=1) + txt_logits = torch.cat([torch.zeros((txt_logits.shape[0], 1, txt_logits.shape[2]), device=txt_logits.device), txt_logits[:, :-1, :]], dim=1) + + return img_logits, txt_logits + + +AutoConfig.register("vision", VisionConfig) +AutoConfig.register("aligner", AlignerConfig) +AutoConfig.register("gen_vision", GenVisionConfig) +AutoConfig.register("gen_aligner", GenAlignerConfig) +AutoConfig.register("gen_head", GenHeadConfig) +AutoConfig.register("multi_modality", MultiModalityConfig) +AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM) diff --git a/fudoki/janus/models/processing_vlm.py b/fudoki/janus/models/processing_vlm.py new file mode 100644 index 0000000..5b1ddb4 --- /dev/null +++ b/fudoki/janus/models/processing_vlm.py @@ -0,0 +1,419 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +from dataclasses import dataclass +from typing import Dict, List + +import torch +from PIL.Image import Image +from transformers import LlamaTokenizerFast +from transformers.processing_utils import ProcessorMixin + +from fudoki.janus.models.image_processing_vlm import VLMImageProcessor +from fudoki.janus.utils.conversation import get_conv_template + + +class DictOutput(object): + def keys(self): + return self.__dict__.keys() + + def __getitem__(self, item): + return self.__dict__[item] + + def __setitem__(self, key, value): + self.__dict__[key] = value + + +@dataclass +class VLChatProcessorOutput(DictOutput): + sft_format: str + input_ids: torch.Tensor + pixel_values: torch.Tensor + num_image_tokens: torch.IntTensor + + def __len__(self): + return len(self.input_ids) + + +@dataclass +class BatchedVLChatProcessorOutput(DictOutput): + sft_format: List[str] + input_ids: torch.Tensor + pixel_values: torch.Tensor + attention_mask: torch.Tensor + images_seq_mask: torch.BoolTensor + images_emb_mask: torch.BoolTensor + + def to(self, device, dtype=torch.bfloat16): + self.input_ids = self.input_ids.to(device) + self.attention_mask = self.attention_mask.to(device) + self.images_seq_mask = self.images_seq_mask.to(device) + self.images_emb_mask = self.images_emb_mask.to(device) + self.pixel_values = self.pixel_values.to(device=device, dtype=dtype) + return self + + +class VLChatProcessor(ProcessorMixin): + image_processor_class = "AutoImageProcessor" + tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") + + attributes = ["image_processor", "tokenizer"] + + system_prompt = ( + "You are a helpful language and vision assistant. " + "You are able to understand the visual content that the user provides, " + "and assist the user with a variety of tasks using natural language." + ) + + def __init__( + self, + image_processor: VLMImageProcessor, + tokenizer: LlamaTokenizerFast, + image_tag: str = "", + image_start_tag: str = "", + image_end_tag: str = "", + num_image_tokens: int = 576, + add_special_token: bool = False, + sft_format: str = "fudoki", + mask_prompt: bool = True, + ignore_id: int = -100, + **kwargs, + ): + self.image_processor = image_processor + self.tokenizer = tokenizer + + image_id = self.tokenizer.vocab.get(image_tag) + if image_id is None: + special_tokens = [image_tag] + special_tokens_dict = {"additional_special_tokens": special_tokens} + self.tokenizer.add_special_tokens(special_tokens_dict) + print(f"Add image tag = {image_tag} to the tokenizer") + + self.image_tag = image_tag + self.image_start_tag = image_start_tag + self.image_end_tag = image_end_tag + + self.num_image_tokens = num_image_tokens + self.add_special_token = add_special_token + self.sft_format = sft_format + self.mask_prompt = mask_prompt + self.ignore_id = ignore_id + + super().__init__( + image_processor, + tokenizer, + image_tag, + num_image_tokens, + add_special_token, + sft_format, + mask_prompt, + ignore_id, + **kwargs, + ) + + def new_chat_template(self): + conv = get_conv_template(self.sft_format) + conv.set_system_message(self.system_prompt) + return conv + + def apply_sft_template_for_multi_turn_prompts( + self, + conversations: List[Dict[str, str]], + sft_format: str = "fudoki", + system_prompt: str = "", + ): + """ + Applies the SFT template to conversation. + + An example of conversation: + conversation = [ + { + "role": "User", + "content": " is Figure 1.\n is Figure 2.\nWhich image is brighter?", + "images": [ + "./multi-images/attribute_comparison_1.png", + "./multi-images/attribute_comparison_2.png" + ] + }, + { + "role": "Assistant", + "content": "" + } + ] + + Args: + conversations (List[Dict]): A conversation with a List of Dict[str, str] text. + sft_format (str, optional): The format of the SFT template to use. Defaults to "fudoki". + system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "". + + Returns: + sft_prompt (str): The formatted text. + """ + + conv = get_conv_template(sft_format) + conv.set_system_message(system_prompt) + for message in conversations: + conv.append_message(message["role"], message["content"].strip()) + sft_prompt = conv.get_prompt().strip() + + return sft_prompt + + @property + def image_token(self): + return self.image_tag + + @property + def image_id(self): + image_id = self.tokenizer.vocab.get(self.image_tag) + return image_id + + @property + def image_start_id(self): + image_start_id = self.tokenizer.vocab.get(self.image_start_tag) + return image_start_id + + @property + def image_end_id(self): + image_end_id = self.tokenizer.vocab.get(self.image_end_tag) + return image_end_id + + @property + def image_start_token(self): + return self.image_start_tag + + @property + def image_end_token(self): + return self.image_end_tag + + @property + def pad_id(self): + pad_id = self.tokenizer.pad_token_id + if pad_id is None: + pad_id = self.tokenizer.eos_token_id + + return pad_id + + def num_id(self, tag=""): + num_id = self.tokenizer.vocab.get(tag) + return num_id + + def add_image_token( + self, + image_indices: List[int], + input_ids: torch.LongTensor, + ): + """ + + Args: + image_indices (List[int]): [index_0, index_1, ..., index_j] + input_ids (torch.LongTensor): [N] + + Returns: + input_ids (torch.LongTensor): [N + image tokens] + num_image_tokens (torch.IntTensor): [n_images] + """ + + input_slices = [] + + start = 0 + for index in image_indices: + if self.add_special_token: + end = index + 1 + else: + end = index + + # original text tokens + input_slices.append(input_ids[start:end]) + + # add boi, image tokens, eoi and set the mask as False + input_slices.append(self.image_start_id * torch.ones((1), dtype=torch.long)) + input_slices.append( + self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long) + ) + input_slices.append(self.image_end_id * torch.ones((1), dtype=torch.long)) + start = index + 1 + + # the left part + input_slices.append(input_ids[start:]) + + # concat all slices + input_ids = torch.cat(input_slices, dim=0) + num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices)) + + return input_ids, num_image_tokens + + def process_one( + self, + prompt: str = None, + conversations: List[Dict[str, str]] = None, + images: List[Image] = None, + **kwargs, + ): + """ + + Args: + prompt (str): the formatted prompt; + conversations (List[Dict]): conversations with a list of messages; + images (List[ImageType]): the list of images; + **kwargs: + + Returns: + outputs (BaseProcessorOutput): the output of the processor, + - input_ids (torch.LongTensor): [N + image tokens] + - target_ids (torch.LongTensor): [N + image tokens] + - images (torch.FloatTensor): [n_images, 3, H, W] + - image_id (int): the id of the image token + - num_image_tokens (List[int]): the number of image tokens + """ + + assert ( + prompt is None or conversations is None + ), "prompt and conversations cannot be used at the same time." + + if prompt is None: + # apply sft format + sft_format = self.apply_sft_template_for_multi_turn_prompts( + conversations=conversations, + sft_format=self.sft_format, + system_prompt=self.system_prompt, + ) + else: + sft_format = prompt + + # tokenize + input_ids = self.tokenizer.encode(sft_format) + input_ids = torch.LongTensor(input_ids) + + # add image tokens to the input_ids + image_token_mask: torch.BoolTensor = input_ids == self.image_id + image_indices = image_token_mask.nonzero() + input_ids, num_image_tokens = self.add_image_token( + image_indices=image_indices, + input_ids=input_ids, + ) + + # load images + images_outputs = self.image_processor(images, return_tensors="pt") + + prepare = VLChatProcessorOutput( + sft_format=sft_format, + input_ids=input_ids, + pixel_values=images_outputs.pixel_values, + num_image_tokens=num_image_tokens, + ) + + return prepare + + def __call__( + self, + *, + prompt: str = None, + conversations: List[Dict[str, str]] = None, + images: List[Image] = None, + force_batchify: bool = True, + **kwargs, + ): + """ + + Args: + prompt (str): the formatted prompt; + conversations (List[Dict]): conversations with a list of messages; + images (List[ImageType]): the list of images; + force_batchify (bool): force batchify the inputs; + **kwargs: + + Returns: + outputs (BaseProcessorOutput): the output of the processor, + - input_ids (torch.LongTensor): [N + image tokens] + - images (torch.FloatTensor): [n_images, 3, H, W] + - image_id (int): the id of the image token + - num_image_tokens (List[int]): the number of image tokens + """ + + prepare = self.process_one( + prompt=prompt, conversations=conversations, images=images + ) + + if force_batchify: + prepare = self.batchify([prepare]) + + return prepare + + def batchify( + self, prepare_list: List[VLChatProcessorOutput] + ) -> BatchedVLChatProcessorOutput: + """ + Preprocesses the inputs for multimodal inference. + + Args: + prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput. + + Returns: + BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference. + """ + + batch_size = len(prepare_list) + sft_format = [] + n_images = [] + seq_lens = [] + for prepare in prepare_list: + n_images.append(len(prepare.num_image_tokens)) + seq_lens.append(len(prepare)) + + input_token_max_len = max(seq_lens) + max_n_images = max(1, max(n_images)) + + batched_input_ids = torch.full( + (batch_size, input_token_max_len), self.pad_id + ).long() # FIXME + batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long() + batched_pixel_values = torch.zeros( + (batch_size, max_n_images, *self.image_processor.default_shape) + ).float() + batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool() + batched_images_emb_mask = torch.zeros( + (batch_size, max_n_images, self.num_image_tokens) + ).bool() + + for i, prepare in enumerate(prepare_list): + input_ids = prepare.input_ids + seq_len = len(prepare) + n_image = len(prepare.num_image_tokens) + # left-padding + batched_attention_mask[i, -seq_len:] = 1 + batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids) + batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id + + if n_image > 0: + batched_pixel_values[i, :n_image] = prepare.pixel_values + for j, n_image_tokens in enumerate(prepare.num_image_tokens): + batched_images_emb_mask[i, j, :n_image_tokens] = True + + sft_format.append(prepare.sft_format) + + batched_prepares = BatchedVLChatProcessorOutput( + input_ids=batched_input_ids, + attention_mask=batched_attention_mask, + pixel_values=batched_pixel_values, + images_seq_mask=batched_images_seq_mask, + images_emb_mask=batched_images_emb_mask, + sft_format=sft_format, + ) + + return batched_prepares diff --git a/fudoki/janus/models/projector.py b/fudoki/janus/models/projector.py new file mode 100644 index 0000000..15f4ca3 --- /dev/null +++ b/fudoki/janus/models/projector.py @@ -0,0 +1,100 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +from typing import Tuple, Union + +import torch +import torch.nn as nn +from attrdict import AttrDict + + +class MlpProjector(nn.Module): + def __init__(self, cfg): + super().__init__() + + self.cfg = cfg + + if cfg.projector_type == "identity": + modules = nn.Identity() + + elif cfg.projector_type == "linear": + modules = nn.Linear(cfg.input_dim, cfg.n_embed) + + elif cfg.projector_type == "mlp_gelu": + mlp_depth = cfg.get("depth", 1) + modules = [nn.Linear(cfg.input_dim, cfg.n_embed)] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) + modules = nn.Sequential(*modules) + + elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu": + mlp_depth = cfg.get("depth", 1) + self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) + self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) + + modules = [] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) + modules = nn.Sequential(*modules) + + else: + raise ValueError(f"Unknown projector type: {cfg.projector_type}") + + self.layers = modules + + def forward( + self, x_or_tuple: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor] + ): + """ + + Args: + x_or_tuple (Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if it is a tuple of torch.Tensor, + then it comes from the hybrid vision encoder, and x = high_res_x, low_res_x); + otherwise it is the feature from the single vision encoder. + + Returns: + x (torch.Tensor): [b, s, c] + """ + + if isinstance(x_or_tuple, tuple): + # self.cfg.projector_type == "low_high_hybrid_split_mlp_gelu": + high_x, low_x = x_or_tuple + high_x = self.high_up_proj(high_x) + low_x = self.low_up_proj(low_x) + x = torch.concat([high_x, low_x], dim=-1) + else: + x = x_or_tuple + + return self.layers(x) + + +if __name__ == "__main__": + cfg = AttrDict( + input_dim=1024, + n_embed=2048, + depth=2, + projector_type="low_high_hybrid_split_mlp_gelu", + ) + inputs = (torch.rand(4, 576, 1024), torch.rand(4, 576, 1024)) + + m = MlpProjector(cfg) + out = m(inputs) + print(out.shape) diff --git a/fudoki/janus/models/siglip_vit.py b/fudoki/janus/models/siglip_vit.py new file mode 100644 index 0000000..ba426d6 --- /dev/null +++ b/fudoki/janus/models/siglip_vit.py @@ -0,0 +1,681 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +# https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py +import math +import warnings +from dataclasses import dataclass +from functools import partial +from typing import ( + Callable, + Dict, + Final, + List, + Literal, + Optional, + Sequence, + Set, + Tuple, + Type, + Union, +) + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.layers import ( + AttentionPoolLatent, + DropPath, + LayerType, + Mlp, + PatchDropout, + PatchEmbed, + resample_abs_pos_embed, +) +from timm.models._manipulate import checkpoint_seq, named_apply + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) # noqa: E741 + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + # type: (torch.Tensor, float, float, float, float) -> torch.Tensor + r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first + convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its original dtype. + Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn + from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + + with torch.no_grad(): + dtype = tensor.dtype + tensor_fp32 = tensor.float() + tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b) + tensor_dtype = tensor_fp32.to(dtype=dtype) + tensor.copy_(tensor_dtype) + + +def init_weights(self): + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) + trunc_normal_(self.latent, std=self.latent_dim**-0.5) + + +def init_weights_vit_timm(module: nn.Module, name: str = "") -> None: + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, "init_weights"): + module.init_weights() + + +class Attention(nn.Module): + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + # self.fused_attn = use_fused_attn() + self.fused_attn = True + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, self.head_dim) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_norm: bool = False, + proj_drop: float = 0.0, + attn_drop: float = 0.0, + init_values: Optional[float] = None, + drop_path: float = 0.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + mlp_layer: nn.Module = Mlp, + ) -> None: + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + ) + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = mlp_layer( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class VisionTransformer(nn.Module): + """Vision Transformer + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + """ + + dynamic_img_size: Final[bool] + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: Literal["", "avg", "token", "map"] = "token", + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_norm: bool = False, + init_values: Optional[float] = None, + class_token: bool = True, + no_embed_class: bool = False, + reg_tokens: int = 0, + pre_norm: bool = False, + fc_norm: Optional[bool] = None, + dynamic_img_size: bool = False, + dynamic_img_pad: bool = False, + drop_rate: float = 0.0, + pos_drop_rate: float = 0.0, + patch_drop_rate: float = 0.0, + proj_drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "", + embed_layer: Callable = PatchEmbed, + norm_layer: Optional[LayerType] = None, + act_layer: Optional[LayerType] = None, + block_fn: Type[nn.Module] = Block, + mlp_layer: Type[nn.Module] = Mlp, + ignore_head: bool = False, + ) -> None: + """ + Args: + img_size: Input image size. + patch_size: Patch size. + in_chans: Number of image input channels. + num_classes: Mumber of classes for classification head. + global_pool: Type of global pooling for final sequence (default: 'token'). + embed_dim: Transformer embedding dimension. + depth: Depth of transformer. + num_heads: Number of attention heads. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: Enable bias for qkv projections if True. + init_values: Layer-scale init values (layer-scale enabled if not None). + class_token: Use class token. + no_embed_class: Don't include position embeddings for class (or reg) tokens. + reg_tokens: Number of register tokens. + fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'. + drop_rate: Head dropout rate. + pos_drop_rate: Position embedding dropout rate. + attn_drop_rate: Attention dropout rate. + drop_path_rate: Stochastic depth rate. + weight_init: Weight initialization scheme. + embed_layer: Patch embedding layer. + norm_layer: Normalization layer. + act_layer: MLP activation layer. + block_fn: Transformer block layer. + """ + super().__init__() + assert global_pool in ("", "avg", "token", "map") + assert class_token or global_pool != "token" + use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm + # norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) + # act_layer = get_act_layer(act_layer) or nn.GELU + norm_layer = partial(nn.LayerNorm, eps=1e-6) + act_layer = nn.GELU + + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = self.embed_dim = ( + embed_dim # num_features for consistency with other models + ) + self.num_prefix_tokens = 1 if class_token else 0 + self.num_prefix_tokens += reg_tokens + self.num_reg_tokens = reg_tokens + self.has_class_token = class_token + self.no_embed_class = ( + no_embed_class # don't embed prefix positions (includes reg) + ) + self.dynamic_img_size = dynamic_img_size + self.grad_checkpointing = False + self.ignore_head = ignore_head + + embed_args = {} + if dynamic_img_size: + # flatten deferred until after pos embed + embed_args.update(dict(strict_img_size=False, output_fmt="NHWC")) + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) + dynamic_img_pad=dynamic_img_pad, + **embed_args, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = ( + nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None + ) + self.reg_token = ( + nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None + ) + embed_len = ( + num_patches if no_embed_class else num_patches + self.num_prefix_tokens + ) + self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02) + self.pos_drop = nn.Dropout(p=pos_drop_rate) + if patch_drop_rate > 0: + self.patch_drop = PatchDropout( + patch_drop_rate, + num_prefix_tokens=self.num_prefix_tokens, + ) + else: + self.patch_drop = nn.Identity() + self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + self.blocks = nn.Sequential( + *[ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + init_values=init_values, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + mlp_layer=mlp_layer, + ) + for i in range(depth) + ] + ) + self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() + + # Classifier Head + if global_pool == "map": + AttentionPoolLatent.init_weights = init_weights + self.attn_pool = AttentionPoolLatent( + self.embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + ) + else: + self.attn_pool = None + self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() + self.head_drop = nn.Dropout(drop_rate) + self.head = ( + nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + ) + + if weight_init != "skip": + self.init_weights(weight_init) + + def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None: + assert mode in ("jax", "jax_nlhb", "moco", "") + # head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0 + trunc_normal_(self.pos_embed, std=0.02) + if self.cls_token is not None: + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(init_weights_vit_timm, self) + + @torch.jit.ignore + def no_weight_decay(self) -> Set: + return {"pos_embed", "cls_token", "dist_token"} + + @torch.jit.ignore + def group_matcher(self, coarse: bool = False) -> Dict: + return dict( + stem=r"^cls_token|pos_embed|patch_embed", # stem and embed + blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))], + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True) -> None: + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self) -> nn.Module: + return self.head + + def reset_classifier(self, num_classes: int, global_pool=None) -> None: + self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ("", "avg", "token", "map") + if global_pool == "map" and self.attn_pool is None: + assert ( + False + ), "Cannot currently add attention pooling in reset_classifier()." + elif global_pool != "map " and self.attn_pool is not None: + self.attn_pool = None # remove attention pooling + self.global_pool = global_pool + self.head = ( + nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + ) + + def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: + if self.dynamic_img_size: + B, H, W, C = x.shape + pos_embed = resample_abs_pos_embed( + self.pos_embed, + (H, W), + num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, + ) + x = x.view(B, -1, C) + else: + pos_embed = self.pos_embed + + to_cat = [] + if self.cls_token is not None: + to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) + if self.reg_token is not None: + to_cat.append(self.reg_token.expand(x.shape[0], -1, -1)) + + if self.no_embed_class: + # deit-3, updated JAX (big vision) + # position embedding does not overlap with class token, add then concat + x = x + pos_embed + if to_cat: + x = torch.cat(to_cat + [x], dim=1) + else: + # original timm, JAX, and deit vit impl + # pos_embed has entry for class token, concat then add + if to_cat: + x = torch.cat(to_cat + [x], dim=1) + x = x + pos_embed + + return self.pos_drop(x) + + def _intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, + ) -> List[torch.Tensor]: + outputs, num_blocks = [], len(self.blocks) + take_indices = set( + range(num_blocks - n, num_blocks) if isinstance(n, int) else n + ) + + # forward pass + x = self.patch_embed(x) + x = self._pos_embed(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in take_indices: + outputs.append(x) + + return outputs + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, + reshape: bool = False, + return_prefix_tokens: bool = False, + norm: bool = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + """Intermediate layer accessor (NOTE: This is a WIP experiment). + Inspired by DINO / DINOv2 interface + """ + # take last n blocks if n is an int, if in is a sequence, select by matching indices + outputs = self._intermediate_layers(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs] + outputs = [out[:, self.num_prefix_tokens :] for out in outputs] + + if reshape: + grid_size = self.patch_embed.grid_size + outputs = [ + out.reshape(x.shape[0], grid_size[0], grid_size[1], -1) + .permute(0, 3, 1, 2) + .contiguous() + for out in outputs + ] + + if return_prefix_tokens: + return tuple(zip(outputs, prefix_tokens)) + return tuple(outputs) + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + x = self._pos_embed(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + x = self.norm(x) + return x + + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + if self.attn_pool is not None: + x = self.attn_pool(x) + elif self.global_pool == "avg": + x = x[:, self.num_prefix_tokens :].mean(dim=1) + elif self.global_pool: + x = x[:, 0] # class token + x = self.fc_norm(x) + x = self.head_drop(x) + return x if pre_logits else self.head(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.forward_features(x) + if not self.ignore_head: + x = self.forward_head(x) + return x + + +@dataclass +class SigLIPVisionCfg: + width: int = 1152 + layers: Union[Tuple[int, int, int, int], int] = 27 + heads: int = 16 + patch_size: int = 14 + image_size: Union[Tuple[int, int], int] = 336 + global_pool: str = "map" + mlp_ratio: float = 3.7362 + class_token: bool = False + num_classes: int = 0 + use_checkpoint: bool = False + + +SigLIP_MODEL_CONFIG = { + "siglip_so400m_patch14_384": { + "image_size": 336, + "patch_size": 14, + "width": 1152, + "layers": 27, + "heads": 16, + "mlp_ratio": 3.7362, + "global_pool": "map", + "use_checkpoint": False, + }, + "siglip_so400m_patch14_224": { + "image_size": 224, + "patch_size": 14, + "width": 1152, + "layers": 27, + "heads": 16, + "mlp_ratio": 3.7362, + "global_pool": "map", + "use_checkpoint": False, + }, + "siglip_large_patch16_384": { + "image_size": 384, + "patch_size": 16, + "width": 1024, + "layers": 24, + "heads": 16, + "mlp_ratio": 4, + "global_pool": "map", + "use_checkpoint": False, + }, +} + + +def create_siglip_vit( + model_name: str = "siglip_so400m_patch14_384", + image_size: int = 384, + select_layer: int = -1, + ckpt_path: str = "", + **kwargs, +): + assert ( + model_name in SigLIP_MODEL_CONFIG.keys() + ), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}" + + vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name]) + + if select_layer <= 0: + layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1) + else: + layers = min(vision_cfg.layers, select_layer) + + model = VisionTransformer( + img_size=image_size, + patch_size=vision_cfg.patch_size, + embed_dim=vision_cfg.width, + depth=layers, + num_heads=vision_cfg.heads, + mlp_ratio=vision_cfg.mlp_ratio, + class_token=vision_cfg.class_token, + global_pool=vision_cfg.global_pool, + ignore_head=kwargs.get("ignore_head", True), + weight_init=kwargs.get("weight_init", "skip"), + num_classes=0, + ) + + if ckpt_path: + state_dict = torch.load(ckpt_path, map_location="cpu") + + incompatible_keys = model.load_state_dict(state_dict, strict=False) + print( + f"SigLIP-ViT restores from {ckpt_path},\n" + f"\tincompatible_keys:', {incompatible_keys}." + ) + + return model diff --git a/fudoki/janus/models/vq_model.py b/fudoki/janus/models/vq_model.py new file mode 100644 index 0000000..90a47e2 --- /dev/null +++ b/fudoki/janus/models/vq_model.py @@ -0,0 +1,527 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + +from dataclasses import dataclass, field +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from functools import partial + + +@dataclass +class ModelArgs: + codebook_size: int = 16384 + codebook_embed_dim: int = 8 + codebook_l2_norm: bool = True + codebook_show_usage: bool = True + commit_loss_beta: float = 0.25 + entropy_loss_ratio: float = 0.0 + + encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4]) + decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4]) + z_channels: int = 256 + dropout_p: float = 0.0 + + +class Encoder(nn.Module): + def __init__( + self, + in_channels=3, + ch=128, + ch_mult=(1, 1, 2, 2, 4), + num_res_blocks=2, + norm_type="group", + dropout=0.0, + resamp_with_conv=True, + z_channels=256, + ): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1) + + # downsampling + in_ch_mult = (1,) + tuple(ch_mult) + self.conv_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + conv_block = nn.Module() + # res & attn + res_block = nn.ModuleList() + attn_block = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + res_block.append( + ResnetBlock( + block_in, block_out, dropout=dropout, norm_type=norm_type + ) + ) + block_in = block_out + if i_level == self.num_resolutions - 1: + attn_block.append(AttnBlock(block_in, norm_type)) + conv_block.res = res_block + conv_block.attn = attn_block + # downsample + if i_level != self.num_resolutions - 1: + conv_block.downsample = Downsample(block_in, resamp_with_conv) + self.conv_blocks.append(conv_block) + + # middle + self.mid = nn.ModuleList() + self.mid.append( + ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) + ) + self.mid.append(AttnBlock(block_in, norm_type=norm_type)) + self.mid.append( + ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) + ) + + # end + self.norm_out = Normalize(block_in, norm_type) + self.conv_out = nn.Conv2d( + block_in, z_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + h = self.conv_in(x) + # downsampling + for i_level, block in enumerate(self.conv_blocks): + for i_block in range(self.num_res_blocks): + h = block.res[i_block](h) + if len(block.attn) > 0: + h = block.attn[i_block](h) + if i_level != self.num_resolutions - 1: + h = block.downsample(h) + + # middle + for mid_block in self.mid: + h = mid_block(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + z_channels=256, + ch=128, + ch_mult=(1, 1, 2, 2, 4), + num_res_blocks=2, + norm_type="group", + dropout=0.0, + resamp_with_conv=True, + out_channels=3, + ): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + + block_in = ch * ch_mult[self.num_resolutions - 1] + # z to block_in + self.conv_in = nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.ModuleList() + self.mid.append( + ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) + ) + self.mid.append(AttnBlock(block_in, norm_type=norm_type)) + self.mid.append( + ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) + ) + + # upsampling + self.conv_blocks = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + conv_block = nn.Module() + # res & attn + res_block = nn.ModuleList() + attn_block = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + res_block.append( + ResnetBlock( + block_in, block_out, dropout=dropout, norm_type=norm_type + ) + ) + block_in = block_out + if i_level == self.num_resolutions - 1: + attn_block.append(AttnBlock(block_in, norm_type)) + conv_block.res = res_block + conv_block.attn = attn_block + # downsample + if i_level != 0: + conv_block.upsample = Upsample(block_in, resamp_with_conv) + self.conv_blocks.append(conv_block) + + # end + self.norm_out = Normalize(block_in, norm_type) + self.conv_out = nn.Conv2d( + block_in, out_channels, kernel_size=3, stride=1, padding=1 + ) + + @property + def last_layer(self): + return self.conv_out.weight + + def forward(self, z): + # z to block_in + h = self.conv_in(z) + + # middle + for mid_block in self.mid: + h = mid_block(h) + + # upsampling + for i_level, block in enumerate(self.conv_blocks): + for i_block in range(self.num_res_blocks + 1): + h = block.res[i_block](h) + if len(block.attn) > 0: + h = block.attn[i_block](h) + if i_level != self.num_resolutions - 1: + h = block.upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class VectorQuantizer(nn.Module): + def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.entropy_loss_ratio = entropy_loss_ratio + self.l2_norm = l2_norm + self.show_usage = show_usage + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + if self.l2_norm: + self.embedding.weight.data = F.normalize( + self.embedding.weight.data, p=2, dim=-1 + ) + if self.show_usage: + self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536))) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + z = torch.einsum("b c h w -> b h w c", z).contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + if self.l2_norm: + z = F.normalize(z, p=2, dim=-1) + z_flattened = F.normalize(z_flattened, p=2, dim=-1) + embedding = F.normalize(self.embedding.weight, p=2, dim=-1) + else: + embedding = self.embedding.weight + + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(embedding**2, dim=1) + - 2 + * torch.einsum( + "bd,dn->bn", z_flattened, torch.einsum("n d -> d n", embedding) + ) + ) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = embedding[min_encoding_indices].view(z.shape) + perplexity = None + min_encodings = None + vq_loss = None + commit_loss = None + entropy_loss = None + + # compute loss for embedding + if self.training: + vq_loss = torch.mean((z_q - z.detach()) ** 2) + commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = torch.einsum("b h w c -> b c h w", z_q) + + return ( + z_q, + (vq_loss, commit_loss, entropy_loss), + (perplexity, min_encodings, min_encoding_indices), + ) + + def get_codebook_entry(self, indices, shape=None, channel_first=True): + # shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel) + if self.l2_norm: + embedding = F.normalize(self.embedding.weight, p=2, dim=-1) + else: + embedding = self.embedding.weight + z_q = embedding[indices] # (b*h*w, c) + + if shape is not None: + if channel_first: + z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1]) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + else: + z_q = z_q.view(shape) + return z_q + + +class ResnetBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + norm_type="group", + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels, norm_type) + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.norm2 = Normalize(out_channels, norm_type) + self.dropout = nn.Dropout(dropout) + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + return x + h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels, norm_type="group"): + super().__init__() + self.norm = Normalize(in_channels, norm_type) + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, norm_type="group"): + assert norm_type in ["group", "batch"] + if norm_type == "group": + return nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + elif norm_type == "batch": + return nn.SyncBatchNorm(in_channels) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + if x.dtype != torch.float32: + x = F.interpolate(x.to(torch.float), scale_factor=2.0, mode="nearest").to( + torch.float16 + ) + else: + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = F.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01): + flat_affinity = affinity.reshape(-1, affinity.shape[-1]) + flat_affinity /= temperature + probs = F.softmax(flat_affinity, dim=-1) + log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1) + if loss_type == "softmax": + target_probs = probs + else: + raise ValueError("Entropy loss {} not supported".format(loss_type)) + avg_probs = torch.mean(target_probs, dim=0) + avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-5)) + sample_entropy = -torch.mean(torch.sum(target_probs * log_probs, dim=-1)) + loss = sample_entropy - avg_entropy + return loss + + +class VQModel(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + self.encoder = Encoder( + ch_mult=config.encoder_ch_mult, + z_channels=config.z_channels, + dropout=config.dropout_p, + ) + self.decoder = Decoder( + ch_mult=config.decoder_ch_mult, + z_channels=config.z_channels, + dropout=config.dropout_p, + ) + + self.quantize = VectorQuantizer( + config.codebook_size, + config.codebook_embed_dim, + config.commit_loss_beta, + config.entropy_loss_ratio, + config.codebook_l2_norm, + config.codebook_show_usage, + ) + self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1) + self.post_quant_conv = nn.Conv2d( + config.codebook_embed_dim, config.z_channels, 1 + ) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b, shape=None, channel_first=True): + quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first) + dec = self.decode(quant_b) + return dec + + def forward(self, input): + quant, diff, _ = self.encode(input) + dec = self.decode(quant) + return dec, diff + + +################################################################################# +# VQ Model Configs # +################################################################################# +def VQ_16(**kwargs): + return VQModel( + ModelArgs( + encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs + ) + ) + + +VQ_models = {"VQ-16": VQ_16} diff --git a/fudoki/janus/utils/__init__.py b/fudoki/janus/utils/__init__.py new file mode 100644 index 0000000..8cb7640 --- /dev/null +++ b/fudoki/janus/utils/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/fudoki/janus/utils/conversation.py b/fudoki/janus/utils/conversation.py new file mode 100644 index 0000000..3b227fe --- /dev/null +++ b/fudoki/janus/utils/conversation.py @@ -0,0 +1,337 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +""" +From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py +""" + +import dataclasses +from enum import IntEnum, auto +from typing import Dict, List + + +class SeparatorStyle(IntEnum): + """Separator styles.""" + + ADD_COLON_SINGLE = auto() + ADD_COLON_TWO = auto() + ADD_COLON_SPACE_SINGLE = auto() + NO_COLON_SINGLE = auto() + NO_COLON_TWO = auto() + ADD_NEW_LINE_SINGLE = auto() + LLAMA2 = auto() + CHATGLM = auto() + CHATML = auto() + CHATINTERN = auto() + DOLLY = auto() + RWKV = auto() + PHOENIX = auto() + ROBIN = auto() + DeepSeek = auto() + PLAIN = auto() + ALIGNMENT = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that manages prompt templates and keeps all conversation history.""" + + # The name of this template + name: str + # The template of the system prompt + system_template: str = "{system_message}" + # The system message + system_message: str = "" + # The names of two roles + roles: List[str] = (("USER", "ASSISTANT"),) + # All messages. Each item is (role, message). + messages: List[List[str]] = () + # The number of few shot examples + offset: int = 0 + # The separator style and configurations + sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE + sep: str = "\n" + sep2: str = None + # Stop criteria (the default one is EOS token) + stop_str: str = None + # Stops generation if meeting any token in this list + stop_token_ids: List[int] = None + + def get_prompt(self) -> str: + """Get the prompt for generation.""" + system_prompt = self.system_template.format(system_message=self.system_message) + + if self.sep_style == SeparatorStyle.DeepSeek: + seps = [self.sep, self.sep2] + if system_prompt == "" or system_prompt is None: + ret = "" + else: + ret = system_prompt + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.LLAMA2: + seps = [self.sep, self.sep2] + if self.system_message: + ret = system_prompt + else: + ret = "[INST] " + for i, (role, message) in enumerate(self.messages): + tag = self.roles[i % 2] + if message: + if type(message) is tuple: # multimodal message + message, _ = message + if i == 0: + ret += message + " " + else: + ret += tag + " " + message + seps[i % 2] + else: + ret += tag + return ret + elif self.sep_style == SeparatorStyle.PLAIN: + seps = [self.sep, self.sep2] + ret = "" + for i, (role, message) in enumerate(self.messages): + if message: + if type(message) is tuple: + message, _, _ = message + if i % 2 == 0: + ret += message + seps[i % 2] + else: + ret += message + seps[i % 2] + else: + ret += "" + return ret + elif self.sep_style == SeparatorStyle.ALIGNMENT: + seps = [self.sep, self.sep2] + ret = "" + for i, (role, message) in enumerate(self.messages): + if message: + if type(message) is tuple: + message, _, _ = message + if i % 2 == 0: + ret += "\n" + seps[i % 2] + else: + ret += message + seps[i % 2] + else: + ret += "" + return ret + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + def get_prompt_for_current_round(self, content=None): + """Get current round formatted question prompt during sft training""" + if self.sep_style == SeparatorStyle.PLAIN: + formatted_question = "\n" + elif self.sep_style == SeparatorStyle.DeepSeek: + formatted_question = ( + f"{self.roles[0]}: " + content.strip() + self.sep + f"{self.roles[1]}:" + ) + else: + raise ValueError(f"Unsupported sep_style: {self.sep_style}") + return formatted_question + + def set_system_message(self, system_message: str): + """Set the system message.""" + self.system_message = system_message + + def append_message(self, role: str, message: str): + """Append a new message.""" + self.messages.append([role, message]) + + def reset_message(self): + """Reset a new message.""" + self.messages = [] + + def update_last_message(self, message: str): + """Update the last output. + + The last message is typically set to be None when constructing the prompt, + so we need to update it in-place after getting the response from a model. + """ + self.messages[-1][1] = message + + def to_gradio_chatbot(self): + """Convert the conversation to gradio chatbot format.""" + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def to_openai_api_messages(self): + """Convert the conversation to OpenAI chat completion format.""" + system_prompt = self.system_template.format(system_message=self.system_message) + ret = [{"role": "system", "content": system_prompt}] + + for i, (_, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append({"role": "user", "content": msg}) + else: + if msg is not None: + ret.append({"role": "assistant", "content": msg}) + return ret + + def copy(self): + return Conversation( + name=self.name, + system_template=self.system_template, + system_message=self.system_message, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + stop_str=self.stop_str, + stop_token_ids=self.stop_token_ids, + ) + + def dict(self): + return { + "template_name": self.name, + "system_message": self.system_message, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + } + + +# A global registry for all conversation templates +conv_templates: Dict[str, Conversation] = {} + + +def register_conv_template(template: Conversation, override: bool = False): + """Register a new conversation template.""" + if not override: + assert ( + template.name not in conv_templates + ), f"{template.name} has been registered." + + conv_templates[template.name] = template + + +def get_conv_template(name: str) -> Conversation: + """Get a conversation template.""" + return conv_templates[name].copy() + + +# llava_llama2 template +register_conv_template( + Conversation( + name="llava_llama2", + system_message="You are a helpful language and vision assistant. " + "You are able to understand the visual content that the user provides, " + "and assist the user with a variety of tasks using natural language.", + system_template="[INST] <>\n{system_message}\n<>\n\n", + roles=("[INST]", "[/INST]"), + messages=(), + offset=0, + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + stop_token_ids=[2], + ) +) + +# llama2 template +# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212 +register_conv_template( + Conversation( + name="llama-2", + system_template="[INST] <>\n{system_message}\n<>\n\n", + roles=("[INST]", "[/INST]"), + messages=(), + offset=0, + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + stop_token_ids=[2], + ) +) + + +# deepseek template +register_conv_template( + Conversation( + name="fudoki", + system_template="{system_message}", + # system_message="You are a helpful assistant. Please answer truthfully and write out your " + # "thinking step by step to be sure you get the right answer.", + system_message="", + roles=("User", "Assistant"), + messages=(), + offset=0, + sep_style=SeparatorStyle.DeepSeek, + sep="\n\n", + sep2="<|end▁of▁sentence|>", + stop_token_ids=[100001], + stop_str=["User:", "<|end▁of▁sentence|>"], + ) +) + +register_conv_template( + Conversation( + name="plain", + system_template="", + system_message="", + roles=("", ""), + messages=(), + offset=0, + sep_style=SeparatorStyle.PLAIN, + sep="", + sep2="", + stop_token_ids=[2], + stop_str=[""], + ) +) + + +register_conv_template( + Conversation( + name="alignment", + system_template="", + system_message="", + roles=("", ""), + messages=(), + offset=0, + sep_style=SeparatorStyle.ALIGNMENT, + sep="", + sep2="", + stop_token_ids=[2], + stop_str=[""], + ) +) + + +if __name__ == "__main__": + print("fudoki template:") + conv = get_conv_template("fudoki") + conv.append_message(conv.roles[0], "Hello!") + conv.append_message(conv.roles[1], "Hi! This is Tony.") + conv.append_message(conv.roles[0], "Who are you?") + conv.append_message(conv.roles[1], "I am a helpful assistant.") + conv.append_message(conv.roles[0], "How are you?") + conv.append_message(conv.roles[1], None) + print(conv.get_prompt()) diff --git a/fudoki/janus/utils/io.py b/fudoki/janus/utils/io.py new file mode 100644 index 0000000..e8e9da7 --- /dev/null +++ b/fudoki/janus/utils/io.py @@ -0,0 +1,89 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import json +from typing import Dict, List + +import PIL.Image +import torch +import base64 +import io +from transformers import AutoModelForCausalLM + +from janus.models import MultiModalityCausalLM, VLChatProcessor + + +def load_pretrained_model(model_path: str): + vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path) + tokenizer = vl_chat_processor.tokenizer + + vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained( + model_path, trust_remote_code=True + ) + vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval() + + return tokenizer, vl_chat_processor, vl_gpt + + +def load_pil_images(conversations: List[Dict[str, str]]) -> List[PIL.Image.Image]: + """ + + Support file path or base64 images. + + Args: + conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is : + [ + { + "role": "User", + "content": "\nExtract all information from this image and convert them into markdown format.", + "images": ["./examples/table_datasets.png"] + }, + {"role": "Assistant", "content": ""}, + ] + + Returns: + pil_images (List[PIL.Image.Image]): the list of PIL images. + + """ + + pil_images = [] + + for message in conversations: + if "images" not in message: + continue + + for image_data in message["images"]: + if image_data.startswith("data:image"): + # Image data is in base64 format + _, image_data = image_data.split(",", 1) + image_bytes = base64.b64decode(image_data) + pil_img = PIL.Image.open(io.BytesIO(image_bytes)) + else: + # Image data is a file path + pil_img = PIL.Image.open(image_data) + pil_img = pil_img.convert("RGB") + pil_images.append(pil_img) + + return pil_images + + +def load_json(filepath): + with open(filepath, "r") as f: + data = json.load(f) + return data diff --git a/fudoki/model.py b/fudoki/model.py new file mode 100644 index 0000000..4cfe97e --- /dev/null +++ b/fudoki/model.py @@ -0,0 +1,11 @@ +from transformers import AutoModelForCausalLM +from fudoki.janus.models import MultiModalityCausalLM + + +def instantiate_model(pretrained_weight_path): + + vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained( + pretrained_weight_path, trust_remote_code=True + ) + model = vl_gpt + return model diff --git a/fudoki/vq_model.py b/fudoki/vq_model.py new file mode 100644 index 0000000..4ee2b48 --- /dev/null +++ b/fudoki/vq_model.py @@ -0,0 +1,424 @@ +# Inherited from LlamaGen (https://arxiv.org/abs/2406.06525), which is modified from: +# taming-transformers: https://github.com/CompVis/taming-transformers +# maskgit: https://github.com/google-research/maskgit +from dataclasses import dataclass, field +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +@dataclass +class ModelArgs: + codebook_size: int = 16384 + codebook_embed_dim: int = 8 + codebook_l2_norm: bool = True + codebook_show_usage: bool = True + commit_loss_beta: float = 0.25 + entropy_loss_ratio: float = 0.0 + + encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4]) + decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4]) + z_channels: int = 256 + dropout_p: float = 0.0 + + + +class VQModel(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + self.encoder = Encoder(ch_mult=config.encoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p) + self.decoder = Decoder(ch_mult=config.decoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p) + + self.quantize = VectorQuantizer(config.codebook_size, config.codebook_embed_dim, + config.commit_loss_beta, config.entropy_loss_ratio, + config.codebook_l2_norm, config.codebook_show_usage) + self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1) + self.post_quant_conv = nn.Conv2d(config.codebook_embed_dim, config.z_channels, 1) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b, shape=None, channel_first=True): + quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first) + dec = self.decode(quant_b) + return dec + + def forward(self, input): + quant, diff, _ = self.encode(input) + dec = self.decode(quant) + return dec, diff + + + +class Encoder(nn.Module): + def __init__(self, in_channels=3, ch=128, ch_mult=(1,1,2,2,4), num_res_blocks=2, + norm_type='group', dropout=0.0, resamp_with_conv=True, z_channels=256): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1) + + # downsampling + in_ch_mult = (1,) + tuple(ch_mult) + self.conv_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + conv_block = nn.Module() + # res & attn + res_block = nn.ModuleList() + attn_block = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for _ in range(self.num_res_blocks): + res_block.append(ResnetBlock(block_in, block_out, dropout=dropout, norm_type=norm_type)) + block_in = block_out + if i_level == self.num_resolutions - 1: + attn_block.append(AttnBlock(block_in, norm_type)) + conv_block.res = res_block + conv_block.attn = attn_block + # downsample + if i_level != self.num_resolutions-1: + conv_block.downsample = Downsample(block_in, resamp_with_conv) + self.conv_blocks.append(conv_block) + + # middle + self.mid = nn.ModuleList() + self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)) + self.mid.append(AttnBlock(block_in, norm_type=norm_type)) + self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)) + + # end + self.norm_out = Normalize(block_in, norm_type) + self.conv_out = nn.Conv2d(block_in, z_channels, kernel_size=3, stride=1, padding=1) + + + def forward(self, x): + h = self.conv_in(x) + # downsampling + for i_level, block in enumerate(self.conv_blocks): + for i_block in range(self.num_res_blocks): + h = block.res[i_block](h) + if len(block.attn) > 0: + h = block.attn[i_block](h) + if i_level != self.num_resolutions - 1: + h = block.downsample(h) + + # middle + for mid_block in self.mid: + h = mid_block(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + + +class Decoder(nn.Module): + def __init__(self, z_channels=256, ch=128, ch_mult=(1,1,2,2,4), num_res_blocks=2, norm_type="group", + dropout=0.0, resamp_with_conv=True, out_channels=3): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + + block_in = ch*ch_mult[self.num_resolutions-1] + # z to block_in + self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.ModuleList() + self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)) + self.mid.append(AttnBlock(block_in, norm_type=norm_type)) + self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)) + + # upsampling + self.conv_blocks = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + conv_block = nn.Module() + # res & attn + res_block = nn.ModuleList() + attn_block = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(block_in, block_out, dropout=dropout, norm_type=norm_type)) + block_in = block_out + if i_level == self.num_resolutions - 1: + attn_block.append(AttnBlock(block_in, norm_type)) + conv_block.res = res_block + conv_block.attn = attn_block + # downsample + if i_level != 0: + conv_block.upsample = Upsample(block_in, resamp_with_conv) + self.conv_blocks.append(conv_block) + + # end + self.norm_out = Normalize(block_in, norm_type) + self.conv_out = nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1) + + @property + def last_layer(self): + return self.conv_out.weight + + def forward(self, z): + # z to block_in + h = self.conv_in(z) + + # middle + for mid_block in self.mid: + h = mid_block(h) + + # upsampling + for i_level, block in enumerate(self.conv_blocks): + for i_block in range(self.num_res_blocks + 1): + h = block.res[i_block](h) + if len(block.attn) > 0: + h = block.attn[i_block](h) + if i_level != self.num_resolutions - 1: + h = block.upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class VectorQuantizer(nn.Module): + def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.entropy_loss_ratio = entropy_loss_ratio + self.l2_norm = l2_norm + self.show_usage = show_usage + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + if self.l2_norm: + self.embedding.weight.data = F.normalize(self.embedding.weight.data, p=2, dim=-1) + if self.show_usage: + self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536))) + + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + z = torch.einsum('b c h w -> b h w c', z).contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + if self.l2_norm: + z = F.normalize(z, p=2, dim=-1) + z_flattened = F.normalize(z_flattened, p=2, dim=-1) + embedding = F.normalize(self.embedding.weight, p=2, dim=-1) + else: + embedding = self.embedding.weight + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(embedding**2, dim=1) - 2 * \ + torch.einsum('bd,dn->bn', z_flattened, torch.einsum('n d -> d n', embedding)) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = embedding[min_encoding_indices].view(z.shape) + perplexity = None + min_encodings = None + vq_loss = None + commit_loss = None + entropy_loss = None + codebook_usage = 0 + + if self.show_usage and self.training: + cur_len = min_encoding_indices.shape[0] + self.codebook_used[:-cur_len] = self.codebook_used[cur_len:].clone() + self.codebook_used[-cur_len:] = min_encoding_indices + codebook_usage = len(torch.unique(self.codebook_used)) / self.n_e + + # compute loss for embedding + if self.training: + vq_loss = torch.mean((z_q - z.detach()) ** 2) + commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = torch.einsum('b h w c -> b c h w', z_q) + + return z_q, (vq_loss, commit_loss, entropy_loss, codebook_usage), (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape=None, channel_first=True): + # shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel) + if self.l2_norm: + embedding = F.normalize(self.embedding.weight, p=2, dim=-1) + else: + embedding = self.embedding.weight + z_q = embedding[indices] # (b*h*w, c) + + if shape is not None: + if channel_first: + z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1]) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + else: + z_q = z_q.view(shape) + return z_q + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group'): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels, norm_type) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = Normalize(out_channels, norm_type) + self.dropout = nn.Dropout(dropout) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + else: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + return x+h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels, norm_type='group'): + super().__init__() + self.norm = Normalize(in_channels, norm_type) + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, norm_type='group'): + assert norm_type in ['group', 'batch'] + if norm_type == 'group': + return nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + elif norm_type == 'batch': + return nn.SyncBatchNorm(in_channels) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = F.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = F.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01): + flat_affinity = affinity.reshape(-1, affinity.shape[-1]) + flat_affinity /= temperature + probs = F.softmax(flat_affinity, dim=-1) + log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1) + if loss_type == "softmax": + target_probs = probs + else: + raise ValueError("Entropy loss {} not supported".format(loss_type)) + avg_probs = torch.mean(target_probs, dim=0) + avg_entropy = - torch.sum(avg_probs * torch.log(avg_probs + 1e-5)) + sample_entropy = - torch.mean(torch.sum(target_probs * log_probs, dim=-1)) + loss = sample_entropy - avg_entropy + return loss + + +################################################################################# +# VQ Model Configs # +################################################################################# +def VQ_8(**kwargs): + return VQModel(ModelArgs(encoder_ch_mult=[1, 2, 2, 4], decoder_ch_mult=[1, 2, 2, 4], **kwargs)) + +def VQ_16(**kwargs): + return VQModel(ModelArgs(encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs)) + +VQ_models = {'VQ-16': VQ_16, 'VQ-8': VQ_8} \ No newline at end of file diff --git a/infer.py b/infer.py new file mode 100644 index 0000000..051d12e --- /dev/null +++ b/infer.py @@ -0,0 +1,332 @@ +# LINT_ME +import os +import argparse + +import numpy as np +from PIL import Image +import torch +import torch.distributed as dist +from torchvision import transforms +from torch.backends import cudnn +from transformers import set_seed + +from flow_matching.data.navsim import resize_pad +from flow_matching.path import MixtureDiscreteSoftmaxProbPath +from flow_matching.solver import MixtureDiscreteSoftmaxEulerSolver +from fudoki.eval_loop import CFGScaledModel +from fudoki.janus.models import VLChatProcessor +from fudoki.model import instantiate_model + + +VOCABULARY_SIZE_TXT = 102400 +VOCABULARY_SIZE_IMG = 16384 +IMG_LEN = 576 + + +def parse_arguments(): + parser = argparse.ArgumentParser(description="Run the script with custom arguments.") + parser.add_argument( + "--seed", type=int, default=999, + help="Random seed for reproducibility." + ) + parser.add_argument( + "--batch_size", type=int, default=1, + help="Batch size for processing." + ) + parser.add_argument( + "--checkpoint_path", type=str, required=True, + help="Path to the checkpoint directory." + ) + parser.add_argument( + "--processor_path", type=str, required=True, + help="Path to the processor." + ) + parser.add_argument( + "--text_embedding_path", type=str, required=True, + help="Path to the text embedding." + ) + parser.add_argument( + "--image_embedding_path", type=str, required=True, + help="Path to the image embedding." + ) + parser.add_argument( + "--discrete_fm_steps", type=int, default=5, + help="Inference steps for discrete flow matching" + ) + parser.add_argument( + "--txt_max_length", type=int, default=500, + help="Text length maximum" + ) + parser.add_argument( + "--image_paths", type=str, required=True, + help="Path to the input image." + ) + parser.add_argument( + "--output_dir", type=str, required=False, + help="Directory to save the output files." + ) + return parser.parse_args() + + +def extract_number_pairs(input_str, k=None): + """ + Extract number pairs from a string and form tuples of (float(x), float(y)) + + Parameters: + input_str (str): Input string in the format "4.78,-0.01,9.69,-0.01,..." + k (int, optional): Number of tuples to extract. If None, extract all possible tuples + + Returns: + tuple: Contains two elements + - list: List of successfully extracted tuples + - list: List of error messages + """ + pairs = [] + errors = [] + + # Check if input is empty or not a string + if not input_str or not isinstance(input_str, str): + errors.append("Input is not a valid string") + return pairs, errors + + # Split the string + elements = input_str.split(',') + + # Process each element and attempt to convert to float + numbers = [] + for index, elem in enumerate(elements): + # Remove possible whitespace characters + elem_clean = elem.strip() + if not elem_clean: + errors.append(f"Empty value at position {index}") + continue + + try: + num = float(elem_clean) + numbers.append(num) + except ValueError: + errors.append(f"Value '{elem_clean}' at position {index} \ + cannot be converted to a number") + + # Calculate maximum possible pairs + max_possible = len(numbers) // 2 + + # Determine number of pairs to extract + if k is None: + # Extract all possible pairs + num_to_extract = max_possible + else: + # Ensure k is a positive integer + try: + k = int(k) + if k <= 0: + errors.append(f"k value {k} must be a positive integer") + return pairs, errors + num_to_extract = min(k, max_possible) + except (ValueError, TypeError): + errors.append(f"k value {k} is not a valid integer") + return pairs, errors + + # Extract number pairs + for i in range(num_to_extract): + x = numbers[2*i] + y = numbers[2*i + 1] + pairs.append((x, y)) + + # Check for unpaired numbers + if len(numbers) % 2 != 0: + errors.append(f"There are {len(numbers) % 2} unpaired number(s)") + + # Check if requested k value was achieved + if k is not None and num_to_extract < k: + errors.append(f"Only {num_to_extract} valid number pairs can be extracted, \ + less than the requested {k}") + + return pairs, errors + + +def main(): + args = parse_arguments() + + dist.init_process_group( + "nccl", + rank=int(os.environ["RANK"]), + world_size=int(os.environ["WORLD_SIZE"]), + ) + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + set_seed(args.seed) + cudnn.benchmark = True + device = torch.device(f"cuda:{local_rank}") + + image_paths = args.image_paths.split(',') + + vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(args.processor_path) + num_tokens_length = 0 + num_tokens = [f"{x:.2f}" for x in np.linspace(-100, 100, 20001)] + num_tokens_length = len(num_tokens) + vl_chat_processor.tokenizer.add_tokens(num_tokens) + + model = instantiate_model(args.checkpoint_path).to(device, dtype=torch.float32) + model.train(False) + + batch_size = args.batch_size + discrete_fm_steps = args.discrete_fm_steps + txt_max_length = args.txt_max_length + + cfg_weighted_model = CFGScaledModel(model=model, g_or_u='understanding') + with torch.no_grad(): + path_txt = MixtureDiscreteSoftmaxProbPath( + mode='text', + embedding_path=args.text_embedding_path + ) + path_txt.set_embedding(model.language_model.get_input_embeddings()) + path_img = MixtureDiscreteSoftmaxProbPath( + mode='image', + embedding_path=args.image_embedding_path + ) + solver = MixtureDiscreteSoftmaxEulerSolver( + model=cfg_weighted_model, + path_txt=path_txt, + path_img=path_img, + vocabulary_size_txt=VOCABULARY_SIZE_TXT + num_tokens_length, + vocabulary_size_img=VOCABULARY_SIZE_IMG, + ) + + imgs = [] + if isinstance(image_paths, str): + image_paths = [image_paths] + for path in image_paths: + img = Image.open(path).convert("RGB") + transform = transforms.Compose([ + transforms.Lambda(resize_pad), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) + ]) + imgs.append(transform(img)) + + if len(imgs) > 0: + imgs = torch.stack(imgs, dim=0) # [N, C, H, W] + img_len = len(imgs) * IMG_LEN + else: + imgs = None + img_len = IMG_LEN # default + + generation_understanding_indicator = 0 # this is an understanding sample + conversation = [ + { + "role": "User", + "content": ( + "Here is front views of a driving vehicle:\n\n" + "The navigation information is: straight\n" + "The current position is (0.00,0.00)\n" + "Current velocity is: (8.34,0.18) and current accelerate is: (-0.83,0.28)\n" + "Predict the optimal driving action for the next 4 seconds with 8 new waypoints." + ) + }, + { + "role": "Assistant", + "content": "" + } # "3.88,-0.06,7.50,-0.07,10.86,-0.10,13.95,-0.11,16.75,-0.13,19.29,-0.15,21.60,-0.12,23.67,-0.11" + ] + sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( + conversations=conversation, + sft_format=vl_chat_processor.sft_format, + system_prompt=vl_chat_processor.system_prompt, + ) + + # tokenize + input_ids = vl_chat_processor.tokenizer.encode(sft_format) + input_ids = torch.LongTensor(input_ids) + # add image tokens to the input_ids + image_token_mask = input_ids == vl_chat_processor.image_id + image_indices = image_token_mask.nonzero() + assert len(image_indices) == len(image_paths), \ + f"Number of images ({len(image_paths)}) \ + does not match the number of image tokens ({len(image_indices)})" + + input_ids, _ = vl_chat_processor.add_image_token( + image_indices=image_indices, + input_ids=input_ids, + ) + + # pad tokens + original_input_id_len = input_ids.shape[0] + if original_input_id_len >= txt_max_length + img_len: + raise ValueError("Sentences too long, not supported so far...") + + rows_to_pad = txt_max_length + img_len - input_ids.shape[0] + input_ids = torch.concat([ + input_ids, + torch.LongTensor([vl_chat_processor.pad_id]).repeat(rows_to_pad) + ], dim=0) + attention_mask = torch.zeros((input_ids.shape[0]), dtype=torch.bool) + attention_mask[:] = True + + # obtain image token mask and fill in img token_ids + if imgs is not None: + image_expanded_token_mask = (input_ids == vl_chat_processor.image_id).to(dtype=int) + image_expanded_mask_indices = torch.where(image_expanded_token_mask == 1)[0] + input_ids[image_expanded_mask_indices] = 0 + else: + image_expanded_token_mask = torch.zeros_like(input_ids) + + # obtain text token mask + # We assume that there is only one turn for assistant to respond + text_expanded_token_mask = torch.zeros_like(image_expanded_token_mask) + split_token = vl_chat_processor.tokenizer.encode("Assistant:", add_special_tokens=False) + split_token_length = len(split_token) + + start_index = -1 + for j in range(len(input_ids) - split_token_length + 1): + if input_ids[j:j + split_token_length].numpy().tolist() == split_token: + start_index = j + break + if start_index != -1: + text_expanded_token_mask[(start_index+split_token_length):] = 1 + else: + raise ValueError("Split token not found in input_ids") + + generation_or_understanding_mask = generation_understanding_indicator + data_info = {} + data_info['text_token_mask'] = text_expanded_token_mask.unsqueeze(0).repeat(batch_size, 1).to(device) + data_info['image_token_mask'] = image_expanded_token_mask.unsqueeze(0).repeat(batch_size, 1).to(device) + data_info['generation_or_understanding_mask'] = \ + torch.Tensor([generation_or_understanding_mask]).unsqueeze(0).repeat(batch_size, 1).to(device).to(dtype=int) + + data_info['attention_mask'] = attention_mask.unsqueeze(0).repeat(batch_size, 1).to(device) + data_info['sft_format'] = sft_format + if imgs is not None: + data_info['understanding_img'] = imgs.unsqueeze(0).to(device, dtype=torch.float32).repeat(batch_size, 1, 1, 1, 1) + data_info['has_understanding_img'] = torch.Tensor([True]).to(dtype=int).unsqueeze(0).repeat(batch_size, 1).to(device) + else: + data_info['understanding_img'] = torch.zeros((3, 384, 384)).unsqueeze(0).repeat(batch_size, 1, 1, 1).to(device) + data_info['has_understanding_img'] = torch.Tensor([False]).to(dtype=int).unsqueeze(0).repeat(batch_size, 1).to(device) + input_ids = torch.LongTensor(input_ids).unsqueeze(0).repeat(batch_size, 1).to(device) + + + x_0_txt = torch.randint(VOCABULARY_SIZE_TXT + num_tokens_length, input_ids.shape, dtype=torch.long, device=device) + x_init = x_0_txt * data_info['text_token_mask'] + input_ids * (1 - data_info['text_token_mask']) + + synthetic_samples = solver.sample( + x_init=x_init, + step_size=1.0/discrete_fm_steps, + verbose=True, + return_intermediates=False, + div_free=0, + dtype_categorical=torch.float32, + datainfo=data_info, + cfg_scale=0, + ) + sentence = vl_chat_processor.tokenizer.batch_decode( + synthetic_samples, + skip_special_tokens=True + )[0] + print("Sentence:", sentence) + + waypoint = extract_number_pairs(sentence, k=8)[0] + print("Waypoint: ", waypoint) + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..dab4802 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,17 @@ +omegaconf +transformers==4.42.4 +timm +tokenizers +attrdict +torch==2.0.1 +torchvision +accelerate +sentencepiece +einops +torchdiffeq +matplotlib +numpy==1.26.4 +icecream +xformers==0.0.22 +diffusers==0.32.2 +deepspeed \ No newline at end of file diff --git a/script/infer.sh b/script/infer.sh new file mode 100644 index 0000000..114b5f4 --- /dev/null +++ b/script/infer.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +CKPT_PATH="pretrained_model/wam-flow/navsim" +FUDOKI_PATH="pretrained_model/fudoki" +IMAGE_PATH="data/navsim_data/sensor_blobs/test/2021.09.09.17.18.51_veh-48_00889_01147/CAM_F0/9a6f0331d98258a0.jpg" + +torchrun --nproc_per_node 1 infer.py \ + --checkpoint_path $CKPT_PATH \ + --image_path $IMAGE_PATH \ + --processor_path $FUDOKI_PATH \ + --text_embedding_path $FUDOKI_PATH/text_embedding.pt \ + --image_embedding_path $FUDOKI_PATH/image_embedding.pt \ + --discrete_fm_steps 2 \ + --seed 123 \ No newline at end of file diff --git a/script/sft_debug.sh b/script/sft_debug.sh new file mode 100644 index 0000000..fd310a6 --- /dev/null +++ b/script/sft_debug.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +NUM_NODES=1 +NUM_GPUS=1 + +config=config/debug.yaml +output_dir=output/train/debug + +accelerate launch \ + --config_file ./config/accelerate_config_ds2.yaml \ + --machine_rank 0 \ + --main_process_port 12345 \ + --num_machines $NUM_NODES \ + --num_processes $NUM_GPUS \ + train.py \ + --config $config \ + --output_dir $output_dir diff --git a/script/sft_navsim.sh b/script/sft_navsim.sh new file mode 100644 index 0000000..aedda23 --- /dev/null +++ b/script/sft_navsim.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +NUM_NODES=1 +NUM_GPUS=1 + +config=config/sft_navsim.yaml +output_dir=output/train/debug + +accelerate launch \ + --config_file ./config/accelerate_config_ds2.yaml \ + --machine_rank 0 \ + --main_process_port 12345 \ + --num_machines $NUM_NODES \ + --num_processes $NUM_GPUS \ + train.py \ + --config $config \ + --output_dir $output_dir diff --git a/train.py b/train.py new file mode 100644 index 0000000..6c8398f --- /dev/null +++ b/train.py @@ -0,0 +1,451 @@ +# LINT_ME +import argparse +import os +import logging +import shutil +import copy +import random + +import numpy as np +import transformers +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torch.nn import CrossEntropyLoss +import diffusers +from diffusers.optimization import get_scheduler +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from omegaconf import OmegaConf +from tqdm.auto import tqdm + +from fudoki.model import instantiate_model +from fudoki.janus.models import VLChatProcessor +from flow_matching.path import MixtureDiscreteSoftmaxProbPath +from flow_matching.data.navsim import SupervisedDataset +from flow_matching.utils.flow import get_source_distribution + + +logger = get_logger(__name__, log_level="INFO") + + +@torch.no_grad() +def init_numeric_and_special_tokens( + model, + tokenizer, + numeric_tokens, + noise_scale: float = 0.01, +): + emb = model.get_input_embeddings().weight + device = emb.device + dim = emb.shape[1] + + def tok_ids_for_text(text: str): + # Get subword ids (no specials), filter out -100/None if any + ids = tokenizer.encode(text,add_special_tokens=False) + return [i for i in ids if isinstance(i, int) and i >= 0] + + # ---- Build a numeric "base" vector from digits and dot ---- + digit_ids = [] + for d in "0123456789": + _ids = tok_ids_for_text(d) + digit_ids.extend(_ids) + dot_ids = tok_ids_for_text(".") + + base_chunks = [] + if digit_ids: + base_chunks.append(emb[torch.tensor(digit_ids, device=device)].mean(dim=0)) + if dot_ids: + base_chunks.append(emb[torch.tensor(dot_ids, device=device)].mean(dim=0)) + if base_chunks: + numeric_base = torch.stack(base_chunks, dim=0).mean(dim=0) + else: + # fallback if tokenizer lacks digits/dot as standalone pieces + numeric_base = torch.zeros(dim, device=device) + + # ---- Initialize numeric tokens ---- + for t in numeric_tokens: + tid = tokenizer.convert_tokens_to_ids(t) + if tid is None or tid < 0: + continue + noise = noise_scale * torch.randn(dim, device=device) + emb[tid] = numeric_base + noise + + +def training_step( + model, + x_1, + source_distribution, + data_info, + path, + time_epsilon = 0.001, + loss_fn = CrossEntropyLoss(), + stage="s1", + vl_chat_processor=None, + args=None, +): + x_0 = source_distribution.sample_like(x_1) + t = torch.rand(x_1.shape[0], device=x_1.device) * (1.0 - time_epsilon) + + if stage == "s1": + x_t = x_1 + elif stage == "s2": + # update emb layer when using num tokenizer + path_sample = path.sample(x_0, x_1, t) + x_t = path_sample.x_t + + # text_token_mask==1 ==> generated text token + x_t = x_t * data_info['text_token_mask'] + x_1 * (1 - data_info['text_token_mask']) + data_info['understanding_img'] = data_info['understanding_img'].to(dtype=model.dtype) + + _, txt_logits = model(x_t, data_info) + + b, _, c = txt_logits.shape + mask = data_info['text_token_mask'].unsqueeze(-1).bool() + txt_logits = txt_logits.masked_select(mask) + txt_logits = txt_logits.view(b, -1, c) + x_1 = x_1.masked_select(mask.squeeze(-1)).view(b, -1) + + loss = ce_loss = loss_fn(txt_logits.flatten(0, 1), x_1.flatten(0, 1)).mean() + loss_dict = {"ce_loss": ce_loss.detach().item()} + + if stage == "s2": + start_mask = x_1 >= vl_chat_processor.num_start_id + end_mask = x_1 <= vl_chat_processor.num_end_id - 1 + action_mask = start_mask & end_mask + + if action_mask.any(): + if args.l2_loss_weight > 0: + pred_probabilities = F.softmax(txt_logits, dim=-1) + pred_ids = torch.argmax(pred_probabilities, dim=-1) + pred_num_ids = pred_ids.masked_select(action_mask) + pred_nums = vl_chat_processor.min_num + (pred_num_ids - vl_chat_processor.num_start_id) * vl_chat_processor.interval + pred_nums = torch.clip(pred_nums, vl_chat_processor.min_num, vl_chat_processor.max_num) + + tgt_num_ids = x_1.masked_select(action_mask) # [N] + tgt_nums = vl_chat_processor.min_num + (tgt_num_ids - vl_chat_processor.num_start_id) * vl_chat_processor.interval + + l2_loss = args.l2_loss_weight + F.mse_loss(pred_nums, tgt_nums, reduction="mean") + loss = loss + l2_loss + loss_dict["l2_loss"] = l2_loss.detach().item() + + loss_dict["loss"] = loss.detach().item() + return loss, loss_dict + + +def main(args): + accelerator = Accelerator( + mixed_precision=args.mixed_precision, + gradient_accumulation_steps=args.accumulate_grad_batches, + log_with="tensorboard", + project_dir=args.output_dir + ) + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # random seed + seed = args.seed + accelerator.process_index + if args.random_seed: + seed = seed + random.randint(0, 500) + set_seed(seed) + logger.info(f"accelerator.process_index: {accelerator.process_index}, seed: {seed} \n") + + # work dir + if accelerator.is_main_process: + os.makedirs(args.output_dir, exist_ok=True) + config_path = os.path.join(args.output_dir, "config.yaml") + OmegaConf.save(args, config_path) + accelerate_config_path = os.path.join(args.output_dir, "accelerate_config_ds2.yaml") + shutil.copyfile( + "./config/accelerate_config_ds2.yaml", + accelerate_config_path + ) + accelerator.wait_for_everyone() + + # dtype + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # prepare dataset + vl_chat_processor = VLChatProcessor.from_pretrained(args.model_path) + if args.use_quantize: + origin_len = len(vl_chat_processor.tokenizer) + num_tokens = [f"{x:.2f}" for x in np.linspace(-100, 100, 20001)] + num_tokens_length = len(num_tokens) + vl_chat_processor.tokenizer.add_tokens(num_tokens) + + vl_chat_processor.num_start_id = origin_len + vl_chat_processor.num_end_id = origin_len + num_tokens_length - 1 + vl_chat_processor.min_num = -100 + vl_chat_processor.max_num = 100 + vl_chat_processor.interval = 0.01 + + logger.info(f"Total number tokens: {num_tokens_length}") + + # data + dataset = SupervisedDataset( + data_list=args.data_list, + vl_chat_processor=vl_chat_processor, + txt_max_length=args.txt_max_length + ) + dataloader = DataLoader( + dataset, + shuffle=True, + batch_size=args.batch_size, + num_workers=args.dataloader_num_workers, + pin_memory=True + ) + logger.info(f"Max txt length: {args.txt_max_length}") + logger.info(f"Total data samples: {len(dataset)}") + + # prepare model + stage = args.stage + logger.info(f"Training stege: {stage}") + model = instantiate_model( + args.pretrain_model_path + ).to(weight_dtype) + model.uncond_prob = args.uncond_prob + + if os.path.exists(args.pretrain_path): + sd = torch.load(args.pretrain_path, map_location='cpu') + model.load_state_dict(sd, strict=True) + model = model.to(weight_dtype) + logger.info(f"Loading pretrain ckpt from {args.pretrain_path}") + + if stage == "s1": + model.language_model.resize_token_embeddings(args.vocab_size + num_tokens_length) + init_numeric_and_special_tokens( + model.language_model, + vl_chat_processor.tokenizer, + numeric_tokens=num_tokens + ) + elif stage == "s2": + if os.path.exists(args.new_embedding_path): + old_emb = copy.deepcopy(model.language_model.get_input_embeddings()) + + with torch.serialization.safe_globals([torch.nn.modules.sparse.Embedding]): + new_emb_state = torch.load(args.new_embedding_path, map_location="cpu") + logger.info(f"Loading new embedding from {args.new_embedding_path}") + + if isinstance(new_emb_state, dict) and "weight" in new_emb_state: + weight = new_emb_state["weight"] + new_emb = torch.nn.Embedding(weight.size(0), weight.size(1)) + new_emb.load_state_dict(new_emb_state) + else: + new_emb = new_emb_state + + model.language_model.resize_token_embeddings(args.vocab_size + num_tokens_length) + + # origin_len = old_emb.weight.shape[0] + origin_len = vl_chat_processor.num_start_id + new_emb.weight.data[:origin_len, :] = old_emb.weight.data[:origin_len, :] + + model.language_model.set_input_embeddings(new_emb) + + if os.path.exists(args.ckpt_path): + if model.language_model.model.embed_tokens.weight.shape[0] != args.vocab_size + num_tokens_length: + model.language_model.resize_token_embeddings(args.vocab_size + num_tokens_length) + sd = torch.load(args.ckpt_path, map_location='cpu') + model.load_state_dict(sd, strict=True) + model = model.to(weight_dtype) + logger.info(f"Loading ckpt from {args.ckpt_path}") + + # prepare path + path = MixtureDiscreteSoftmaxProbPath( + mode='text', + embedding_path=args.text_embedding_path + ) + if args.use_quantize: + path.set_embedding(model.language_model.get_input_embeddings()) + else: + logger.info("No quantize!") + + logger.info(f"path.a = {path.a}") + logger.info(f"path.c = {path.c}") + + # set trainable params + model.requires_grad_(False) + if stage == "s1": + model.language_model.requires_grad_(False) + model.language_model.model.embed_tokens.requires_grad_(True) + model.language_model.lm_head.requires_grad_(True) + elif stage == "s2": + model.language_model.requires_grad_(True) + if args.train_llm_emb: + model.language_model.model.embed_tokens.requires_grad_(True) + else: + model.language_model.model.embed_tokens.requires_grad_(False) + + trainable_params = list( + filter(lambda p: p.requires_grad, model.parameters()) + ) + + # log trainable params + # for name, param in model.named_parameters(): + # if param.requires_grad: + # logger.info(f"Trainable params: {name}") + + num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + logger.info(f"Total trainable parameters: {num_trainable/1e9:.3} B") + + optimizer = torch.optim.AdamW( + trainable_params, + lr=args.learning_rate, + betas=(0.9, 0.95), + weight_decay=0.05, + ) + + # lr scheduler + lr_scheduler = get_scheduler( + args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes + ) + + # accelerator + model, optimizer, dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, dataloader, lr_scheduler + ) + + source_distribution = get_source_distribution( + source_distribution=args.source_distribution, + vocab_size=args.vocab_size + num_tokens_length if args.use_quantize else args.vocab_size, + ) + + + global_step = 0 + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + disable=not accelerator.is_local_main_process, + ) + + # training loop + for epoch in range(args.max_epochs): + logger.info(f"Epoch {epoch + 1}/{args.max_epochs}") + logger.info(f"training sample length: {len(dataloader)}") + + for _, batch in enumerate(dataloader): + with accelerator.accumulate(model): + x_1 = batch["input_ids"].to(dtype=torch.long) + loss, logs = training_step( + x_1=x_1, + model=model, + source_distribution=source_distribution, + data_info=batch, + path=path, + stage=stage, + vl_chat_processor=vl_chat_processor, + args=args + ) + accelerator.backward(loss) + + if accelerator.sync_gradients: + accelerator.clip_grad_norm_( + trainable_params, + args.max_grad_norm, + ) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + logs["lr"] = optimizer.param_groups[0]['lr'] + + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step % args.checkpointing_steps == 0 \ + and accelerator.is_main_process and args.checkpoints_total_limit is not None: + + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + if os.path.exists(removing_checkpoint): + shutil.rmtree(removing_checkpoint) + + accelerator.wait_for_everyone() + unwrap_net = accelerator.unwrap_model(model) + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + if accelerator.is_main_process: + unwrap_net.save_pretrained(save_path, max_shard_size="20GB") + logger.info(f"Saved state to {save_path}") + + if global_step >= args.max_train_steps: + break + + if global_step >= args.max_train_steps: + break + + accelerator.wait_for_everyone() + accelerator.end_training() + logger.info("training completed!") + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument("--output_obs_dir", type=str, default=None) + + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_args() + + config = OmegaConf.load(args.config) + + # merge args + args_dict = vars(args).copy() + args_dict.pop("config", None) + config_keys = set(config.keys()) + cli_keys = set(args_dict.keys()) + + # check conflict + conflict_keys = cli_keys & config_keys + if conflict_keys: + print(f"Args conflict: {conflict_keys}") + + # merge + merged_config = OmegaConf.merge(OmegaConf.create(args_dict), config) + args = merged_config + + # training + main(args)