diff --git a/.DS_Store b/.DS_Store
deleted file mode 100644
index 84f5661..0000000
Binary files a/.DS_Store and /dev/null differ
diff --git a/.gitignore b/.gitignore
index b7faf40..dfa1df5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -205,3 +205,30 @@ cython_debug/
marimo/_static/
marimo/_lsp/
__marimo__/
+
+
+
+pretrained_model
+__pycache__/
+*.pyc
+*.pyo
+*.pyd
+*.json
+*.jsonl
+env/
+venv/
+.venv/
+results/
+fusion_result.json
+*output*
+*kernel_meta/
+data/*.json
+maps
+data/metric_cache
+data/maps
+tmp
+navsim_v1
+nuscenes
+extra-info/
+pretrained_model/
+.DS_Store
\ No newline at end of file
diff --git a/.pylintrc b/.pylintrc
index 7e930ac..aa823bf 100644
--- a/.pylintrc
+++ b/.pylintrc
@@ -31,7 +31,7 @@ extension-pkg-allow-list=
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code. (This is an alternative name to extension-pkg-allow-list
# for backward compatibility.)
-extension-pkg-whitelist=cv2
+extension-pkg-whitelist=
# Return non-zero exit code if any of these messages/categories are detected,
# even if score is above --fail-under value. Syntax same as enable. Messages
@@ -59,15 +59,16 @@ ignore-paths=
# Emacs file locks
ignore-patterns=^\.#
-# List of module names for which member attributes should not be checked
-# (useful for modules/projects where namespaces are manipulated during runtime
-# and thus existing member attributes cannot be deduced by static analysis). It
-# supports qualified module names, as well as Unix pattern matching.
-ignored-modules=cv2
+# List of module names for which member attributes should not be checked and
+# will not be imported (useful for modules/projects where namespaces are
+# manipulated during runtime and thus existing member attributes cannot be
+# deduced by static analysis). It supports qualified module names, as well as
+# Unix pattern matching.
+ignored-modules=torch
# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
-init-hook='import sys; sys.path.append(".")'
+#init-hook=
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
# number of processors available to use, and will cap the count on Windows to
@@ -86,9 +87,13 @@ load-plugins=
# Pickle collected data for later comparisons.
persistent=yes
+# Resolve imports to .pyi stubs if available. May reduce no-member messages and
+# increase not-an-iterable messages.
+prefer-stubs=no
+
# Minimum Python version to use for version dependent checks. Will default to
# the version used to run pylint.
-py-version=3.10
+py-version=3.12
# Discover python modules and packages in the file system subtree.
recursive=no
@@ -99,10 +104,6 @@ recursive=no
# source root.
source-roots=
-# When enabled, pylint would attempt to guess common misconfiguration and emit
-# user-friendly hints instead of false-positive error messages.
-suggestion-mode=yes
-
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
@@ -229,6 +230,11 @@ name-group=
# not require a docstring.
no-docstring-rgx=^_
+# Regular expression matching correct parameter specification variable names.
+# If left empty, parameter specification variable names will be checked with
+# the set naming style.
+#paramspec-rgx=
+
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
# These decorators are taken in consideration only for invalid-name.
@@ -242,13 +248,17 @@ property-classes=abc.abstractproperty
# variable names will be checked with the set naming style.
#typevar-rgx=
+# Regular expression matching correct type variable tuple names. If left empty,
+# type variable tuple names will be checked with the set naming style.
+#typevartuple-rgx=
+
# Naming style matching correct variable names.
variable-naming-style=snake_case
# Regular expression matching correct variable names. Overrides variable-
# naming-style. If left empty, variable names will be checked with the set
# naming style.
-variable-rgx=(_?[a-z][A-Za-z0-9]{0,30})|([A-Z0-9]{1,30})
+#variable-rgx=
[CLASSES]
@@ -285,23 +295,26 @@ exclude-too-few-public-methods=
ignored-parents=
# Maximum number of arguments for function / method.
-max-args=7
+max-args=10
# Maximum number of attributes for a class (see R0902).
-max-attributes=20
+max-attributes=7
# Maximum number of boolean expressions in an if statement (see R0916).
max-bool-expr=5
# Maximum number of branch for function / method body.
-max-branches=12
+max-branches=40
# Maximum number of locals for function / method body.
-max-locals=15
+max-locals=50
# Maximum number of parents for a class (see R0901).
max-parents=7
+# Maximum number of positional arguments for function / method.
+max-positional-arguments=10
+
# Maximum number of public methods for a class (see R0904).
max-public-methods=20
@@ -309,10 +322,10 @@ max-public-methods=20
max-returns=6
# Maximum number of statements in function / method body.
-max-statements=300
+max-statements=150
# Minimum number of public methods for a class (see R0903).
-min-public-methods=1
+min-public-methods=2
[EXCEPTIONS]
@@ -336,11 +349,13 @@ indent-after-paren=4
# tab).
indent-string=' '
-# Maximum number of characters on a single line.
+# Maximum number of characters on a single line. Pylint's default of 100 is
+# based on PEP 8's guidance that teams may choose line lengths up to 99
+# characters.
max-line-length=150
# Maximum number of lines in a module.
-max-module-lines=2000
+max-module-lines=1000
# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
@@ -421,11 +436,21 @@ confidence=HIGH,
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use "--disable=all --enable=classes
# --disable=W".
-disable=too-many-arguments,
- too-many-locals,
- too-many-branches,
- protected-access
-
+disable=raw-checker-failed,
+ bad-inline-option,
+ locally-disabled,
+ file-ignored,
+ suppressed-message,
+ useless-suppression,
+ deprecated-pragma,
+ use-symbolic-message-instead,
+ use-implicit-booleaness-not-comparison-to-string,
+ use-implicit-booleaness-not-comparison-to-zero,
+ C0116,
+ C0114,
+ W0621,
+ E0601,
+ W0718
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
@@ -443,9 +468,13 @@ timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.
[MISCELLANEOUS]
+# Whether or not to search for fixme's in docstrings.
+check-fixme-in-docstring=no
+
# List of note tags to take in consideration, separated by a comma.
notes=FIXME,
- XXX
+ XXX,
+ TODO
# Regular expression of note tags to take in consideration.
notes-rgx=
@@ -465,7 +494,7 @@ never-returning-functions=sys.exit,argparse.parse_error
# Let 'consider-using-join' be raised when the separator to join on would be
# non-empty (resulting in expected fixes of the type: ``"- " + " -
# ".join(items)``)
-# suggest-join-with-non-empty-separator=yes
+suggest-join-with-non-empty-separator=yes
[REPORTS]
@@ -481,10 +510,10 @@ evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor
# used to format the message information. See doc for all details.
msg-template=
-# Set the output format. Available formats are: text, parseable, colorized,
-# json2 (improved json format), json (old json format) and msvs (visual
-# studio). You can also give a reporter class, e.g.
-# mypackage.mymodule.MyReporterClass.
+# Set the output format. Available formats are: 'text', 'parseable',
+# 'colorized', 'json2' (improved json format), 'json' (old json format), msvs
+# (visual studio) and 'github' (GitHub actions). You can also give a reporter
+# class, e.g. mypackage.mymodule.MyReporterClass.
#output-format=
# Tells whether to display a full report or only the messages.
@@ -586,7 +615,7 @@ ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace
# of finding the hint is based on edit distance.
missing-member-hint=yes
-# The minimum edit distance a name should have in order to be considered a
+# The maximum edit distance a name should have in order to be considered a
# similar match for a missing member name.
missing-member-hint-distance=1
@@ -630,4 +659,4 @@ init-import=no
# List of qualified module names which can have objects that can redefine
# builtins.
-redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
\ No newline at end of file
+redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..261eeb9
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
index 50cfc2f..5271aa3 100644
--- a/README.md
+++ b/README.md
@@ -27,17 +27,23 @@
+
+
## π° News
- **`2025/12/06`**: πππ Paper submitted on [Arxiv](https://arxiv.org/pdf/2512.06112).
+
+
## π
οΈ Roadmap
| Status | Milestone | ETA |
| :----: | :----------------------------------------------------------------------------------------------------: | :--------: |
-| π | **[Releasing the inference source code](https://github.com/fudan-generative-vision/WAM-Flow)** | 2025.12.21 |
-| π | **[Pretrained models on Huggingface](https://huggingface.co/fudan-generative-ai/WAM-Flow)** | TBD |
-| π | **[Releasing the training scripts](#training)** | TBD |
+| β
| **[Release the SFT and inference code](https://github.com/fudan-generative-vision/WAM-Flow)** | 2025.12.19 |
+| π | **[Pretrained models on Huggingface](https://huggingface.co/fudan-generative-ai/WAM-Flow)** | TBD |
+| π | **[Release the evaluation code](https://huggingface.co/fudan-generative-ai/WAM-Flow)** | TBD |
+| π | **[Release the RL code](https://github.com/fudan-generative-vision/WAM-Flow)** | TBD |
+| π | **[Release the pre-processed training data](#training)** | TBD |
## πΈ Showcase
@@ -45,14 +51,68 @@
## π Qualitative Results on NAVSIM
### NAVSIM-v1 benchmark results
-
+
+

+
+
### NAVSIM-v2 benchmark results
-
+
+

+
+
+
## π§οΈ Framework

Our method takes as input a front-view image, a natural-language navigation command with a system prompt, and the ego-vehicle states, and outputs an 8-waypoint future trajectory spanning 4 seconds through parallel denoising. The model is first trained via supervised fine-tuning to learn accurate trajectory prediction. We then apply simulatorguided GRPO to further optimize closed-loop behavior. The GRPO reward function integrates safety constraints (collision avoidance, drivable-area compliance) with performance objectives (ego-progress, time-to-collision, comfort).
+
+## Preparation
+
+### Environment
+```sh
+conda create --name wam-flow python=3.10
+conda activate wam-flow
+pip install -r requirements.txt
+```
+
+### Checkpoint
+
+
+### Data
+
+
+
+## Training
+
+### Pretraining
+
+### SFT
+
+### RL
+
+### Debug
+```sh
+sh script/sft_debug.sh
+```
+
+
+## Inference
+```sh
+sh script/infer.sh
+```
+
+
+## Evaluation
+
+### NAVSIM-v1
+
+### NAVSIM-v2
+
+### nuScenes
+
+
+
## π Citation
If you find our work useful for your research, please consider citing the paper:
@@ -66,9 +126,13 @@ If you find our work useful for your research, please consider citing the paper:
}
```
+
+
## β οΈ Social Risks and Mitigations
The integration of Vision-Language-Action models into autonomous driving introduces ethical challenges, particularly regarding the opacity of neural decision-making and its impact on road safety. To mitigate these risks, it is imperative to implement explainable AI frameworks and robust safe protocols that ensure predictable vehicle behavior in long-tailed scenarios. Furthermore, addressing concerns over data privacy and public surveillance requires transparent data governance and rigorous de-identification practices. By prioritizing safety-critical alignment and ethical compliance, this research promotes the responsible development and deployment of VLA-based autonomous systems.
+
+
## π€ Acknowledgements
-We gratefully acknowledge the contributors to the [Janus](https://github.com/deepseek-ai/Janus), [FUDOKI](https://github.com/fudoki-hku/FUDOKI) and [flow_matching](https://github.com/facebookresearch/flow_matching) repositories, whose commitment to open source has provided us with their excellent codebases and pretrained models.
\ No newline at end of file
+We gratefully acknowledge the contributors to the [Recogdrive](https://github.com/xiaomi-research/recogdrive), [Janus](https://github.com/deepseek-ai/Janus), [FUDOKI](https://github.com/fudoki-hku/FUDOKI) and [flow_matching](https://github.com/facebookresearch/flow_matching) repositories, whose commitment to open source has provided us with their excellent codebases and pretrained models.
\ No newline at end of file
diff --git a/config/accelerate_config_ds2.yaml b/config/accelerate_config_ds2.yaml
new file mode 100644
index 0000000..325fe7c
--- /dev/null
+++ b/config/accelerate_config_ds2.yaml
@@ -0,0 +1,21 @@
+compute_environment: LOCAL_MACHINE
+debug: true
+deepspeed_config:
+ deepspeed_multinode_launcher: standard
+ gradient_accumulation_steps: 1
+ offload_optimizer_device: none
+ offload_param_device: none
+ zero3_init_flag: false
+ zero_stage: 0 # use 2 to accelerate
+distributed_type: DEEPSPEED
+downcast_bf16: "no"
+main_training_function: main
+mixed_precision: "fp16"
+num_machines: 1
+num_processes: 8
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
\ No newline at end of file
diff --git a/config/debug.yaml b/config/debug.yaml
new file mode 100644
index 0000000..4d7fe62
--- /dev/null
+++ b/config/debug.yaml
@@ -0,0 +1,36 @@
+model_path: pretrained_model/fudoki
+text_embedding_path: pretrained_model/fudoki/text_embedding.pt
+pretrain_model_path: pretrained_model/wam-flow/navsim
+
+stage: s2
+txt_max_length: 500
+new_embedding_path: ""
+
+pretrain_path: ""
+ckpt_path: ""
+train_llm_emb: false
+
+data_list: [
+ "data/navsim_debug.jsonl",
+]
+
+mixed_precision: "no"
+accumulate_grad_batches: 1
+max_grad_norm: 1.0
+seed: 42
+source_distribution: "uniform"
+vocab_size: 102400
+batch_size: 1
+dataloader_num_workers: 8
+learning_rate: 5e-6
+lr_scheduler_type: "cosine" # "constant", "cosine"
+lr_warmup_steps: 500
+max_train_steps: 40000
+max_epochs: 100
+uncond_prob: 0
+use_quantize: true
+random_seed: false
+l2_loss_weight: 0
+
+checkpointing_steps: 4000
+checkpoints_total_limit: 50
\ No newline at end of file
diff --git a/config/pretrain.yaml b/config/pretrain.yaml
new file mode 100644
index 0000000..6563464
--- /dev/null
+++ b/config/pretrain.yaml
@@ -0,0 +1,53 @@
+model_path: pretrained_model/fudoki
+pretrain_model_path: /cache/xyf_model/
+text_embedding_path: pretrained_model/fudoki/text_embedding.pt
+
+stage: s2
+txt_max_length: 500
+new_embedding_path: ""
+
+pretrain_path: ""
+ckpt_path: ""
+train_llm_emb: false
+
+data_list: [
+ "path/to/llava_v1_5_mix665k_2.jsonl"
+ "path/to/dataset_coda_lm.jsonl",
+ "path/to/dataset_drivegpt4.jsonl",
+ "path/to/dataset_lingoqa.jsonl",
+ "path/to/dataset_maplm.jsonl",
+ "path/to/dataset_nuscenes_qa.jsonl",
+ "path/to/dataset_omnidrive.jsonl",
+ "path/to/dataset_senna.jsonl",
+ "path/to/dataset_talk2car.jsonl",
+ "path/to/drivelm_change_box_type.jsonl",
+
+ "path/to/nuplan_recogdrive.jsonl",
+ "path/to/navsim_recogdrive.jsonl",
+ "path/to/navsim_recogdrive.jsonl",
+
+ "path/to/navsim_668k.jsonl",
+ "path/to/navsim_103k.jsonl",
+ "path/to/nuscenes_train.jsonl",
+]
+
+mixed_precision: "no"
+accumulate_grad_batches: 4
+max_grad_norm: 1.0
+seed: 42
+source_distribution: "uniform"
+vocab_size: 102400
+batch_size: 1
+dataloader_num_workers: 8
+learning_rate: 1e-5
+lr_scheduler_type: "constant" # cosine, constant
+lr_warmup_steps: 0
+max_train_steps: 80000
+max_epochs: 10
+uncond_prob: 0
+use_quantize: true
+random_seed: true
+l2_loss_weight: 0
+
+checkpointing_steps: 2500
+checkpoints_total_limit: 50
\ No newline at end of file
diff --git a/config/sft_navsim.yaml b/config/sft_navsim.yaml
new file mode 100644
index 0000000..54264a7
--- /dev/null
+++ b/config/sft_navsim.yaml
@@ -0,0 +1,36 @@
+model_path: pretrained_model/fudoki
+text_embedding_path: pretrained_model/fudoki/text_embedding.pt
+pretrain_model_path: pretrained_model/wam-flow/navsim
+
+stage: s2
+txt_max_length: 500
+new_embedding_path: ""
+
+pretrain_path: ""
+ckpt_path: ""
+train_llm_emb: false
+
+data_list: [
+ "path/to/navsim_668k.jsonl",
+]
+
+mixed_precision: "no"
+accumulate_grad_batches: 1
+max_grad_norm: 1.0
+seed: 42
+source_distribution: "uniform"
+vocab_size: 102400
+batch_size: 1
+dataloader_num_workers: 8
+learning_rate: 5e-6
+lr_scheduler_type: "cosine" # "constant", "cosine"
+lr_warmup_steps: 500
+max_train_steps: 40000
+max_epochs: 100
+uncond_prob: 0
+use_quantize: true
+random_seed: false
+l2_loss_weight: 0
+
+checkpointing_steps: 4000
+checkpoints_total_limit: 50
\ No newline at end of file
diff --git a/config/sft_nuscenes.yaml b/config/sft_nuscenes.yaml
new file mode 100644
index 0000000..17848ab
--- /dev/null
+++ b/config/sft_nuscenes.yaml
@@ -0,0 +1,36 @@
+model_path: pretrained_model/fudoki
+pretrain_model_path: /cache/xyf_model/
+text_embedding_path: /cache/models/LucasJinWang-FUDOKI/text_embedding.pt
+
+stage: s2
+txt_max_length: 500
+new_embedding_path: ""
+
+pretrain_path: ""
+ckpt_path: ""
+train_llm_emb: false
+
+data_list: [
+ "path/to/nuscenes_train.jsonl",
+]
+
+mixed_precision: "no"
+accumulate_grad_batches: 1
+max_grad_norm: 1.0
+seed: 42
+source_distribution: "uniform"
+vocab_size: 102400
+batch_size: 1
+dataloader_num_workers: 8
+learning_rate: 5e-6
+lr_scheduler_type: "cosine" # "constant", "cosine"
+lr_warmup_steps: 100
+max_train_steps: 10000
+max_epochs: 100
+uncond_prob: 0
+use_quantize: true
+random_seed: false
+l2_loss_weight: 0
+
+checkpointing_steps: 1000
+checkpoints_total_limit: 50
\ No newline at end of file
diff --git a/data/navsim_data/sensor_blobs/test/2021.09.09.17.18.51_veh-48_00889_01147/CAM_F0/9a6f0331d98258a0.jpg b/data/navsim_data/sensor_blobs/test/2021.09.09.17.18.51_veh-48_00889_01147/CAM_F0/9a6f0331d98258a0.jpg
new file mode 100644
index 0000000..786c25e
Binary files /dev/null and b/data/navsim_data/sensor_blobs/test/2021.09.09.17.18.51_veh-48_00889_01147/CAM_F0/9a6f0331d98258a0.jpg differ
diff --git a/data/navsim_data/sensor_blobs/trainval/2021.06.09.18.23.43_veh-35_03190_03392/CAM_F0/f997082b36a65b27.jpg b/data/navsim_data/sensor_blobs/trainval/2021.06.09.18.23.43_veh-35_03190_03392/CAM_F0/f997082b36a65b27.jpg
new file mode 100644
index 0000000..6a80392
Binary files /dev/null and b/data/navsim_data/sensor_blobs/trainval/2021.06.09.18.23.43_veh-35_03190_03392/CAM_F0/f997082b36a65b27.jpg differ
diff --git a/data/navsim_data/sensor_blobs/trainval/2021.06.09.19.40.26_veh-12_00279_01212/CAM_F0/e69a584baa1a571f.jpg b/data/navsim_data/sensor_blobs/trainval/2021.06.09.19.40.26_veh-12_00279_01212/CAM_F0/e69a584baa1a571f.jpg
new file mode 100644
index 0000000..a90ead4
Binary files /dev/null and b/data/navsim_data/sensor_blobs/trainval/2021.06.09.19.40.26_veh-12_00279_01212/CAM_F0/e69a584baa1a571f.jpg differ
diff --git a/data/navsim_data/sensor_blobs/trainval/2021.06.23.14.06.20_veh-26_00020_01142/CAM_F0/ab90c429b6c851ae.jpg b/data/navsim_data/sensor_blobs/trainval/2021.06.23.14.06.20_veh-26_00020_01142/CAM_F0/ab90c429b6c851ae.jpg
new file mode 100644
index 0000000..49fc0aa
Binary files /dev/null and b/data/navsim_data/sensor_blobs/trainval/2021.06.23.14.06.20_veh-26_00020_01142/CAM_F0/ab90c429b6c851ae.jpg differ
diff --git a/data/navsim_data/sensor_blobs/trainval/2021.07.09.23.23.48_veh-26_02228_04624/CAM_F0/12bfa1b0a8615565.jpg b/data/navsim_data/sensor_blobs/trainval/2021.07.09.23.23.48_veh-26_02228_04624/CAM_F0/12bfa1b0a8615565.jpg
new file mode 100644
index 0000000..3a107ee
Binary files /dev/null and b/data/navsim_data/sensor_blobs/trainval/2021.07.09.23.23.48_veh-26_02228_04624/CAM_F0/12bfa1b0a8615565.jpg differ
diff --git a/data/navsim_data/sensor_blobs/trainval/2021.07.16.00.51.05_veh-17_01352_01901/CAM_F0/6c1a5fe095f95f1e.jpg b/data/navsim_data/sensor_blobs/trainval/2021.07.16.00.51.05_veh-17_01352_01901/CAM_F0/6c1a5fe095f95f1e.jpg
new file mode 100644
index 0000000..7267ca2
Binary files /dev/null and b/data/navsim_data/sensor_blobs/trainval/2021.07.16.00.51.05_veh-17_01352_01901/CAM_F0/6c1a5fe095f95f1e.jpg differ
diff --git a/data/navsim_data/sensor_blobs/trainval/2021.07.16.16.01.30_veh-38_02497_03871/CAM_F0/c2a1255ee2935f44.jpg b/data/navsim_data/sensor_blobs/trainval/2021.07.16.16.01.30_veh-38_02497_03871/CAM_F0/c2a1255ee2935f44.jpg
new file mode 100644
index 0000000..4d24cc4
Binary files /dev/null and b/data/navsim_data/sensor_blobs/trainval/2021.07.16.16.01.30_veh-38_02497_03871/CAM_F0/c2a1255ee2935f44.jpg differ
diff --git a/data/navsim_data/sensor_blobs/trainval/2021.07.16.18.49.56_veh-26_00833_03384/CAM_F0/bcf630737fbf5d29.jpg b/data/navsim_data/sensor_blobs/trainval/2021.07.16.18.49.56_veh-26_00833_03384/CAM_F0/bcf630737fbf5d29.jpg
new file mode 100644
index 0000000..11c3237
Binary files /dev/null and b/data/navsim_data/sensor_blobs/trainval/2021.07.16.18.49.56_veh-26_00833_03384/CAM_F0/bcf630737fbf5d29.jpg differ
diff --git a/data/navsim_data/sensor_blobs/trainval/2021.10.11.02.57.41_veh-50_00352_00535/CAM_F0/e9c0df7e1fae5b0c.jpg b/data/navsim_data/sensor_blobs/trainval/2021.10.11.02.57.41_veh-50_00352_00535/CAM_F0/e9c0df7e1fae5b0c.jpg
new file mode 100644
index 0000000..66edd85
Binary files /dev/null and b/data/navsim_data/sensor_blobs/trainval/2021.10.11.02.57.41_veh-50_00352_00535/CAM_F0/e9c0df7e1fae5b0c.jpg differ
diff --git a/flow_matching/__init__.py b/flow_matching/__init__.py
new file mode 100644
index 0000000..7975227
--- /dev/null
+++ b/flow_matching/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+
+__version__ = "1.0.9"
diff --git a/flow_matching/data/__init__.py b/flow_matching/data/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/flow_matching/data/navsim.py b/flow_matching/data/navsim.py
new file mode 100644
index 0000000..d6d2fdf
--- /dev/null
+++ b/flow_matching/data/navsim.py
@@ -0,0 +1,261 @@
+# LINT_ME
+import os
+import json
+import random
+
+import numpy as np
+from PIL import Image
+import torch
+from torch.utils.data import Dataset
+from torchvision import transforms
+
+from fudoki.janus.models import VLChatProcessor
+
+
+VOCABULARY_SIZE_TXT = 102400
+VOCABULARY_SIZE_IMG = 16384
+IMG_LEN = 576
+
+
+def resize_pad(image, image_size=384):
+ w, h = image.size
+ if w <= 0 or h <= 0:
+ return image.resize((image_size, image_size), Image.Resampling.BILINEAR)
+
+ resize_scale = image_size / max(w, h)
+ new_w = max(1, int(w * resize_scale))
+ new_h = max(1, int(h * resize_scale))
+
+ padding_color = (127, 127, 127)
+ new_image = Image.new('RGB', (image_size, image_size), padding_color)
+
+ if new_w <= 0 or new_h <= 0:
+ return image.resize((image_size, image_size), Image.Resampling.BILINEAR)
+
+ image = image.resize((new_w, new_h), Image.Resampling.BILINEAR)
+
+ paste_x = (image_size - new_w) // 2
+ paste_y = (image_size - new_h) // 2
+
+ new_image.paste(image, (paste_x, paste_y))
+ return new_image
+
+
+class SupervisedDataset(Dataset):
+ """
+ Dataset for supervised training on image-text conversation data.
+
+ Loads image-text conversation samples from JSON/JSONL files, processes images (resize/pad/normalize),
+ tokenizes text prompts with image placeholders, and formats data for training.
+ """
+ def __init__(
+ self,
+ data_list: list,
+ vl_chat_processor: VLChatProcessor,
+ txt_max_length=500
+ ):
+ super().__init__()
+ self.vl_chat_processor = vl_chat_processor
+ self.txt_max_length = txt_max_length
+ self.list_data_dict = []
+
+ self.split_token = None
+
+ self.transform_img = transforms.Compose([
+ transforms.Lambda(resize_pad),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
+ ])
+
+ for path in data_list:
+ jsonl_list = []
+ ext = os.path.splitext(path)[-1].lower()
+
+ with open(path, "r", encoding="utf-8") as f:
+ if ext == ".jsonl":
+ for line in f:
+ line = line.strip()
+ if line:
+ jsonl_list.append(json.loads(line))
+ elif ext == ".json":
+ data = json.load(f)
+ if isinstance(data, list):
+ jsonl_list.extend(data)
+ else:
+ jsonl_list.append(data)
+ else:
+ raise ValueError(f"Unsupported file extension: {ext}")
+
+ self.list_data_dict.extend(jsonl_list)
+
+
+ def __len__(self):
+ return len(self.list_data_dict)
+
+ def __getitem__(self, i):
+ try:
+ sample = self._get_item(i)
+ return sample
+ except Exception as e:
+ print(f"[Warning] Error loading index {i}: {e}")
+ # random choice
+ rand_idx = np.random.randint(0, len(self.list_data_dict))
+ print(f"[Retry] Loading random index {rand_idx} instead.")
+ return self.__getitem__(rand_idx)
+
+
+ def _get_item(self, i):
+ item = self.list_data_dict[i]
+
+ if "image" not in item:
+ raise ValueError("Currently only image-based samples are supported.")
+
+ image_list = item["image"]
+ conversation = item["conversations"]
+
+ if conversation[0]["from"] == "system":
+ addition_system_prompt = conversation[0]["value"]
+ conversation = conversation[1:]
+ else:
+ addition_system_prompt = ""
+
+ # conv_length = len(conversation)
+ # assert conv_length == 2, "only support single turn"
+
+ conversation = self.construct_conv(conversation)
+
+ data_dict = self.process_image_item(
+ image_list,
+ conversation,
+ system_prompt=addition_system_prompt,
+ txt_max_length=self.txt_max_length
+ )
+ return data_dict
+
+ def construct_conv(self, conv):
+ _conv = []
+ for item in conv:
+ role = item["from"]
+ content = item["value"]
+ _item = {}
+
+ if role == "human":
+ _item["role"] = "User"
+ elif role == "gpt":
+ _item["role"] = "Assistant"
+ else:
+ raise ValueError("role must be human or gpt")
+
+ if "" in content:
+ content = content.replace("", "")
+
+ _item["content"] = content
+
+ _conv.append(_item)
+
+ return _conv
+
+ def _find_split_token(self, input_ids, split_token_length):
+ # start index for "Assistant:"
+ start_index = -1
+ for j in range(len(input_ids) - split_token_length, 0, -1):
+ if input_ids[j:j + split_token_length].numpy().tolist() == self.split_token:
+ start_index = j
+ break
+ return start_index
+
+ def process_image_item(
+ self,
+ image_paths,
+ conversation,
+ system_prompt="",
+ txt_max_length=500
+ ):
+ imgs = []
+ if isinstance(image_paths, str):
+ image_paths = [image_paths]
+
+ for path in image_paths:
+ img = Image.open(path).convert("RGB")
+ imgs.append(self.transform_img(img))
+
+ if len(imgs) > 0:
+ imgs = torch.stack(imgs, dim=0) # [N, C, H, W]
+ img_len = len(imgs) * IMG_LEN
+ else:
+ imgs = None
+ img_len = 0 # default
+
+ generation_understanding_indicator = 0
+
+ sft_format = self.vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
+ conversations=conversation,
+ sft_format=self.vl_chat_processor.sft_format,
+ system_prompt=self.vl_chat_processor.system_prompt + system_prompt,
+ )
+
+ # tokenize
+ input_ids = self.vl_chat_processor.tokenizer.encode(sft_format)
+ input_ids = torch.LongTensor(input_ids)
+
+ # add image tokens to the input_ids
+ image_token_mask = input_ids == self.vl_chat_processor.image_id
+ image_indices = image_token_mask.nonzero()
+ assert len(image_indices) == len(image_paths), \
+ f"Number of images ({len(image_paths)}) does not match number of image tokens ({len(image_indices)})"
+
+ input_ids, _ = self.vl_chat_processor.add_image_token(
+ image_indices=image_indices,
+ input_ids=input_ids,
+ )
+
+ # pad tokens
+ if input_ids.shape[0] >= txt_max_length + img_len:
+ rows_to_pad = random.randint(0, 50)
+ else:
+ rows_to_pad = txt_max_length + img_len - input_ids.shape[0]
+ input_ids = torch.cat([input_ids, torch.LongTensor([self.vl_chat_processor.pad_id]).repeat(rows_to_pad)], dim=0)
+ attention_mask = torch.zeros((input_ids.shape[0]), dtype=torch.bool)
+ attention_mask[:] = True
+
+ # obtain image token mask and fill in img token_ids
+ if imgs is not None:
+ image_expanded_token_mask = (input_ids == self.vl_chat_processor.image_id).to(dtype=int)
+ image_expanded_mask_indices = torch.where(image_expanded_token_mask == 1)[0]
+ input_ids[image_expanded_mask_indices] = 0
+ else:
+ image_expanded_token_mask = torch.zeros_like(input_ids)
+
+ # obtain text token mask
+ # support multi turn, indicating the last one
+ if self.split_token is None:
+ self.split_token = self.vl_chat_processor.tokenizer.encode("Assistant:", add_special_tokens=False)
+ split_token_length = len(self.split_token)
+ start_index = self._find_split_token(input_ids, split_token_length)
+
+ text_expanded_token_mask = torch.zeros_like(image_expanded_token_mask)
+ if start_index != -1:
+ text_expanded_token_mask[(start_index+split_token_length):] = 1
+ else:
+ raise ValueError("Split token not found in input_ids")
+
+
+ generation_or_understanding_mask = generation_understanding_indicator
+ data_info = {}
+ data_info['text_token_mask'] = text_expanded_token_mask
+ data_info['image_token_mask'] = image_expanded_token_mask
+ data_info['generation_or_understanding_mask'] = torch.Tensor([generation_or_understanding_mask])
+
+ data_info['attention_mask'] = attention_mask
+ data_info['sft_format'] = sft_format
+
+ data_info['understanding_img'] = imgs
+ data_info['has_understanding_img'] = torch.Tensor([True]).to(dtype=int)
+
+ data_info["input_ids"] = torch.LongTensor(input_ids)
+
+ # print("\n\n\n", sft_format)
+ # target = self.vl_chat_processor.tokenizer.batch_decode(input_ids[text_expanded_token_mask == 1])
+ # print("\n \n", ''.join(target).strip())
+ # exit()
+ return data_info
diff --git a/flow_matching/data/num.py b/flow_matching/data/num.py
new file mode 100644
index 0000000..34338fd
--- /dev/null
+++ b/flow_matching/data/num.py
@@ -0,0 +1,25 @@
+import torch
+from torch.utils.data import Dataset
+import numpy as np
+import random
+
+class SupervisedDataset(Dataset):
+ def __init__(self, vl_chat_processor, seq_len=512, total_samples=100000,
+ min_num=-100, max_num=100, interval=0.01):
+ self.tokenizer = vl_chat_processor.tokenizer
+ self.seq_len = seq_len
+ self.total_samples = total_samples
+ self.all_num = np.linspace(min_num, max_num, int((max_num - min_num) / interval) + 1)
+ self.min_num = min_num
+ self.max_num = max_num
+ self.interval = interval
+
+ def __len__(self):
+ return self.total_samples
+
+ def __getitem__(self, idx):
+ sampled_nums = random.sample(list(self.all_num), self.seq_len)
+ sampled_tokens = [f"{x:.2f}" for x in sampled_nums]
+ token_ids = self.tokenizer.encode(" ".join(sampled_tokens), add_special_tokens=False)
+ token_ids = torch.tensor(token_ids, dtype=torch.long)
+ return {"input_ids": token_ids}
diff --git a/flow_matching/loss/__init__.py b/flow_matching/loss/__init__.py
new file mode 100644
index 0000000..24ec1a9
--- /dev/null
+++ b/flow_matching/loss/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .generalized_loss import MixturePathGeneralizedKL
+
+__all__ = [
+ "MixturePathGeneralizedKL",
+]
diff --git a/flow_matching/loss/generalized_loss.py b/flow_matching/loss/generalized_loss.py
new file mode 100644
index 0000000..cc1507e
--- /dev/null
+++ b/flow_matching/loss/generalized_loss.py
@@ -0,0 +1,80 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import Tensor
+from torch.nn.modules.loss import _Loss
+
+from flow_matching.path import MixtureDiscreteProbPath
+
+
+class MixturePathGeneralizedKL(_Loss):
+ r"""A generalized KL loss for discrete flow matching.
+ A class that measures the generalized KL of a discrete flow model :math:`p_{1|t}` w.r.t. a probability path given by ``path``. Note: this class is assuming that the model is trained on the same path.
+
+ For a model trained on a space :math:`\mathcal{S} = \mathcal{T}^d`, :math:`\mathcal{T} = [K] = \set{1,2,\ldots,K}`, the loss is given by
+
+ .. math::
+ \ell_i(x_1, x_t, t) = -\frac{\dot{\kappa}_t}{1-\kappa_t} \biggr[ p_{1|t}(x_t^i|x_t) -\delta_{x^i_1}(x_t^i) + (1-\delta_{x^i_1}(x_t^i))\left(\log p_{1|t}(x_1^i|x_t)\right)\biggr],
+
+ where :math:`\kappa_t` is the scheduler associated with ``path``.
+
+ Args:
+ path (MixtureDiscreteProbPath): Probability path (x-prediction training).
+ reduction (str, optional): Specify the reduction to apply to the output ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction is applied to the output, ``'mean'``: the output is reduced by mean over sequence elements, ``'sum'``: the output is reduced by sum over sequence elements. Defaults to 'mean'.
+ """
+
+ def __init__(self, path: MixtureDiscreteProbPath, reduction: str = "mean") -> None:
+ super().__init__(None, None, reduction)
+ self.path = path
+
+ def forward(self, logits: Tensor, x_1: Tensor, x_t: Tensor, t: Tensor) -> Tensor:
+ r"""Evaluates the generalized KL loss.
+
+ Args:
+ logits (Tensor): posterior model output (i.e., softmax(``logits``) :math:`=p_{1|t}(x|x_t)`), shape (batch, d, K).
+ x_1 (Tensor): target data point :math:`x_1 \sim q`, shape (batch, d).
+ x_t (Tensor): conditional sample at :math:`x_t \sim p_t(\cdot|x_1)`, shape (batch, d).
+ t (Tensor): times in :math:`[0,1]`, shape (batch).
+
+ Raises:
+ ValueError: reduction value must be one of ``'none'`` | ``'mean'`` | ``'sum'``.
+
+ Returns:
+ Tensor: Generalized KL loss.
+ """
+ x_1_shape = x_1.shape
+
+ # extract x_1 value of log(p_{1|t}(x|x_t)).
+ log_p_1t = torch.log_softmax(logits, dim=-1)
+ log_p_1t_x1 = torch.gather(log_p_1t, dim=-1, index=x_1.unsqueeze(-1))
+ log_p_1t_x1 = log_p_1t_x1.view(*x_1_shape)
+
+ # extract x_t value of p_{1|t}(x|x_t).
+ p_1t = torch.exp(log_p_1t)
+ p_1t_xt = torch.gather(p_1t, dim=-1, index=x_t.unsqueeze(-1))
+ p_1t_xt = p_1t_xt.view(*x_1_shape)
+
+ scheduler_output = self.path.scheduler(t)
+
+ jump_coefficient = (
+ scheduler_output.d_alpha_t / (1 - scheduler_output.alpha_t)
+ )[(...,) + (None,) * (x_1.dim() - 1)]
+ jump_coefficient = jump_coefficient.repeat(1, *x_1_shape[1:])
+ delta_x1_xt = (x_t == x_1).to(log_p_1t.dtype)
+
+ loss = -jump_coefficient * (
+ p_1t_xt - delta_x1_xt + (1 - delta_x1_xt) * log_p_1t_x1
+ )
+
+ if self.reduction == "mean":
+ return torch.mean(loss)
+ elif self.reduction == "sum":
+ return torch.sum(loss)
+ elif self.reduction == "none":
+ return loss
+ else:
+ raise ValueError(f"{self.reduction} is not a valid value for reduction")
diff --git a/flow_matching/path/__init__.py b/flow_matching/path/__init__.py
new file mode 100644
index 0000000..c33cd0d
--- /dev/null
+++ b/flow_matching/path/__init__.py
@@ -0,0 +1,22 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .affine import AffineProbPath, CondOTProbPath
+from .geodesic import GeodesicProbPath
+from .mixture import MixtureDiscreteProbPath, MixtureDiscreteSoftmaxProbPath
+from .path import ProbPath
+from .path_sample import DiscretePathSample, PathSample
+
+
+__all__ = [
+ "ProbPath",
+ "AffineProbPath",
+ "CondOTProbPath",
+ "MixtureDiscreteProbPath",
+ "GeodesicProbPath",
+ "PathSample",
+ "DiscretePathSample",
+]
diff --git a/flow_matching/path/affine.py b/flow_matching/path/affine.py
new file mode 100644
index 0000000..81cb7ed
--- /dev/null
+++ b/flow_matching/path/affine.py
@@ -0,0 +1,260 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+
+from torch import Tensor
+
+from flow_matching.path.path import ProbPath
+from flow_matching.path.path_sample import PathSample
+from flow_matching.path.scheduler.scheduler import CondOTScheduler, Scheduler
+from flow_matching.utils import expand_tensor_like
+
+
+class AffineProbPath(ProbPath):
+ r"""The ``AffineProbPath`` class represents a specific type of probability path where the transformation between distributions is affine.
+ An affine transformation can be represented as:
+
+ .. math::
+
+ X_t = \alpha_t X_1 + \sigma_t X_0,
+
+ where :math:`X_t` is the transformed data point at time `t`. :math:`X_0` and :math:`X_1` are the source and target data points, respectively. :math:`\alpha_t` and :math:`\sigma_t` are the parameters of the affine transformation at time `t`.
+
+ The scheduler is responsible for providing the time-dependent parameters :math:`\alpha_t` and :math:`\sigma_t`, as well as their derivatives, which define the affine transformation at any given time `t`.
+
+ Using ``AffineProbPath`` in the flow matching framework:
+
+ .. code-block:: python
+
+ # Instantiates a probability path
+ my_path = AffineProbPath(...)
+ mse_loss = torch.nn.MSELoss()
+
+ for x_1 in dataset:
+ # Sets x_0 to random noise
+ x_0 = torch.randn()
+
+ # Sets t to a random value in [0,1]
+ t = torch.rand()
+
+ # Samples the conditional path X_t ~ p_t(X_t|X_0,X_1)
+ path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t)
+
+ # Computes the MSE loss w.r.t. the velocity
+ loss = mse_loss(path_sample.dx_t, my_model(x_t, t))
+ loss.backward()
+
+ Args:
+ scheduler (Scheduler): An instance of a scheduler that provides the parameters :math:`\alpha_t`, :math:`\sigma_t`, and their derivatives over time.
+
+ """
+
+ def __init__(self, scheduler: Scheduler):
+ self.scheduler = scheduler
+
+ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample:
+ r"""Sample from the affine probability path:
+
+ | given :math:`(X_0,X_1) \sim \pi(X_0,X_1)` and a scheduler :math:`(\alpha_t,\sigma_t)`.
+ | return :math:`X_0, X_1, X_t = \alpha_t X_1 + \sigma_t X_0`, and the conditional velocity at :math:`X_t, \dot{X}_t = \dot{\alpha}_t X_1 + \dot{\sigma}_t X_0`.
+
+ Args:
+ x_0 (Tensor): source data point, shape (batch_size, ...).
+ x_1 (Tensor): target data point, shape (batch_size, ...).
+ t (Tensor): times in [0,1], shape (batch_size).
+
+ Returns:
+ PathSample: a conditional sample at :math:`X_t \sim p_t`.
+ """
+ self.assert_sample_shape(x_0=x_0, x_1=x_1, t=t)
+
+ scheduler_output = self.scheduler(t)
+
+ alpha_t = expand_tensor_like(
+ input_tensor=scheduler_output.alpha_t, expand_to=x_1
+ )
+ sigma_t = expand_tensor_like(
+ input_tensor=scheduler_output.sigma_t, expand_to=x_1
+ )
+ d_alpha_t = expand_tensor_like(
+ input_tensor=scheduler_output.d_alpha_t, expand_to=x_1
+ )
+ d_sigma_t = expand_tensor_like(
+ input_tensor=scheduler_output.d_sigma_t, expand_to=x_1
+ )
+
+ # construct xt ~ p_t(x|x1).
+ x_t = sigma_t * x_0 + alpha_t * x_1
+ dx_t = d_sigma_t * x_0 + d_alpha_t * x_1
+
+ return PathSample(x_t=x_t, dx_t=dx_t, x_1=x_1, x_0=x_0, t=t)
+
+ def target_to_velocity(self, x_1: Tensor, x_t: Tensor, t: Tensor) -> Tensor:
+ r"""Convert from x_1 representation to velocity.
+
+ | given :math:`X_1`.
+ | return :math:`\dot{X}_t`.
+
+ Args:
+ x_1 (Tensor): target data point.
+ x_t (Tensor): path sample at time t.
+ t (Tensor): time in [0,1].
+
+ Returns:
+ Tensor: velocity.
+ """
+ scheduler_output = self.scheduler(t)
+
+ alpha_t = scheduler_output.alpha_t
+ d_alpha_t = scheduler_output.d_alpha_t
+ sigma_t = scheduler_output.sigma_t
+ d_sigma_t = scheduler_output.d_sigma_t
+
+ a_t = d_sigma_t / sigma_t
+ b_t = (d_alpha_t * sigma_t - d_sigma_t * alpha_t) / sigma_t
+
+ return a_t * x_t + b_t * x_1
+
+ def epsilon_to_velocity(self, epsilon: Tensor, x_t: Tensor, t: Tensor) -> Tensor:
+ r"""Convert from epsilon representation to velocity.
+
+ | given :math:`\epsilon`.
+ | return :math:`\dot{X}_t`.
+
+ Args:
+ epsilon (Tensor): noise in the path sample.
+ x_t (Tensor): path sample at time t.
+ t (Tensor): time in [0,1].
+
+ Returns:
+ Tensor: velocity.
+ """
+ scheduler_output = self.scheduler(t)
+
+ alpha_t = scheduler_output.alpha_t
+ d_alpha_t = scheduler_output.d_alpha_t
+ sigma_t = scheduler_output.sigma_t
+ d_sigma_t = scheduler_output.d_sigma_t
+
+ a_t = d_alpha_t / alpha_t
+ b_t = (d_sigma_t * alpha_t - d_alpha_t * sigma_t) / alpha_t
+
+ return a_t * x_t + b_t * epsilon
+
+ def velocity_to_target(self, velocity: Tensor, x_t: Tensor, t: Tensor) -> Tensor:
+ r"""Convert from velocity to x_1 representation.
+
+ | given :math:`\dot{X}_t`.
+ | return :math:`X_1`.
+
+ Args:
+ velocity (Tensor): velocity at the path sample.
+ x_t (Tensor): path sample at time t.
+ t (Tensor): time in [0,1].
+
+ Returns:
+ Tensor: target data point.
+ """
+ scheduler_output = self.scheduler(t)
+
+ alpha_t = scheduler_output.alpha_t
+ d_alpha_t = scheduler_output.d_alpha_t
+ sigma_t = scheduler_output.sigma_t
+ d_sigma_t = scheduler_output.d_sigma_t
+
+ a_t = -d_sigma_t / (d_alpha_t * sigma_t - d_sigma_t * alpha_t)
+ b_t = sigma_t / (d_alpha_t * sigma_t - d_sigma_t * alpha_t)
+
+ return a_t * x_t + b_t * velocity
+
+ def epsilon_to_target(self, epsilon: Tensor, x_t: Tensor, t: Tensor) -> Tensor:
+ r"""Convert from epsilon representation to x_1 representation.
+
+ | given :math:`\epsilon`.
+ | return :math:`X_1`.
+
+ Args:
+ epsilon (Tensor): noise in the path sample.
+ x_t (Tensor): path sample at time t.
+ t (Tensor): time in [0,1].
+
+ Returns:
+ Tensor: target data point.
+ """
+ scheduler_output = self.scheduler(t)
+
+ alpha_t = scheduler_output.alpha_t
+ sigma_t = scheduler_output.sigma_t
+
+ a_t = 1 / alpha_t
+ b_t = -sigma_t / alpha_t
+
+ return a_t * x_t + b_t * epsilon
+
+ def velocity_to_epsilon(self, velocity: Tensor, x_t: Tensor, t: Tensor) -> Tensor:
+ r"""Convert from velocity to noise representation.
+
+ | given :math:`\dot{X}_t`.
+ | return :math:`\epsilon`.
+
+ Args:
+ velocity (Tensor): velocity at the path sample.
+ x_t (Tensor): path sample at time t.
+ t (Tensor): time in [0,1].
+
+ Returns:
+ Tensor: noise in the path sample.
+ """
+ scheduler_output = self.scheduler(t)
+
+ alpha_t = scheduler_output.alpha_t
+ d_alpha_t = scheduler_output.d_alpha_t
+ sigma_t = scheduler_output.sigma_t
+ d_sigma_t = scheduler_output.d_sigma_t
+
+ a_t = -d_alpha_t / (d_sigma_t * alpha_t - d_alpha_t * sigma_t)
+ b_t = alpha_t / (d_sigma_t * alpha_t - d_alpha_t * sigma_t)
+
+ return a_t * x_t + b_t * velocity
+
+ def target_to_epsilon(self, x_1: Tensor, x_t: Tensor, t: Tensor) -> Tensor:
+ r"""Convert from x_1 representation to velocity.
+
+ | given :math:`X_1`.
+ | return :math:`\epsilon`.
+
+ Args:
+ x_1 (Tensor): target data point.
+ x_t (Tensor): path sample at time t.
+ t (Tensor): time in [0,1].
+
+ Returns:
+ Tensor: noise in the path sample.
+ """
+ scheduler_output = self.scheduler(t)
+
+ alpha_t = scheduler_output.alpha_t
+ sigma_t = scheduler_output.sigma_t
+
+ a_t = 1 / sigma_t
+ b_t = -alpha_t / sigma_t
+
+ return a_t * x_t + b_t * x_1
+
+
+class CondOTProbPath(AffineProbPath):
+ r"""The ``CondOTProbPath`` class represents a conditional optimal transport probability path.
+
+ This class is a specialized version of the ``AffineProbPath`` that uses a conditional optimal transport scheduler to determine the parameters of the affine transformation.
+
+ The parameters :math:`\alpha_t` and :math:`\sigma_t` for the conditional optimal transport path are defined as:
+
+ .. math::
+
+ \alpha_t = t \quad \text{and} \quad \sigma_t = 1 - t.
+ """
+
+ def __init__(self):
+ self.scheduler = CondOTScheduler()
diff --git a/flow_matching/path/geodesic.py b/flow_matching/path/geodesic.py
new file mode 100644
index 0000000..d04bf67
--- /dev/null
+++ b/flow_matching/path/geodesic.py
@@ -0,0 +1,100 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from torch import Tensor
+from torch.func import jvp, vmap
+
+from flow_matching.path.path import ProbPath
+
+from flow_matching.path.path_sample import PathSample
+from flow_matching.path.scheduler import ConvexScheduler
+from flow_matching.utils import expand_tensor_like
+
+from flow_matching.utils.manifolds import geodesic, Manifold
+
+
+class GeodesicProbPath(ProbPath):
+ r"""The ``GeodesicProbPath`` class represents a specific type of probability path where the transformation between distributions is defined through the geodesic path.
+ Mathematically, a geodesic path can be represented as:
+
+ .. math::
+
+ X_t = \psi_t(X_0 | X_1) = \exp_{X_1}(\kappa_t \log_{X_1}(X_0)),
+
+ where :math:`X_t` is the transformed data point at time `t`, :math:`X_0` and :math:`X_1` are the source and target data points, respectively, and :math:`\kappa_t` is a scheduler.
+
+ The scheduler is responsible for providing the time-dependent :math:`\kappa_t` and must be differentiable.
+
+ Using ``GeodesicProbPath`` in the flow matching framework:
+
+ .. code-block:: python
+ # Instantiates a manifold
+ manifold = FlatTorus()
+
+ # Instantiates a scheduler
+ scheduler = CondOTScheduler()
+
+ # Instantiates a probability path
+ my_path = GeodesicProbPath(scheduler, manifold)
+ mse_loss = torch.nn.MSELoss()
+
+ for x_1 in dataset:
+ # Sets x_0 to random noise
+ x_0 = torch.randn()
+
+ # Sets t to a random value in [0,1]
+ t = torch.rand()
+
+ # Samples the conditional path :math:`X_t \sim p_t(X_t|X_0,X_1)`
+ path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t)
+
+ # Computes the MSE loss w.r.t. the velocity
+ loss = mse_loss(path_sample.dx_t, my_model(x_t, t))
+ loss.backward()
+
+ Args:
+ scheduler (ConvexScheduler): The scheduler that provides :math:`\kappa_t`.
+ manifold (Manifold): The manifold on which the probability path is defined.
+
+ """
+
+ def __init__(self, scheduler: ConvexScheduler, manifold: Manifold):
+ self.scheduler = scheduler
+ self.manifold = manifold
+
+ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample:
+ r"""Sample from the Riemannian probability path with geodesic interpolation:
+
+ | given :math:`(X_0,X_1) \sim \pi(X_0,X_1)` and a scheduler :math:`\kappa_t`.
+ | return :math:`X_0, X_1, X_t = \exp_{X_1}(\kappa_t \log_{X_1}(X_0))`, and the conditional velocity at :math:`X_t, \dot{X}_t`.
+
+ Args:
+ x_0 (Tensor): source data point, shape (batch_size, ...).
+ x_1 (Tensor): target data point, shape (batch_size, ...).
+ t (Tensor): times in [0,1], shape (batch_size).
+
+ Returns:
+ PathSample: A conditional sample at :math:`X_t \sim p_t`.
+ """
+ self.assert_sample_shape(x_0=x_0, x_1=x_1, t=t)
+ t = expand_tensor_like(input_tensor=t, expand_to=x_1[..., 0:1]).clone()
+
+ def cond_u(x_0, x_1, t):
+ path = geodesic(self.manifold, x_0, x_1)
+ x_t, dx_t = jvp(
+ lambda t: path(self.scheduler(t).alpha_t),
+ (t,),
+ (torch.ones_like(t).to(t),),
+ )
+ return x_t, dx_t
+
+ x_t, dx_t = vmap(cond_u)(x_0, x_1, t)
+ x_t = x_t.reshape_as(x_1)
+ dx_t = dx_t.reshape_as(x_1)
+
+ return PathSample(x_t=x_t, dx_t=dx_t, x_1=x_1, x_0=x_0, t=t)
diff --git a/flow_matching/path/mixture.py b/flow_matching/path/mixture.py
new file mode 100644
index 0000000..f04c01a
--- /dev/null
+++ b/flow_matching/path/mixture.py
@@ -0,0 +1,190 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+
+from torch import Tensor
+
+from flow_matching.path.path import ProbPath
+
+from flow_matching.path.path_sample import DiscretePathSample
+from flow_matching.path.scheduler import ConvexScheduler
+from flow_matching.utils import expand_tensor_like, unsqueeze_to_match
+
+
+class MixtureDiscreteProbPath(ProbPath):
+ r"""The ``MixtureDiscreteProbPath`` class defines a factorized discrete probability path.
+
+ This path remains constant at the source data point :math:`X_0` until a random time, determined by the scheduler, when it flips to the target data point :math:`X_1`.
+ The scheduler determines the flip probability using the parameter :math:`\sigma_t`, which is a function of time `t`. Specifically, :math:`\sigma_t` represents the probability of remaining at :math:`X_0`, while :math:`1 - \sigma_t` is the probability of flipping to :math:`X_1`:
+
+ .. math::
+
+ P(X_t = X_0) = \sigma_t \quad \text{and} \quad P(X_t = X_1) = 1 - \sigma_t,
+
+ where :math:`\sigma_t` is provided by the scheduler.
+
+ Example:
+
+ .. code-block:: python
+
+ >>> x_0 = torch.zeros((1, 3, 3))
+ >>> x_1 = torch.ones((1, 3, 3))
+
+ >>> path = MixtureDiscreteProbPath(PolynomialConvexScheduler(n=1.0))
+ >>> result = path.sample(x_0, x_1, t=torch.tensor([0.1])).x_t
+ >>> result
+ tensor([[[0.0, 0.0, 0.0],
+ [0.0, 0.0, 1.0],
+ [0.0, 0.0, 0.0]]])
+
+ >>> result = path.sample(x_0, x_1, t=torch.tensor([0.5])).x_t
+ >>> result
+ tensor([[[1.0, 0.0, 1.0],
+ [0.0, 1.0, 0.0],
+ [0.0, 1.0, 0.0]]])
+
+ >>> result = path.sample(x_0, x_1, t=torch.tensor([1.0])).x_t
+ >>> result
+ tensor([[[1.0, 1.0, 1.0],
+ [1.0, 1.0, 1.0],
+ [1.0, 1.0, 1.0]]])
+
+ Args:
+ scheduler (ConvexScheduler): The scheduler that provides :math:`\sigma_t`.
+ """
+
+ def __init__(self, scheduler: ConvexScheduler):
+ assert isinstance(
+ scheduler, ConvexScheduler
+ ), "Scheduler for ConvexProbPath must be a ConvexScheduler."
+
+ self.scheduler = scheduler
+
+ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> DiscretePathSample:
+ r"""Sample from the affine probability path:
+ | given :math:`(X_0,X_1) \sim \pi(X_0,X_1)` and a scheduler :math:`(\alpha_t,\sigma_t)`.
+ | return :math:`X_0, X_1, t`, and :math:`X_t \sim p_t`.
+ Args:
+ x_0 (Tensor): source data point, shape (batch_size, ...).
+ x_1 (Tensor): target data point, shape (batch_size, ...).
+ t (Tensor): times in [0,1], shape (batch_size).
+
+ Returns:
+ DiscretePathSample: a conditional sample at :math:`X_t ~ p_t`.
+ """
+ self.assert_sample_shape(x_0=x_0, x_1=x_1, t=t)
+
+ sigma_t = self.scheduler(t).sigma_t
+
+ sigma_t = expand_tensor_like(input_tensor=sigma_t, expand_to=x_1)
+
+ source_indices = torch.rand(size=x_1.shape, device=x_1.device) < sigma_t
+ x_t = torch.where(condition=source_indices, input=x_0, other=x_1)
+
+ return DiscretePathSample(x_t=x_t, x_1=x_1, x_0=x_0, t=t)
+
+ def posterior_to_velocity(
+ self, posterior_logits: Tensor, x_t: Tensor, t: Tensor
+ ) -> Tensor:
+ r"""Convert the factorized posterior to velocity.
+
+ | given :math:`p(X_1|X_t)`. In the factorized case: :math:`\prod_i p(X_1^i | X_t)`.
+ | return :math:`u_t`.
+
+ Args:
+ posterior_logits (Tensor): logits of the x_1 posterior conditional on x_t, shape (..., vocab size).
+ x_t (Tensor): path sample at time t, shape (...).
+ t (Tensor): time in [0,1].
+
+ Returns:
+ Tensor: velocity.
+ """
+ posterior = torch.softmax(posterior_logits, dim=-1)
+ vocabulary_size = posterior.shape[-1]
+ x_t = F.one_hot(x_t, num_classes=vocabulary_size)
+ t = unsqueeze_to_match(source=t, target=x_t)
+
+ scheduler_output = self.scheduler(t)
+
+ kappa_t = scheduler_output.alpha_t
+ d_kappa_t = scheduler_output.d_alpha_t
+
+ return (d_kappa_t / (1 - kappa_t)) * (posterior - x_t)
+
+
+class MixtureDiscreteSoftmaxProbPath(ProbPath):
+ def __init__(self, mode, embedding_path):
+ self.a = 0.9
+ self.c = 3
+ assert mode in ['image', 'text'], f"Unsupported mode probability path: {mode}"
+ self.mode = mode
+ self.embedding_path = embedding_path
+ self.embedding = self.get_embedding(embedding_path)
+ self.embedding.weight.requires_grad = False
+ self.embedding = self.embedding
+ torch.cuda.empty_cache()
+
+ def get_embedding(self, embedding_path):
+ # with torch.serialization.safe_globals([torch.nn.modules.sparse.Embedding]):
+ embedding = torch.load(embedding_path, map_location="cpu")
+ embedding.requires_grad_(False)
+ torch.cuda.empty_cache()
+ return embedding.cuda()
+
+ def set_embedding(self, new_embedding):
+ self.embedding = new_embedding
+
+ def metric(self, z):
+ z_flattened = z.view(-1, z.shape[-1])
+ z = F.normalize(z, p=2, dim=-1)
+ z_flattened = F.normalize(z_flattened, p=2, dim=-1)
+ embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
+ d = (torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
+ torch.sum(embedding**2, dim=1) - 2 * \
+ torch.einsum('bd,dn->bn', z_flattened, torch.einsum('n d -> d n', embedding))) ** 2
+ return d
+
+ def get_prob_distribution(self, z, t):
+ b, s = z.shape[:2]
+ d = self.metric(z)
+ d = d.reshape(b, s, -1)
+ beta_t = self.c * ((t / (1 - t)) ** self.a)
+ if beta_t.shape[0] == b:
+ beta_t = beta_t.reshape(b, 1, 1)
+ # print(beta_t.shape)
+ d = d * (-1) * beta_t
+ d = torch.softmax(d, dim=-1)
+ return d
+
+ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> DiscretePathSample:
+ # emb_x_0 = self.self.embedding(x_0).squeeze() # 2, 256, 8
+ emb_x_1 = self.embedding(x_1)
+ # prob_x_0 = self.get_prob_distribution(emb_x_0)
+ prob_x_t = self.get_prob_distribution(emb_x_1, t)
+ b, s = prob_x_t.shape[:2]
+ x_t = torch.multinomial(prob_x_t.reshape(b*s, -1), num_samples=1, replacement=False)
+ x_t = x_t.reshape(b, s)
+ return DiscretePathSample(x_t=x_t, x_1=x_1, x_0=x_0, t=t)
+
+ def posterior_to_velocity(
+ self, posterior_logits: Tensor, x_t: Tensor, t: Tensor
+ ) -> Tensor:
+ r"""Convert the factorized posterior to velocity.
+
+ | given :math:`p(X_1|X_t)`. In the factorized case: :math:`\prod_i p(X_1^i | X_t)`.
+ | return :math:`u_t`.
+
+ Args:
+ posterior_logits (Tensor): logits of the x_1 posterior conditional on x_t, shape (..., vocab size).
+ x_t (Tensor): path sample at time t, shape (...).
+ t (Tensor): time in [0,1].
+
+ Returns:
+ Tensor: velocity.
+ """
+ raise NotImplementedError
diff --git a/flow_matching/path/path.py b/flow_matching/path/path.py
new file mode 100644
index 0000000..c133a14
--- /dev/null
+++ b/flow_matching/path/path.py
@@ -0,0 +1,61 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+
+from abc import ABC, abstractmethod
+
+from torch import Tensor
+
+from flow_matching.path.path_sample import PathSample
+
+
+class ProbPath(ABC):
+ r"""Abstract class, representing a probability path.
+
+ A probability path transforms the distribution :math:`p(X_0)` into :math:`p(X_1)` over :math:`t=0\rightarrow 1`.
+
+ The ``ProbPath`` class is designed to support model training in the flow matching framework. It supports two key functionalities: (1) sampling the conditional probability path and (2) conversion between various training objectives.
+ Here is a high-level example
+
+ .. code-block:: python
+
+ # Instantiate a probability path
+ my_path = ProbPath(...)
+
+ for x_0, x_1 in dataset:
+ # Sets t to a random value in [0,1]
+ t = torch.rand()
+
+ # Samples the conditional path X_t ~ p_t(X_t|X_0,X_1)
+ path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t)
+
+ # Optimizes the model. The loss function varies, depending on model and path.
+ loss(path_sample, my_model(x_t, t)).backward()
+
+ """
+
+ @abstractmethod
+ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample:
+ r"""Sample from an abstract probability path:
+
+ | given :math:`(X_0,X_1) \sim \pi(X_0,X_1)`.
+ | returns :math:`X_0, X_1, X_t \sim p_t(X_t)`, and a conditional target :math:`Y`, all objects are under ``PathSample``.
+
+ Args:
+ x_0 (Tensor): source data point, shape (batch_size, ...).
+ x_1 (Tensor): target data point, shape (batch_size, ...).
+ t (Tensor): times in [0,1], shape (batch_size).
+
+ Returns:
+ PathSample: a conditional sample.
+ """
+
+ def assert_sample_shape(self, x_0: Tensor, x_1: Tensor, t: Tensor):
+ assert (
+ t.ndim == 1
+ ), f"The time vector t must have shape [batch_size]. Got {t.shape}."
+ assert (
+ t.shape[0] == x_0.shape[0] == x_1.shape[0]
+ ), f"Time t dimension must match the batch size [{x_1.shape[0]}]. Got {t.shape}"
diff --git a/flow_matching/path/path_sample.py b/flow_matching/path/path_sample.py
new file mode 100644
index 0000000..867032e
--- /dev/null
+++ b/flow_matching/path/path_sample.py
@@ -0,0 +1,53 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+
+from dataclasses import dataclass, field
+
+from torch import Tensor
+
+
+@dataclass
+class PathSample:
+ r"""Represents a sample of a conditional-flow generated probability path.
+
+ Attributes:
+ x_1 (Tensor): the target sample :math:`X_1`.
+ x_0 (Tensor): the source sample :math:`X_0`.
+ t (Tensor): the time sample :math:`t`.
+ x_t (Tensor): samples :math:`X_t \sim p_t(X_t)`, shape (batch_size, ...).
+ dx_t (Tensor): conditional target :math:`\frac{\partial X}{\partial t}`, shape: (batch_size, ...).
+
+ """
+
+ x_1: Tensor = field(metadata={"help": "target samples X_1 (batch_size, ...)."})
+ x_0: Tensor = field(metadata={"help": "source samples X_0 (batch_size, ...)."})
+ t: Tensor = field(metadata={"help": "time samples t (batch_size, ...)."})
+ x_t: Tensor = field(
+ metadata={"help": "samples x_t ~ p_t(X_t), shape (batch_size, ...)."}
+ )
+ dx_t: Tensor = field(
+ metadata={"help": "conditional target dX_t, shape: (batch_size, ...)."}
+ )
+
+
+@dataclass
+class DiscretePathSample:
+ """
+ Represents a sample of a conditional-flow generated discrete probability path.
+
+ Attributes:
+ x_1 (Tensor): the target sample :math:`X_1`.
+ x_0 (Tensor): the source sample :math:`X_0`.
+ t (Tensor): the time sample :math:`t`.
+ x_t (Tensor): the sample along the path :math:`X_t \sim p_t`.
+ """
+
+ x_1: Tensor = field(metadata={"help": "target samples X_1 (batch_size, ...)."})
+ x_0: Tensor = field(metadata={"help": "source samples X_0 (batch_size, ...)."})
+ t: Tensor = field(metadata={"help": "time samples t (batch_size, ...)."})
+ x_t: Tensor = field(
+ metadata={"help": "samples X_t ~ p_t(X_t), shape (batch_size, ...)."}
+ )
diff --git a/flow_matching/path/scheduler/__init__.py b/flow_matching/path/scheduler/__init__.py
new file mode 100644
index 0000000..f3b1a43
--- /dev/null
+++ b/flow_matching/path/scheduler/__init__.py
@@ -0,0 +1,29 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .schedule_transform import ScheduleTransformedModel
+from .scheduler import (
+ CondOTScheduler,
+ ConvexScheduler,
+ CosineScheduler,
+ LinearVPScheduler,
+ PolynomialConvexScheduler,
+ Scheduler,
+ SchedulerOutput,
+ VPScheduler,
+)
+
+__all__ = [
+ "CondOTScheduler",
+ "CosineScheduler",
+ "ConvexScheduler",
+ "PolynomialConvexScheduler",
+ "ScheduleTransformedModel",
+ "Scheduler",
+ "VPScheduler",
+ "LinearVPScheduler",
+ "SchedulerOutput",
+]
diff --git a/flow_matching/path/scheduler/schedule_transform.py b/flow_matching/path/scheduler/schedule_transform.py
new file mode 100644
index 0000000..a366f19
--- /dev/null
+++ b/flow_matching/path/scheduler/schedule_transform.py
@@ -0,0 +1,148 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+
+from torch import Tensor
+
+from flow_matching.path.scheduler.scheduler import Scheduler
+from flow_matching.utils import ModelWrapper
+
+
+class ScheduleTransformedModel(ModelWrapper):
+ """
+ Change of scheduler for a velocity model.
+
+ This class wraps a given velocity model and transforms its scheduling
+ to a new scheduler function. It modifies the time
+ dynamics of the model according to the new scheduler while maintaining
+ the original model's behavior.
+
+ Example:
+
+ .. code-block:: python
+
+ import torch
+ from flow_matching.path.scheduler import CondOTScheduler, CosineScheduler, ScheduleTransformedModel
+ from flow_matching.solver import ODESolver
+
+ # Initialize the model and schedulers
+ model = ...
+
+ original_scheduler = CondOTScheduler()
+ new_scheduler = CosineScheduler()
+
+ # Create the transformed model
+ transformed_model = ScheduleTransformedModel(
+ velocity_model=model,
+ original_scheduler=original_scheduler,
+ new_scheduler=new_scheduler
+ )
+
+ # Set up the solver
+ solver = ODESolver(velocity_model=transformed_model)
+
+ x_0 = torch.randn([10, 2]) # Example initial condition
+
+ x_1 = solver.sample(
+ time_steps=torch.tensor([0.0, 1.0]),
+ x_init=x_0,
+ step_size=1/1000
+ )[1]
+
+ Args:
+ velocity_model (ModelWrapper): The original velocity model to be transformed.
+ original_scheduler (Scheduler): The scheduler used by the original model. Must implement the snr_inverse function.
+ new_scheduler (Scheduler): The new scheduler to be applied to the model.
+ """
+
+ def __init__(
+ self,
+ velocity_model: ModelWrapper,
+ original_scheduler: Scheduler,
+ new_scheduler: Scheduler,
+ ):
+ super().__init__(model=velocity_model)
+ self.original_scheduler = original_scheduler
+ self.new_scheduler = new_scheduler
+
+ assert hasattr(self.original_scheduler, "snr_inverse") and callable(
+ getattr(self.original_scheduler, "snr_inverse")
+ ), "The original scheduler must have a callable 'snr_inverse' method."
+
+ def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor:
+ r"""
+ Compute the transformed marginal velocity field for a new scheduler.
+ This method implements a post-training velocity scheduler change for
+ affine conditional flows. It transforms a generating marginal velocity
+ field :math:`u_t(x)` based on an original scheduler to a new marginal velocity
+ field :math:`\bar{u}_r(x)` based on a different scheduler, while maintaining
+ the same data coupling.
+ The transformation is based on the scale-time (ST) transformation
+ between the two conditional flows, defined as:
+
+ .. math::
+
+ \bar{X}_r = s_r X_{t_r},
+
+ where :math:`X_t` and :math:`\bar{X}_r` are defined by their respective schedulers.
+ The ST transformation is computed as:
+
+ .. math::
+
+ t_r = \rho^{-1}(\bar{\rho}(r)) \quad \text{and} \quad s_r = \frac{\bar{\sigma}_r}{\sigma_{t_r}}.
+
+ Here, :math:`\rho(t)` is the signal-to-noise ratio (SNR) defined as:
+
+ .. math::
+
+ \rho(t) = \frac{\alpha_t}{\sigma_t}.
+
+ :math:`\bar{\rho}(r)` is similarly defined for the new scheduler.
+ The marginal velocity for the new scheduler is then given by:
+
+ .. math::
+
+ \bar{u}_r(x) = \left(\frac{\dot{s}_r}{s_r}\right) x + s_r \dot{t}_r u_{t_r}\left(\frac{x}{s_r}\right).
+
+ Args:
+ x (Tensor): :math:`x_t`, the input tensor.
+ t (Tensor): The time tensor (denoted as :math:`r` above).
+ **extras: Additional arguments for the model.
+ Returns:
+ Tensor: The transformed velocity.
+ """
+ r = t
+
+ r_scheduler_output = self.new_scheduler(t=r)
+
+ alpha_r = r_scheduler_output.alpha_t
+ sigma_r = r_scheduler_output.sigma_t
+ d_alpha_r = r_scheduler_output.d_alpha_t
+ d_sigma_r = r_scheduler_output.d_sigma_t
+
+ t = self.original_scheduler.snr_inverse(alpha_r / sigma_r)
+
+ t_scheduler_output = self.original_scheduler(t=t)
+
+ alpha_t = t_scheduler_output.alpha_t
+ sigma_t = t_scheduler_output.sigma_t
+ d_alpha_t = t_scheduler_output.d_alpha_t
+ d_sigma_t = t_scheduler_output.d_sigma_t
+
+ s_r = sigma_r / sigma_t
+
+ dt_r = (
+ sigma_t
+ * sigma_t
+ * (sigma_r * d_alpha_r - alpha_r * d_sigma_r)
+ / (sigma_r * sigma_r * (sigma_t * d_alpha_t - alpha_t * d_sigma_t))
+ )
+
+ ds_r = (sigma_t * d_sigma_r - sigma_r * d_sigma_t * dt_r) / (sigma_t * sigma_t)
+
+ u_t = self.model(x=x / s_r, t=t, **extras)
+ u_r = ds_r * x / s_r + dt_r * s_r * u_t
+
+ return u_r
diff --git a/flow_matching/path/scheduler/scheduler.py b/flow_matching/path/scheduler/scheduler.py
new file mode 100644
index 0000000..422618a
--- /dev/null
+++ b/flow_matching/path/scheduler/scheduler.py
@@ -0,0 +1,199 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass, field
+
+from typing import Union
+
+import torch
+
+from torch import Tensor
+
+
+@dataclass
+class SchedulerOutput:
+ r"""Represents a sample of a conditional-flow generated probability path.
+
+ Attributes:
+ alpha_t (Tensor): :math:`\alpha_t`, shape (...).
+ sigma_t (Tensor): :math:`\sigma_t`, shape (...).
+ d_alpha_t (Tensor): :math:`\frac{\partial}{\partial t}\alpha_t`, shape (...).
+ d_sigma_t (Tensor): :math:`\frac{\partial}{\partial t}\sigma_t`, shape (...).
+
+ """
+
+ alpha_t: Tensor = field(metadata={"help": "alpha_t"})
+ sigma_t: Tensor = field(metadata={"help": "sigma_t"})
+ d_alpha_t: Tensor = field(metadata={"help": "Derivative of alpha_t."})
+ d_sigma_t: Tensor = field(metadata={"help": "Derivative of sigma_t."})
+
+
+class Scheduler(ABC):
+ """Base Scheduler class."""
+
+ @abstractmethod
+ def __call__(self, t: Tensor) -> SchedulerOutput:
+ r"""
+ Args:
+ t (Tensor): times in [0,1], shape (...).
+
+ Returns:
+ SchedulerOutput: :math:`\alpha_t,\sigma_t,\frac{\partial}{\partial t}\alpha_t,\frac{\partial}{\partial t}\sigma_t`
+ """
+ ...
+
+ @abstractmethod
+ def snr_inverse(self, snr: Tensor) -> Tensor:
+ r"""
+ Computes :math:`t` from the signal-to-noise ratio :math:`\frac{\alpha_t}{\sigma_t}`.
+
+ Args:
+ snr (Tensor): The signal-to-noise, shape (...)
+
+ Returns:
+ Tensor: t, shape (...)
+ """
+ ...
+
+
+class ConvexScheduler(Scheduler):
+ @abstractmethod
+ def __call__(self, t: Tensor) -> SchedulerOutput:
+ """Scheduler for convex paths.
+
+ Args:
+ t (Tensor): times in [0,1], shape (...).
+
+ Returns:
+ SchedulerOutput: :math:`\alpha_t,\sigma_t,\frac{\partial}{\partial t}\alpha_t,\frac{\partial}{\partial t}\sigma_t`
+ """
+ ...
+
+ @abstractmethod
+ def kappa_inverse(self, kappa: Tensor) -> Tensor:
+ """
+ Computes :math:`t` from :math:`\kappa_t`.
+
+ Args:
+ kappa (Tensor): :math:`\kappa`, shape (...)
+
+ Returns:
+ Tensor: t, shape (...)
+ """
+ ...
+
+ def snr_inverse(self, snr: Tensor) -> Tensor:
+ r"""
+ Computes :math:`t` from the signal-to-noise ratio :math:`\frac{\alpha_t}{\sigma_t}`.
+
+ Args:
+ snr (Tensor): The signal-to-noise, shape (...)
+
+ Returns:
+ Tensor: t, shape (...)
+ """
+ kappa_t = snr / (1.0 + snr)
+
+ return self.kappa_inverse(kappa=kappa_t)
+
+
+class CondOTScheduler(ConvexScheduler):
+ """CondOT Scheduler."""
+
+ def __call__(self, t: Tensor) -> SchedulerOutput:
+ return SchedulerOutput(
+ alpha_t=t,
+ sigma_t=1 - t,
+ d_alpha_t=torch.ones_like(t),
+ d_sigma_t=-torch.ones_like(t),
+ )
+
+ def kappa_inverse(self, kappa: Tensor) -> Tensor:
+ return kappa
+
+
+class PolynomialConvexScheduler(ConvexScheduler):
+ """Polynomial Scheduler."""
+
+ def __init__(self, n: Union[float, int]) -> None:
+ assert isinstance(
+ n, (float, int)
+ ), f"`n` must be a float or int. Got {type(n)=}."
+ assert n > 0, f"`n` must be positive. Got {n=}."
+
+ self.n = n
+
+ def __call__(self, t: Tensor) -> SchedulerOutput:
+ return SchedulerOutput(
+ alpha_t=t**self.n,
+ sigma_t=1 - t**self.n,
+ d_alpha_t=self.n * (t ** (self.n - 1)),
+ d_sigma_t=-self.n * (t ** (self.n - 1)),
+ )
+
+ def kappa_inverse(self, kappa: Tensor) -> Tensor:
+ return torch.pow(kappa, 1.0 / self.n)
+
+
+class VPScheduler(Scheduler):
+ """Variance Preserving Scheduler."""
+
+ def __init__(self, beta_min: float = 0.1, beta_max: float = 20.0) -> None:
+ self.beta_min = beta_min
+ self.beta_max = beta_max
+ super().__init__()
+
+ def __call__(self, t: Tensor) -> SchedulerOutput:
+ b = self.beta_min
+ B = self.beta_max
+ T = 0.5 * (1 - t) ** 2 * (B - b) + (1 - t) * b
+ dT = -(1 - t) * (B - b) - b
+
+ return SchedulerOutput(
+ alpha_t=torch.exp(-0.5 * T),
+ sigma_t=torch.sqrt(1 - torch.exp(-T)),
+ d_alpha_t=-0.5 * dT * torch.exp(-0.5 * T),
+ d_sigma_t=0.5 * dT * torch.exp(-T) / torch.sqrt(1 - torch.exp(-T)),
+ )
+
+ def snr_inverse(self, snr: Tensor) -> Tensor:
+ T = -torch.log(snr**2 / (snr**2 + 1))
+ b = self.beta_min
+ B = self.beta_max
+ t = 1 - ((-b + torch.sqrt(b**2 + 2 * (B - b) * T)) / (B - b))
+ return t
+
+
+class LinearVPScheduler(Scheduler):
+ """Linear Variance Preserving Scheduler."""
+
+ def __call__(self, t: Tensor) -> SchedulerOutput:
+ return SchedulerOutput(
+ alpha_t=t,
+ sigma_t=(1 - t**2) ** 0.5,
+ d_alpha_t=torch.ones_like(t),
+ d_sigma_t=-t / (1 - t**2) ** 0.5,
+ )
+
+ def snr_inverse(self, snr: Tensor) -> Tensor:
+ return torch.sqrt(snr**2 / (1 + snr**2))
+
+
+class CosineScheduler(Scheduler):
+ """Cosine Scheduler."""
+
+ def __call__(self, t: Tensor) -> SchedulerOutput:
+ pi = torch.pi
+ return SchedulerOutput(
+ alpha_t=torch.sin(pi / 2 * t),
+ sigma_t=torch.cos(pi / 2 * t),
+ d_alpha_t=pi / 2 * torch.cos(pi / 2 * t),
+ d_sigma_t=-pi / 2 * torch.sin(pi / 2 * t),
+ )
+
+ def snr_inverse(self, snr: Tensor) -> Tensor:
+ return 2.0 * torch.atan(snr) / torch.pi
diff --git a/flow_matching/solver/__init__.py b/flow_matching/solver/__init__.py
new file mode 100644
index 0000000..8e62a7e
--- /dev/null
+++ b/flow_matching/solver/__init__.py
@@ -0,0 +1,18 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .discrete_solver import MixtureDiscreteEulerSolver, MixtureDiscreteSoftmaxEulerSolver
+from .ode_solver import ODESolver
+from .riemannian_ode_solver import RiemannianODESolver
+from .solver import Solver
+
+__all__ = [
+ "ODESolver",
+ "Solver",
+ "ModelWrapper",
+ "MixtureDiscreteEulerSolver",
+ "RiemannianODESolver",
+]
diff --git a/flow_matching/solver/discrete_solver.py b/flow_matching/solver/discrete_solver.py
new file mode 100644
index 0000000..e8f9431
--- /dev/null
+++ b/flow_matching/solver/discrete_solver.py
@@ -0,0 +1,656 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+
+from contextlib import nullcontext
+from math import ceil
+from typing import Callable, Optional, Union
+
+import torch
+from torch import Tensor
+
+from torch.nn import functional as F
+from tqdm import tqdm
+
+from flow_matching.path import MixtureDiscreteProbPath, MixtureDiscreteSoftmaxProbPath
+
+from flow_matching.solver.solver import Solver
+from flow_matching.utils import categorical, ModelWrapper
+from .utils import get_nearest_times
+
+
+class MixtureDiscreteEulerSolver(Solver):
+ r"""Solver that simulates the CTMC process :math:`(X_t)_{t_{\text{init}}\leq t\leq t_{\text{final}}}` defined by :math:`p_t` the marginal probability path of ``path``.
+ Given :math:`X_t \sim p_t`, the algorithm of solver step from :math:`t` to :math:`t+h` for the i-th coordinate is:
+
+ .. math::
+
+ \begin{align*}
+ & X_1^i \sim p_{1|t}^i(\cdot|X_t)\\
+ & \lambda^i \gets \sum_{x^i\ne X_t^i} u_t^i(x^i, X_t^i|X_1^i)\\
+ & Z^i_{\text{change}} \sim U[0,1]\\
+ & X_{t+h}^i \sim \begin{cases}
+ \frac{u_t^i(\cdot, X_t^i|X_1^i)}{\lambda^i}(1-\delta_{X_t^i}(\cdot)) \text{ if $Z^i_{\text{change}}\le 1-e^{-h\lambda^i}$}\\
+ \delta_{X_t^i}(\cdot) \text{ else }
+ \end{cases}
+ \end{align*}
+
+ Where :math:`p_{1|t}(\cdot|X_t)` is the output of ``model``, and the conditional probability velocity is of the mixture probability path is:
+
+ .. math::
+
+ u_t^i(x^i, y^i|x_1^i) = \hat{u}_t^i(x^i, y^i|x_1^i) + c_{\text{div\_free}}\left[\hat{u}_t^i(x^i, y^i|x_1^i) - \check{u}_t^i(x^i, y^i|x_1^i) \right],
+
+ where
+
+ .. math::
+ \hat{u}_t^i(x^i, y^i|x_1^i) = \frac{\dot{\kappa}_t}{1-\kappa_t} \left[ \delta_{x_1^i}(x^i) - \delta_{y^i}(x^i) \right],
+
+ and
+
+ .. math::
+
+ \check{u}_t^i(x^i, y^i|x_1^i) = \frac{\dot{\kappa}_t}{\kappa_t}\left[ \delta_{y^i}(x^i) - p(x^i) \right].
+
+ The source distribution :math:`p(x^i)` is given by ``p``.
+
+ Args:
+ model (ModelWrapper): trained with x-prediction, outputting posterior probabilities (in the range :math:`[0,1]`), output must be [..., vocabulary_size].
+ path (MixtureDiscreteProbPath): Probability path used for x-prediction training.
+ vocabulary_size (int): size of the discrete vocabulary.
+ source_distribution_p (Optional[Tensor], optional): Source distribution, must be of shape [vocabulary_size]. Required only when divergence-free term for the probability velocity is non-zero. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ path: MixtureDiscreteProbPath,
+ vocabulary_size: int,
+ source_distribution_p: Optional[Tensor] = None,
+ ):
+ super().__init__()
+ self.model = model
+ self.path = path
+ self.vocabulary_size = vocabulary_size
+
+ if source_distribution_p is not None:
+ assert source_distribution_p.shape == torch.Size(
+ [vocabulary_size]
+ ), f"Source distribution p dimension must match the vocabulary size {vocabulary_size}. Got {source_distribution_p.shape}."
+
+ self.source_distribution_p = source_distribution_p
+
+ @torch.no_grad()
+ def sample(
+ self,
+ x_init: Tensor,
+ step_size: Optional[float],
+ div_free: Union[float, Callable[[float], float]] = 0.0,
+ dtype_categorical: torch.dtype = torch.float32,
+ time_grid: Tensor = torch.tensor([0.0, 1.0]),
+ return_intermediates: bool = False,
+ verbose: bool = False,
+ **model_extras,
+ ) -> Tensor:
+ """
+ Sample a sequence of discrete values from the given model.
+
+ .. code-block:: python
+
+ import torch
+ from flow_matching.utils import ModelWrapper
+ from flow_matching.solver import MixtureDiscreteEulerSolver
+
+ class DummyModel(ModelWrapper):
+ def __init__(self):
+ super().__init__(None)
+ def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:
+ return ...
+
+ model = DummyModel()
+ solver = MixtureDiscreteEulerSolver(model=model)
+
+ x_init = torch.LongTensor([122, 725])
+ step_size = 0.001
+ time_grid = torch.tensor([0.0, 1.0])
+
+ result = solver.sample(x_init=x_init, step_size=step_size, time_grid=time_grid)
+
+ Args:
+ x_init (Tensor): The initial state.
+ step_size (Optional[float]): If float then time discretization is uniform with the given step size. If None then time discretization is set to be time_grid.
+ div_free (Union[float, Callable[[float], float]]): The coefficient of the divergence-free term in the probability velocity. Can be either a float or a time dependent function. Defaults to 0.0.
+ dtype_categorical (torch.dtype): Precision to use for categorical sampler. Defaults to torch.float32.
+ time_grid (Tensor): The CTMC process is solved in the interval [time_grid[0], time_grid[-1]] and if step_size is None then time discretization is set by the time grid. Defaults to torch.tensor([0.0,1.0]).
+ return_intermediates (bool): If True then return intermediate time steps according to time_grid. Defaults to False.
+ verbose (bool): Whether to print progress bars. Defaults to False.
+ **model_extras: Additional input for the model.
+
+ Returns:
+ Tensor: The sampled sequence of discrete values.
+ """
+ if not div_free == 0.0:
+ assert (
+ self.source_distribution_p is not None
+ ), "Source distribution p must be specified in order to add a divergence-free term to the probability velocity."
+
+ # Initialize the current state `x_t` with the initial state `X_0`.
+ time_grid = time_grid.to(device=x_init.device)
+
+ if step_size is None:
+ # If step_size is None then set the t discretization to time_grid.
+ t_discretization = time_grid
+ n_steps = len(time_grid) - 1
+ else:
+ # If step_size is float then t discretization is uniform with step size set by step_size.
+ t_init = time_grid[0].item()
+ t_final = time_grid[-1].item()
+ assert (
+ t_final - t_init
+ ) > step_size, f"Time interval [time_grid[0], time_grid[-1]] must be larger than step_size. Got a time interval [{t_init}, {t_final}] and step_size {step_size}."
+
+ n_steps = ceil((t_final - t_init) / step_size)
+ t_discretization = torch.tensor(
+ [t_init + step_size * i for i in range(n_steps)] + [t_final],
+ device=x_init.device,
+ )
+
+ if return_intermediates:
+ # get order of intermediate steps:
+ order = torch.argsort(time_grid)
+ # Compute intermediate steps to return via nearest points in t_discretization to time_grid.
+ time_grid = get_nearest_times(
+ time_grid=time_grid, t_discretization=t_discretization
+ )
+
+ x_t = x_init.clone()
+ steps_counter = 0
+ res = []
+
+ if return_intermediates:
+ res = [x_init.clone()]
+
+ if verbose:
+ ctx = tqdm(total=t_final, desc=f"NFE: {steps_counter}")
+ else:
+ ctx = nullcontext()
+
+ with ctx:
+ for i in range(n_steps):
+ t = t_discretization[i : i + 1]
+ h = t_discretization[i + 1 : i + 2] - t_discretization[i : i + 1]
+
+ # Sample x_1 ~ p_1|t( \cdot |x_t)
+ p_1t = self.model(x=x_t, t=t.repeat(x_t.shape[0]), **model_extras)
+ x_1 = categorical(p_1t.to(dtype=dtype_categorical))
+
+ # Checks if final step
+ if i == n_steps - 1:
+ x_t = x_1
+ else:
+ # Compute u_t(x|x_t,x_1)
+ scheduler_output = self.path.scheduler(t=t)
+
+ k_t = scheduler_output.alpha_t
+ d_k_t = scheduler_output.d_alpha_t
+
+ delta_1 = F.one_hot(x_1, num_classes=self.vocabulary_size).to(
+ k_t.dtype
+ )
+ u = d_k_t / (1 - k_t) * delta_1
+
+ # Add divergence-free part
+ div_free_t = div_free(t) if callable(div_free) else div_free
+
+ if div_free_t > 0:
+ p_0 = self.source_distribution_p[(None,) * x_t.dim()]
+ u = u + div_free_t * d_k_t / (k_t * (1 - k_t)) * (
+ (1 - k_t) * p_0 + k_t * delta_1
+ )
+
+ # Set u_t(x_t|x_t,x_1) = 0
+ delta_t = F.one_hot(x_t, num_classes=self.vocabulary_size)
+ u = torch.where(
+ delta_t.to(dtype=torch.bool), torch.zeros_like(u), u
+ )
+
+ # Sample x_t ~ u_t( \cdot |x_t,x_1)
+ intensity = u.sum(dim=-1) # Assuming u_t(xt|xt,x1) := 0
+ mask_jump = torch.rand(
+ size=x_t.shape, device=x_t.device
+ ) < 1 - torch.exp(-h * intensity)
+
+ if mask_jump.sum() > 0:
+ x_t[mask_jump] = categorical(
+ u[mask_jump].to(dtype=dtype_categorical)
+ )
+
+ steps_counter += 1
+ t = t + h
+
+ if return_intermediates and (t in time_grid):
+ res.append(x_t.clone())
+
+ if verbose:
+ ctx.n = t.item()
+ ctx.refresh()
+ ctx.set_description(f"NFE: {steps_counter}")
+
+ if return_intermediates:
+ if step_size is None:
+ return torch.stack(res, dim=0)
+ else:
+ return torch.stack(res, dim=0)[order]
+ else:
+ return x_t
+
+
+class MixtureDiscreteSoftmaxEulerSolver(Solver):
+ r"""Solver that simulates the CTMC process :math:`(X_t)_{t_{\text{init}}\leq t\leq t_{\text{final}}}` defined by :math:`p_t` the marginal probability path of ``path``.
+ Given :math:`X_t \sim p_t`, the algorithm of solver step from :math:`t` to :math:`t+h` for the i-th coordinate is:
+
+ .. math::
+
+ \begin{align*}
+ & X_1^i \sim p_{1|t}^i(\cdot|X_t)\\
+ & \lambda^i \gets \sum_{x^i\ne X_t^i} u_t^i(x^i, X_t^i|X_1^i)\\
+ & Z^i_{\text{change}} \sim U[0,1]\\
+ & X_{t+h}^i \sim \begin{cases}
+ \frac{u_t^i(\cdot, X_t^i|X_1^i)}{\lambda^i}(1-\delta_{X_t^i}(\cdot)) \text{ if $Z^i_{\text{change}}\le 1-e^{-h\lambda^i}$}\\
+ \delta_{X_t^i}(\cdot) \text{ else }
+ \end{cases}
+ \end{align*}
+
+ Where :math:`p_{1|t}(\cdot|X_t)` is the output of ``model``, and the conditional probability velocity is of the mixture probability path is:
+
+ .. math::
+
+ u_t^i(x^i, y^i|x_1^i) = \hat{u}_t^i(x^i, y^i|x_1^i) + c_{\text{div\_free}}\left[\hat{u}_t^i(x^i, y^i|x_1^i) - \check{u}_t^i(x^i, y^i|x_1^i) \right],
+
+ where
+
+ .. math::
+ \hat{u}_t^i(x^i, y^i|x_1^i) = \frac{\dot{\kappa}_t}{1-\kappa_t} \left[ \delta_{x_1^i}(x^i) - \delta_{y^i}(x^i) \right],
+
+ and
+
+ .. math::
+
+ \check{u}_t^i(x^i, y^i|x_1^i) = \frac{\dot{\kappa}_t}{\kappa_t}\left[ \delta_{y^i}(x^i) - p(x^i) \right].
+
+ The source distribution :math:`p(x^i)` is given by ``p``.
+
+ Args:
+ model (ModelWrapper): trained with x-prediction, outputting posterior probabilities (in the range :math:`[0,1]`), output must be [..., vocabulary_size].
+ path (MixtureDiscreteProbPath): Probability path used for x-prediction training.
+ vocabulary_size (int): size of the discrete vocabulary.
+ source_distribution_p (Optional[Tensor], optional): Source distribution, must be of shape [vocabulary_size]. Required only when divergence-free term for the probability velocity is non-zero. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ path_txt: MixtureDiscreteSoftmaxProbPath,
+ path_img: MixtureDiscreteSoftmaxProbPath,
+ vocabulary_size_txt: int,
+ vocabulary_size_img: int,
+ ):
+ super().__init__()
+ self.model = model
+ self.path_txt = path_txt
+ self.path_img = path_img
+ self.vocabulary_size_txt = vocabulary_size_txt
+ self.vocabulary_size_img = vocabulary_size_img
+
+ @torch.no_grad()
+ def sample(
+ self,
+ x_init: Tensor,
+ step_size: Optional[float],
+ div_free: Union[float, Callable[[float], float]] = 0.0,
+ dtype_categorical: torch.dtype = torch.float32,
+ time_grid: Tensor = torch.tensor([0.0, 1.0]),
+ return_intermediates: bool = False,
+ verbose: bool = False,
+ # callback: bool = False,
+ **model_extras,
+ ) -> Tensor:
+ """
+ Sample a sequence of discrete values from the given model.
+
+ .. code-block:: python
+
+ import torch
+ from flow_matching.utils import ModelWrapper
+ from flow_matching.solver import MixtureDiscreteEulerSolver
+
+ class DummyModel(ModelWrapper):
+ def __init__(self):
+ super().__init__(None)
+ def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:
+ return ...
+
+ model = DummyModel()
+ solver = MixtureDiscreteEulerSolver(model=model)
+
+ x_init = torch.LongTensor([122, 725])
+ step_size = 0.001
+ time_grid = torch.tensor([0.0, 1.0])
+
+ result = solver.sample(x_init=x_init, step_size=step_size, time_grid=time_grid)
+
+ Args:
+ x_init (Tensor): The initial state.
+ step_size (Optional[float]): If float then time discretization is uniform with the given step size. If None then time discretization is set to be time_grid.
+ div_free (Union[float, Callable[[float], float]]): The coefficient of the divergence-free term in the probability velocity. Can be either a float or a time dependent function. Defaults to 0.0.
+ dtype_categorical (torch.dtype): Precision to use for categorical sampler. Defaults to torch.float32.
+ time_grid (Tensor): The CTMC process is solved in the interval [time_grid[0], time_grid[-1]] and if step_size is None then time discretization is set by the time grid. Defaults to torch.tensor([0.0,1.0]).
+ return_intermediates (bool): If True then return intermediate time steps according to time_grid. Defaults to False.
+ verbose (bool): Whether to print progress bars. Defaults to False.
+ **model_extras: Additional input for the model.
+
+ Returns:
+ Tensor: The sampled sequence of discrete values.
+ """
+ if not div_free == 0.0:
+ assert (
+ self.source_distribution_p is not None
+ ), "Source distribution p must be specified in order to add a divergence-free term to the probability velocity."
+
+ # Initialize the current state `x_t` with the initial state `X_0`.
+ time_grid = time_grid.to(device=x_init.device)
+
+ if step_size is None:
+ # If step_size is None then set the t discretization to time_grid.
+ t_discretization = time_grid
+ n_steps = len(time_grid) - 1
+ else:
+ # If step_size is float then t discretization is uniform with step size set by step_size.
+ t_init = time_grid[0].item()
+ t_final = time_grid[-1].item()
+ assert (
+ t_final - t_init
+ ) > step_size, f"Time interval [time_grid[0], time_grid[-1]] must be larger than step_size. Got a time interval [{t_init}, {t_final}] and step_size {step_size}."
+
+ n_steps = ceil((t_final - t_init) / step_size)
+ t_discretization = torch.tensor(
+ [t_init + step_size * i for i in range(n_steps)] + [t_final],
+ device=x_init.device,
+ )
+
+ if return_intermediates:
+ # get order of intermediate steps:
+ order = torch.argsort(time_grid)
+ # Compute intermediate steps to return via nearest points in t_discretization to time_grid.
+ time_grid = get_nearest_times(
+ time_grid=time_grid, t_discretization=t_discretization
+ )
+
+ x_t = x_init.clone()
+ steps_counter = 0
+ res = []
+
+ if return_intermediates:
+ if self.model.g_or_u == 'generation':
+ res = [x_init.clone()[model_extras['datainfo']['image_token_mask']==1].reshape(x_init.shape[0], -1)]
+ elif self.model.g_or_u =='understanding':
+ res = [x_init.clone()[model_extras['datainfo']['text_token_mask']==1].reshape(x_init.shape[0], -1)]
+ else:
+ res = [x_init.clone()]
+
+
+ if verbose:
+ ctx = tqdm(total=time_grid[-1].item(), desc=f"NFE: {steps_counter}")
+ else:
+ ctx = nullcontext()
+
+ with ctx:
+ original_x_t = x_t.clone()
+ batch_size = original_x_t.shape[0]
+ for i in range(n_steps):
+ t = t_discretization[i : i + 1]
+ h = t_discretization[i + 1 : i + 2] - t_discretization[i : i + 1]
+
+ # Sample x_1 ~ p_1|t( \cdot |x_t)
+ p_1t_txt, p_1t_img, data_info = self.model(x=x_t, **model_extras)
+ if p_1t_txt is None:
+ x_1 = categorical(p_1t_img.to(dtype=dtype_categorical))
+ x_1 = x_1[data_info['image_token_mask']==1].reshape(batch_size, -1)
+ x_t = x_t[data_info['image_token_mask']==1].reshape(batch_size, -1)
+ # x_1 = x_1 * data_info['image_token_mask'] + x_t * (1 - data_info['image_token_mask'])
+ elif p_1t_img is None:
+ x_1 = categorical(p_1t_txt.to(dtype=dtype_categorical))
+ x_1 = x_1[data_info['text_token_mask']==1].reshape(batch_size, -1)
+ x_t = x_t[data_info['text_token_mask']==1].reshape(batch_size, -1)
+ # x_1 = x_1 * data_info['text_token_mask'] + x_t * (1 - data_info['text_token_mask'])
+ else:
+ x_1_img = categorical(p_1t_img.to(dtype=dtype_categorical))
+ x_1_txt = categorical(p_1t_txt.to(dtype=dtype_categorical))
+ x_1_img = x_1_img[data_info['image_token_mask']==1].reshape(batch_size, -1)
+ x_1_txt = x_1_txt[data_info['text_token_mask']==1].reshape(batch_size, -1)
+ x_t_img = x_t[data_info['image_token_mask']==1].reshape(batch_size, -1)
+ x_t_txt = x_t[data_info['text_token_mask']==1].reshape(batch_size, -1)
+ # x_1_txt = x_1_txt * data_info['text_token_mask'] + x_t * (1 - data_info['text_token_mask'])
+ # x_1_img = x_1_img * data_info['image_token_mask'] + x_t * (1 - data_info['image_token_mask'])
+ # x_1 = x_1_txt * (1 - data_info['image_token_mask']) + x_1_img * data_info['image_token_mask']
+
+
+ # Checks if final step
+ if i == n_steps - 1:
+ if p_1t_txt is None:
+ x_t = x_1
+ elif p_1t_img is None:
+ x_t = x_1
+ else:
+ x_t = original_x_t.clone()
+ x_t[data_info['image_token_mask']==1] = x_1_img.flatten()
+ x_t[data_info['text_token_mask']==1] = x_1_txt.flatten()
+
+ if return_intermediates:
+ res.append(x_t.clone())
+ else:
+ if p_1t_txt is None:
+ # Compute p_t(x|x_1)
+ emb_x_1 = self.path_img.embedding(x_1)
+ prob_x_t = self.path_img.get_prob_distribution(emb_x_1, t)
+ prob_x_t = prob_x_t.reshape(-1, prob_x_t.shape[-1])
+ # Comptute the metric
+ emb_x_t = self.path_img.embedding(x_t)
+ emb_x_t_flattened = F.normalize(emb_x_t.view(-1, emb_x_t.shape[-1]), p=2, dim=-1)
+ emb_x_1_flattened = F.normalize(emb_x_1.view(-1, emb_x_1.shape[-1]), p=2, dim=-1)
+ distance_x_t_2_x_1 = (torch.sum(emb_x_t_flattened ** 2, dim=1, keepdim=True) + torch.sum(emb_x_1_flattened**2, dim=1, keepdim=True) - 2 * torch.einsum('bd,bd->b', emb_x_t_flattened, emb_x_1_flattened).unsqueeze(1)) ** 2
+ distance_x_1_2_x = self.path_img.metric(emb_x_1)
+ distance = F.relu(distance_x_t_2_x_1 - distance_x_1_2_x)
+ if t ==0 :
+ d_beta_t = 0
+ else:
+ d_beta_t = self.path_img.c * self.path_img.a * ((t / (1 - t)) ** (self.path_img.a - 1)) * 1 / ((1 - t) ** 2)
+ # get u
+ u = prob_x_t * d_beta_t * distance
+ # print(f"prob_x_t:{prob_x_t}")
+ # print(f"d_beta_t:{d_beta_t}")
+ # print(f"distance:{distance}")
+ # print(f"t:{t}, {t.dtype}")
+ u = u.reshape(x_1.shape[0], x_1.shape[1], -1)
+
+ # Set u_t(x_t|x_t,x_1) = 0
+ delta_t = F.one_hot(x_t, num_classes=self.vocabulary_size_img)
+ u = torch.where(
+ delta_t.to(dtype=torch.bool), torch.zeros_like(u), u
+ )
+
+ # Sample x_t ~ u_t( \cdot |x_t,x_1)
+ intensity = u.sum(dim=-1) # Assuming u_t(xt|xt,x1) := 0
+ # print(f"intensity:{intensity.sum()}")
+ mask_jump = torch.rand(
+ size=x_t.shape, device=x_t.device
+ ) < 1 - torch.exp(-h * intensity)
+ # torch.save(u, f'u_{u.device}.pt')
+ if mask_jump.sum() > 0:
+ x_t[mask_jump] = categorical(
+ u[mask_jump].to(dtype=dtype_categorical)
+ )
+ if return_intermediates:
+ res.append(x_t.clone())
+ # if callback:
+ # yield x_t
+ # res.append(x_1.clone())
+ original_x_t[data_info['image_token_mask']==1] = x_t.flatten()
+ # original_x_t[data_info['image_token_mask']==1] = x_1.flatten()
+ x_t = original_x_t.clone()
+ elif p_1t_img is None:
+ # Compute p_t(x|x_1)
+ emb_x_1 = self.path_txt.embedding(x_1)
+ prob_x_t = self.path_txt.get_prob_distribution(emb_x_1, t)
+ prob_x_t = prob_x_t.reshape(-1, prob_x_t.shape[-1])
+ # Comptute the metric
+ emb_x_t = self.path_txt.embedding(x_t)
+ emb_x_t_flattened = F.normalize(emb_x_t.view(-1, emb_x_t.shape[-1]), p=2, dim=-1)
+ emb_x_1_flattened = F.normalize(emb_x_1.view(-1, emb_x_1.shape[-1]), p=2, dim=-1)
+ distance_x_t_2_x_1 = (torch.sum(emb_x_t_flattened ** 2, dim=1, keepdim=True) + torch.sum(emb_x_1_flattened**2, dim=1, keepdim=True) - 2 * torch.einsum('bd,bd->b', emb_x_t_flattened, emb_x_1_flattened).unsqueeze(1)) ** 2
+ distance_x_1_2_x = self.path_txt.metric(emb_x_1)
+ distance = F.relu(distance_x_t_2_x_1 - distance_x_1_2_x)
+ if t ==0 :
+ d_beta_t = 0
+ else:
+ d_beta_t = self.path_txt.c * self.path_txt.a * ((t / (1 - t)) ** (self.path_txt.a - 1)) * 1 / ((1 - t) ** 2)
+ # get u
+ u = prob_x_t * d_beta_t * distance
+ # print(f"prob_x_t:{prob_x_t}")
+ # print(f"d_beta_t:{d_beta_t}")
+ # print(f"distance:{distance}")
+ # print(f"t:{t}, {t.dtype}")
+ u = u.reshape(x_1.shape[0], x_1.shape[1], -1)
+
+ # Set u_t(x_t|x_t,x_1) = 0
+ delta_t = F.one_hot(x_t, num_classes=self.vocabulary_size_txt)
+ u = torch.where(
+ delta_t.to(dtype=torch.bool), torch.zeros_like(u), u
+ )
+
+ # Sample x_t ~ u_t( \cdot |x_t,x_1)
+ intensity = u.sum(dim=-1) # Assuming u_t(xt|xt,x1) := 0
+ mask_jump = torch.rand(
+ size=x_t.shape, device=x_t.device
+ ) < 1 - torch.exp(-h * intensity)
+ # torch.save(u, f'u_{u.device}.pt')
+ if mask_jump.sum() > 0:
+ x_t[mask_jump] = categorical(
+ u[mask_jump].to(dtype=dtype_categorical)
+ )
+ if return_intermediates:
+ res.append(x_t.clone())
+ # if callback:
+ # yield x_t
+ original_x_t[data_info['text_token_mask']==1] = x_t.flatten()
+ x_t = original_x_t.clone()
+ else:
+ # The text part
+ x_t = x_t_txt.clone()
+ x_1 = x_1_txt.clone()
+ # Compute p_t(x|x_1)
+ emb_x_1 = self.path_txt.embedding(x_1)
+ prob_x_t = self.path_txt.get_prob_distribution(emb_x_1, t)
+ prob_x_t = prob_x_t.reshape(-1, prob_x_t.shape[-1])
+ # Comptute the metric
+ emb_x_t = self.path_txt.embedding(x_t)
+ emb_x_t_flattened = F.normalize(emb_x_t.view(-1, emb_x_t.shape[-1]), p=2, dim=-1)
+ emb_x_1_flattened = F.normalize(emb_x_1.view(-1, emb_x_1.shape[-1]), p=2, dim=-1)
+ distance_x_t_2_x_1 = (torch.sum(emb_x_t_flattened ** 2, dim=1, keepdim=True) + torch.sum(emb_x_1_flattened**2, dim=1, keepdim=True) - 2 * torch.einsum('bd,bd->b', emb_x_t_flattened, emb_x_1_flattened).unsqueeze(1)) ** 2
+ distance_x_1_2_x = self.path_txt.metric(emb_x_1)
+ distance = F.relu(distance_x_t_2_x_1 - distance_x_1_2_x)
+ if t ==0 :
+ d_beta_t = 0
+ else:
+ d_beta_t = self.path_txt.c * self.path_txt.a * ((t / (1 - t)) ** (self.path_txt.a - 1)) * 1 / ((1 - t) ** 2)
+ # get u
+ u = prob_x_t * d_beta_t * distance
+ # print(f"prob_x_t:{prob_x_t}")
+ # print(f"d_beta_t:{d_beta_t}")
+ # print(f"distance:{distance}")
+ # print(f"t:{t}, {t.dtype}")
+ u = u.reshape(x_1.shape[0], x_1.shape[1], -1)
+
+ # Set u_t(x_t|x_t,x_1) = 0
+ delta_t = F.one_hot(x_t, num_classes=self.vocabulary_size_txt)
+ u = torch.where(
+ delta_t.to(dtype=torch.bool), torch.zeros_like(u), u
+ )
+
+ # Sample x_t ~ u_t( \cdot |x_t,x_1)
+ intensity = u.sum(dim=-1) # Assuming u_t(xt|xt,x1) := 0
+ mask_jump = torch.rand(
+ size=x_t.shape, device=x_t.device
+ ) < 1 - torch.exp(-h * intensity)
+ # torch.save(u, f'u_{u.device}.pt')
+ if mask_jump.sum() > 0:
+ x_t[mask_jump] = categorical(
+ u[mask_jump].to(dtype=dtype_categorical)
+ )
+ original_x_t[data_info['text_token_mask']==1] = x_t.flatten()
+
+ # The image part
+ x_t = x_t_img.clone()
+ x_1 = x_1_img.clone()
+ scheduler_output = self.path_img.scheduler(t=t)
+ emb_x_1 = self.path_img.embedding(x_1)
+ prob_x_t = self.path_img.get_prob_distribution(emb_x_1, t)
+ prob_x_t = prob_x_t.reshape(-1, prob_x_t.shape[-1])
+ # Comptute the metric
+ emb_x_t = self.path_img.embedding(x_t)
+ emb_x_t_flattened = F.normalize(emb_x_t.view(-1, emb_x_t.shape[-1]), p=2, dim=-1)
+ emb_x_1_flattened = F.normalize(emb_x_1.view(-1, emb_x_1.shape[-1]), p=2, dim=-1)
+ distance_x_t_2_x_1 = (torch.sum(emb_x_t_flattened ** 2, dim=1, keepdim=True) + torch.sum(emb_x_1_flattened**2, dim=1, keepdim=True) - 2 * torch.einsum('bd,bd->b', emb_x_t_flattened, emb_x_1_flattened).unsqueeze(1)) ** 2
+ distance_x_1_2_x = self.path_img.metric(emb_x_1)
+ distance = F.relu(distance_x_t_2_x_1 - distance_x_1_2_x)
+ if t ==0 :
+ d_beta_t = 0
+ else:
+ d_beta_t = self.path_img.c * self.path_img.a * ((t / (1 - t)) ** (self.path_img.a - 1)) * 1 / ((1 - t) ** 2)
+ # get u
+ u = prob_x_t * d_beta_t * distance
+ # print(f"prob_x_t:{prob_x_t}")
+ # print(f"d_beta_t:{d_beta_t}")
+ # print(f"distance:{distance}")
+ # print(f"t:{t}, {t.dtype}")
+ u = u.reshape(x_1.shape[0], x_1.shape[1], -1)
+
+ # Set u_t(x_t|x_t,x_1) = 0
+ delta_t = F.one_hot(x_t, num_classes=self.vocabulary_size_img)
+ u = torch.where(
+ delta_t.to(dtype=torch.bool), torch.zeros_like(u), u
+ )
+
+ # Sample x_t ~ u_t( \cdot |x_t,x_1)
+ intensity = u.sum(dim=-1) # Assuming u_t(xt|xt,x1) := 0
+ mask_jump = torch.rand(
+ size=x_t.shape, device=x_t.device
+ ) < 1 - torch.exp(-h * intensity)
+ # torch.save(u, f'u_{u.device}.pt')
+ if mask_jump.sum() > 0:
+ x_t[mask_jump] = categorical(
+ u[mask_jump].to(dtype=dtype_categorical)
+ )
+ original_x_t[data_info['image_token_mask']==1] = x_t.flatten()
+
+ x_t = original_x_t.clone()
+ if return_intermediates:
+ res.append(x_t.clone())
+
+ steps_counter += 1
+ t = t + h
+
+ if verbose:
+ ctx.n = t.item()
+ ctx.refresh()
+ ctx.set_description(f"NFE: {steps_counter}")
+
+ # if return_intermediates and not callback:
+ if return_intermediates:
+ return torch.stack(res, dim=0)[:, 0, :].reshape(n_steps+1, -1)
+ # elif callback:
+ # yield x_t
+ else:
+ return x_t
diff --git a/flow_matching/solver/ode_solver.py b/flow_matching/solver/ode_solver.py
new file mode 100644
index 0000000..d2c1040
--- /dev/null
+++ b/flow_matching/solver/ode_solver.py
@@ -0,0 +1,194 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Callable, Optional, Sequence, Tuple, Union
+
+import torch
+from torch import Tensor
+from torchdiffeq import odeint
+
+from flow_matching.solver.solver import Solver
+from flow_matching.utils import gradient, ModelWrapper
+
+
+class ODESolver(Solver):
+ """A class to solve ordinary differential equations (ODEs) using a specified velocity model.
+
+ This class utilizes a velocity field model to solve ODEs over a given time grid using numerical ode solvers.
+
+ Args:
+ velocity_model (Union[ModelWrapper, Callable]): a velocity field model receiving :math:`(x,t)` and returning :math:`u_t(x)`
+ """
+
+ def __init__(self, velocity_model: Union[ModelWrapper, Callable]):
+ super().__init__()
+ self.velocity_model = velocity_model
+
+ @torch.no_grad()
+ def sample(
+ self,
+ x_init: Tensor,
+ step_size: Optional[float],
+ method: str = "euler",
+ atol: float = 1e-5,
+ rtol: float = 1e-5,
+ time_grid: Tensor = torch.tensor([0.0, 1.0]),
+ return_intermediates: bool = False,
+ **model_extras,
+ ) -> Union[Tensor, Sequence[Tensor]]:
+ r"""Solve the ODE with the velocity field.
+
+ Example:
+
+ .. code-block:: python
+
+ import torch
+ from flow_matching.utils import ModelWrapper
+ from flow_matching.solver import ODESolver
+
+ class DummyModel(ModelWrapper):
+ def __init__(self):
+ super().__init__(None)
+
+ def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:
+ return torch.ones_like(x) * 3.0 * t**2
+
+ velocity_model = DummyModel()
+ solver = ODESolver(velocity_model=velocity_model)
+ x_init = torch.tensor([0.0, 0.0])
+ step_size = 0.001
+ time_grid = torch.tensor([0.0, 1.0])
+
+ result = solver.sample(x_init=x_init, step_size=step_size, time_grid=time_grid)
+
+ Args:
+ x_init (Tensor): initial conditions (e.g., source samples :math:`X_0 \sim p`). Shape: [batch_size, ...].
+ step_size (Optional[float]): The step size. Must be None for adaptive step solvers.
+ method (str): A method supported by torchdiffeq. Defaults to "euler". Other commonly used solvers are "dopri5", "midpoint" and "heun3". For a complete list, see torchdiffeq.
+ atol (float): Absolute tolerance, used for adaptive step solvers.
+ rtol (float): Relative tolerance, used for adaptive step solvers.
+ time_grid (Tensor): The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. May specify a descending time_grid to solve in the reverse direction. Defaults to torch.tensor([0.0, 1.0]).
+ return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Defaults to False.
+ **model_extras: Additional input for the model.
+
+ Returns:
+ Union[Tensor, Sequence[Tensor]]: The last timestep when return_intermediates=False, otherwise all values specified in time_grid.
+ """
+
+ time_grid = time_grid.to(x_init.device)
+
+ def ode_func(t, x):
+ return self.velocity_model(x=x, t=t, **model_extras)
+
+ ode_opts = {"step_size": step_size} if step_size is not None else {}
+
+ # Approximate ODE solution with numerical ODE solver
+ sol = odeint(
+ ode_func,
+ x_init,
+ time_grid,
+ method=method,
+ options=ode_opts,
+ atol=atol,
+ rtol=rtol,
+ )
+
+ if return_intermediates:
+ return sol
+ else:
+ return sol[-1]
+
+ @torch.no_grad()
+ def compute_likelihood(
+ self,
+ x_1: Tensor,
+ log_p0: Callable[[Tensor], Tensor],
+ step_size: Optional[float],
+ method: str = "euler",
+ atol: float = 1e-5,
+ rtol: float = 1e-5,
+ time_grid: Tensor = torch.tensor([1.0, 0.0]),
+ return_intermediates: bool = False,
+ exact_divergence: bool = False,
+ **model_extras,
+ ) -> Union[Tuple[Tensor, Tensor], Tuple[Sequence[Tensor], Tensor]]:
+ r"""Solve for log likelihood given a target sample at :math:`t=0`.
+
+ Works similarly to sample, but solves the ODE in reverse to compute the log-likelihood. The velocity model must be differentiable with respect to x.
+ The function assumes log_p0 is the log probability of the source distribution at :math:`t=0`.
+
+ Args:
+ x_1 (Tensor): target sample (e.g., samples :math:`X_1 \sim p_1`).
+ log_p0 (Callable[[Tensor], Tensor]): Log probability function of the source distribution.
+ step_size (Optional[float]): The step size. Must be None for adaptive step solvers.
+ method (str): A method supported by torchdiffeq. Defaults to "euler". Other commonly used solvers are "dopri5", "midpoint" and "heun3". For a complete list, see torchdiffeq.
+ atol (float): Absolute tolerance, used for adaptive step solvers.
+ rtol (float): Relative tolerance, used for adaptive step solvers.
+ time_grid (Tensor): If step_size is None then time discretization is set by the time grid. Must start at 1.0 and end at 0.0, otherwise the likelihood computation is not valid. Defaults to torch.tensor([1.0, 0.0]).
+ return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Otherwise only return the final sample. Defaults to False.
+ exact_divergence (bool): Whether to compute the exact divergence or use the Hutchinson estimator.
+ **model_extras: Additional input for the model.
+
+ Returns:
+ Union[Tuple[Tensor, Tensor], Tuple[Sequence[Tensor], Tensor]]: Samples at time_grid and log likelihood values of given x_1.
+ """
+ assert (
+ time_grid[0] == 1.0 and time_grid[-1] == 0.0
+ ), f"Time grid must start at 1.0 and end at 0.0. Got {time_grid}"
+
+ # Fix the random projection for the Hutchinson divergence estimator
+ if not exact_divergence:
+ z = (torch.randn_like(x_1).to(x_1.device) < 0) * 2.0 - 1.0
+
+ def ode_func(x, t):
+ return self.velocity_model(x=x, t=t, **model_extras)
+
+ def dynamics_func(t, states):
+ xt = states[0]
+ with torch.set_grad_enabled(True):
+ xt.requires_grad_()
+ ut = ode_func(xt, t)
+
+ if exact_divergence:
+ # Compute exact divergence
+ div = 0
+ for i in range(ut.flatten(1).shape[1]):
+ div += gradient(ut[:, i], xt, create_graph=True)[:, i]
+ else:
+ # Compute Hutchinson divergence estimator E[z^T D_x(ut) z]
+ ut_dot_z = torch.einsum(
+ "ij,ij->i", ut.flatten(start_dim=1), z.flatten(start_dim=1)
+ )
+ grad_ut_dot_z = gradient(ut_dot_z, xt)
+ div = torch.einsum(
+ "ij,ij->i",
+ grad_ut_dot_z.flatten(start_dim=1),
+ z.flatten(start_dim=1),
+ )
+
+ return ut.detach(), div.detach()
+
+ y_init = (x_1, torch.zeros(x_1.shape[0], device=x_1.device))
+ ode_opts = {"step_size": step_size} if step_size is not None else {}
+
+ with torch.no_grad():
+ sol, log_det = odeint(
+ dynamics_func,
+ y_init,
+ time_grid,
+ method=method,
+ options=ode_opts,
+ atol=atol,
+ rtol=rtol,
+ )
+
+ x_source = sol[-1]
+ source_log_p = log_p0(x_source)
+
+ if return_intermediates:
+ return sol, source_log_p + log_det[-1]
+ else:
+ return sol[-1], source_log_p + log_det[-1]
diff --git a/flow_matching/solver/riemannian_ode_solver.py b/flow_matching/solver/riemannian_ode_solver.py
new file mode 100644
index 0000000..6eb3e5e
--- /dev/null
+++ b/flow_matching/solver/riemannian_ode_solver.py
@@ -0,0 +1,243 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from typing import Callable
+
+import torch
+from torch import Tensor
+from tqdm import tqdm
+
+from flow_matching.solver.solver import Solver
+from flow_matching.utils import ModelWrapper
+from flow_matching.utils.manifolds import geodesic, Manifold
+
+
+class RiemannianODESolver(Solver):
+ r"""Riemannian ODE solver
+ Initialize the ``RiemannianODESolver``.
+
+ Args:
+ manifold (Manifold): the manifold to solve on.
+ velocity_model (ModelWrapper): a velocity field model receiving :math:`(x,t)`
+ and returning :math:`u_t(x)` which is assumed to lie on the tangent plane at `x`.
+ """
+
+ def __init__(self, manifold: Manifold, velocity_model: ModelWrapper):
+ super().__init__()
+ self.manifold = manifold
+ self.velocity_model = velocity_model
+
+ @torch.no_grad()
+ def sample(
+ self,
+ x_init: Tensor,
+ step_size: float,
+ projx: bool = True,
+ proju: bool = True,
+ method: str = "euler",
+ time_grid: Tensor = torch.tensor([0.0, 1.0]),
+ return_intermediates: bool = False,
+ verbose: bool = False,
+ **model_extras,
+ ) -> Tensor:
+ r"""Solve the ODE with the `velocity_field` on the manifold.
+
+ Args:
+ x_init (Tensor): initial conditions (e.g., source samples :math:`X_0 \sim p`).
+ step_size (float): The step size.
+ projx (bool): Whether to project the point onto the manifold at each step. Defaults to True.
+ proju (bool): Whether to project the vector field onto the tangent plane at each step. Defaults to True.
+ method (str): One of ["euler", "midpoint", "rk4"]. Defaults to "euler".
+ time_grid (Tensor, optional): The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. Defaults to torch.tensor([0.0,1.0]).
+ return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Defaults to False.
+ verbose (bool, optional): Whether to print progress bars. Defaults to False.
+ **model_extras: Additional input for the model.
+
+ Returns:
+ Tensor: The sampled sequence. Defaults to returning samples at :math:`t=1`.
+ """
+ step_fns = {
+ "euler": _euler_step,
+ "midpoint": _midpoint_step,
+ "rk4": _rk4_step,
+ }
+ assert method in step_fns.keys(), f"Unknown method {method}"
+ step_fn = step_fns[method]
+
+ # --- Factor this out.
+ time_grid = torch.sort(time_grid.to(device=x_init.device)).values
+
+ if step_size is None:
+ # If step_size is None then set the t discretization to time_grid.
+ t_discretization = time_grid
+ n_steps = len(time_grid) - 1
+ else:
+ # If step_size is float then t discretization is uniform with step size set by step_size.
+ t_init = time_grid[0].item()
+ t_final = time_grid[-1].item()
+ assert (
+ t_final - t_init
+ ) > step_size, f"Time interval [min(time_grid), max(time_grid)] must be larger than step_size. Got a time interval [{t_init}, {t_final}] and step_size {step_size}."
+
+ n_steps = math.ceil((t_final - t_init) / step_size)
+ t_discretization = torch.tensor(
+ [step_size * i for i in range(n_steps)] + [t_final],
+ device=x_init.device,
+ )
+ # ---
+ t0s = t_discretization[:-1]
+
+ if verbose:
+ t0s = tqdm(t0s)
+
+ if return_intermediates:
+ xts = []
+ i_ret = 0
+
+ xt = x_init
+ for t0, t1 in zip(t0s, t_discretization[1:]):
+ dt = t1 - t0
+ xt_next = step_fn(
+ self.velocity_model,
+ xt,
+ t0,
+ dt,
+ manifold=self.manifold,
+ projx=projx,
+ proju=proju,
+ )
+ if return_intermediates:
+ while (
+ i_ret < len(time_grid)
+ and t0 <= time_grid[i_ret]
+ and time_grid[i_ret] <= t1
+ ):
+ xts.append(
+ interp(self.manifold, xt, xt_next, t0, t1, time_grid[i_ret])
+ )
+ i_ret += 1
+ xt = xt_next
+
+ if return_intermediates:
+ return torch.stack(xts, dim=0)
+ else:
+ return xt
+
+
+def interp(manifold, xt, xt_next, t, t_next, t_ret):
+ return geodesic(manifold, xt, xt_next)(
+ (t_ret - t) / (t_next - t).reshape(1)
+ ).reshape_as(xt)
+
+
+def _euler_step(
+ velocity_model: Callable,
+ xt: Tensor,
+ t0: Tensor,
+ dt: Tensor,
+ manifold: Manifold,
+ projx: bool = True,
+ proju: bool = True,
+) -> Tensor:
+ r"""Perform an Euler step on a manifold.
+
+ Args:
+ velocity_model (Callable): the velocity model
+ xt (Tensor): tensor containing the state at time t0
+ t0 (Tensor): the time at which this step is taken
+ dt (Tensor): the step size
+ manifold (Manifold): a manifold object
+ projx (bool, optional): whether to project the state onto the manifold. Defaults to True.
+ proju (bool, optional): whether to project the velocity onto the tangent plane. Defaults to True.
+
+ Returns:
+ Tensor: tensor containing the state after the step
+ """
+ velocity_fn = lambda x, t: (
+ manifold.proju(x, velocity_model(x, t)) if proju else velocity_model(x, t)
+ )
+ projx_fn = lambda x: manifold.projx(x) if projx else x
+
+ vt = velocity_fn(xt, t0)
+
+ xt = xt + dt * vt
+
+ return projx_fn(xt)
+
+
+def _midpoint_step(
+ velocity_model: Callable,
+ xt: Tensor,
+ t0: Tensor,
+ dt: Tensor,
+ manifold: Manifold,
+ projx: bool = True,
+ proju: bool = True,
+) -> Tensor:
+ r"""Perform a midpoint step on a manifold.
+
+ Args:
+ velocity_model (Callable): the velocity model
+ xt (Tensor): tensor containing the state at time t0
+ t0 (Tensor): the time at which this step is taken
+ dt (Tensor): the step size
+ manifold (Manifold): a manifold object
+ projx (bool, optional): whether to project the state onto the manifold. Defaults to True.
+ proju (bool, optional): whether to project the velocity onto the tangent plane. Defaults to True.
+
+ Returns:
+ Tensor: tensor containing the state after the step
+ """
+ velocity_fn = lambda x, t: (
+ manifold.proju(x, velocity_model(x, t)) if proju else velocity_model(x, t)
+ )
+ projx_fn = lambda x: manifold.projx(x) if projx else x
+
+ half_dt = 0.5 * dt
+ vt = velocity_fn(xt, t0)
+ x_mid = xt + half_dt * vt
+ x_mid = projx_fn(x_mid)
+
+ xt = xt + dt * velocity_fn(x_mid, t0 + half_dt)
+
+ return projx_fn(xt)
+
+
+def _rk4_step(
+ velocity_model: Callable,
+ xt: Tensor,
+ t0: Tensor,
+ dt: Tensor,
+ manifold: Manifold,
+ projx: bool = True,
+ proju: bool = True,
+) -> Tensor:
+ r"""Perform an RK4 step on a manifold.
+
+ Args:
+ velocity_model (Callable): the velocity model
+ xt (Tensor): tensor containing the state at time t0
+ t0 (Tensor): the time at which this step is taken
+ dt (Tensor): the step size
+ manifold (Manifold): a manifold object
+ projx (bool, optional): whether to project the state onto the manifold. Defaults to True.
+ proju (bool, optional): whether to project the velocity onto the tangent plane. Defaults to True.
+
+ Returns:
+ Tensor: tensor containing the state after the step
+ """
+ velocity_fn = lambda x, t: (
+ manifold.proju(x, velocity_model(x, t)) if proju else velocity_model(x, t)
+ )
+ projx_fn = lambda x: manifold.projx(x) if projx else x
+
+ k1 = velocity_fn(xt, t0)
+ k2 = velocity_fn(projx_fn(xt + dt * k1 / 3), t0 + dt / 3)
+ k3 = velocity_fn(projx_fn(xt + dt * (k2 - k1 / 3)), t0 + dt * 2 / 3)
+ k4 = velocity_fn(projx_fn(xt + dt * (k1 - k2 + k3)), t0 + dt)
+
+ return projx_fn(xt + (k1 + 3 * (k2 + k3) + k4) * dt * 0.125)
diff --git a/flow_matching/solver/solver.py b/flow_matching/solver/solver.py
new file mode 100644
index 0000000..4819e1c
--- /dev/null
+++ b/flow_matching/solver/solver.py
@@ -0,0 +1,17 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+
+from abc import ABC, abstractmethod
+
+from torch import nn, Tensor
+
+
+class Solver(ABC, nn.Module):
+ """Abstract base class for solvers."""
+
+ @abstractmethod
+ def sample(self, x_0: Tensor = None) -> Tensor:
+ ...
diff --git a/flow_matching/solver/utils.py b/flow_matching/solver/utils.py
new file mode 100644
index 0000000..f3a34ee
--- /dev/null
+++ b/flow_matching/solver/utils.py
@@ -0,0 +1,19 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import Tensor
+
+
+def get_nearest_times(time_grid: Tensor, t_discretization: Tensor) -> Tensor:
+ distances = torch.cdist(
+ time_grid.unsqueeze(1),
+ t_discretization.unsqueeze(1),
+ compute_mode="donot_use_mm_for_euclid_dist",
+ )
+ nearest_indices = distances.argmin(dim=1)
+
+ return t_discretization[nearest_indices]
diff --git a/flow_matching/utils/__init__.py b/flow_matching/utils/__init__.py
new file mode 100644
index 0000000..0085c44
--- /dev/null
+++ b/flow_matching/utils/__init__.py
@@ -0,0 +1,17 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .categorical_sampler import categorical
+from .model_wrapper import ModelWrapper
+from .utils import expand_tensor_like, gradient, unsqueeze_to_match
+
+__all__ = [
+ "unsqueeze_to_match",
+ "expand_tensor_like",
+ "gradient",
+ "categorical",
+ "ModelWrapper",
+]
diff --git a/flow_matching/utils/categorical_sampler.py b/flow_matching/utils/categorical_sampler.py
new file mode 100644
index 0000000..761a44e
--- /dev/null
+++ b/flow_matching/utils/categorical_sampler.py
@@ -0,0 +1,25 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import Tensor
+
+
+def categorical(probs: Tensor) -> Tensor:
+ r"""Categorical sampler according to weights in the last dimension of ``probs`` using :func:`torch.multinomial`.
+
+ Args:
+ probs (Tensor): probabilities.
+
+ Returns:
+ Tensor: Samples.
+ """
+
+ return torch.multinomial(probs.flatten(0, -2), 1, replacement=True).view(
+ *probs.shape[:-1]
+ )
+
+ # return torch.argmax(probs, dim=-1)
diff --git a/flow_matching/utils/flow.py b/flow_matching/utils/flow.py
new file mode 100644
index 0000000..5e73df4
--- /dev/null
+++ b/flow_matching/utils/flow.py
@@ -0,0 +1,90 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+
+from abc import ABC
+from typing import Optional, Tuple
+
+import torch
+from torch import Tensor
+from torch.nn.modules.loss import _Loss
+
+from flow_matching.loss import MixturePathGeneralizedKL
+from flow_matching.path import MixtureDiscreteProbPath, ProbPath
+from flow_matching.path.scheduler import PolynomialConvexScheduler
+
+
+class SourceDistribution(ABC):
+ def __init__(
+ self,
+ ) -> None:
+ ...
+
+ def sample(self, tensor_size: Tuple[int, ...], device: torch.device) -> Tensor:
+ ...
+
+ def sample_like(self, tensor_like: Tensor) -> Tensor:
+ ...
+
+
+class MaskedSourceDistribution(SourceDistribution):
+ def __init__(self, mask_token: int) -> None:
+ self.mask_token = mask_token
+
+ @property
+ def masked(self) -> bool:
+ return True
+
+ def sample(self, tensor_size: Tuple[int, ...], device: torch.device) -> Tensor:
+ return torch.zeros(tensor_size, device=device).fill_(self.mask_token).long()
+
+ def sample_like(self, tensor_like: Tensor) -> Tensor:
+ return torch.zeros_like(tensor_like).fill_(self.mask_token).long()
+
+
+class UniformSourceDistribution(SourceDistribution):
+ def __init__(self, vocab_size: int) -> None:
+ self.vocab_size = vocab_size
+
+ @property
+ def masked(self) -> bool:
+ return False
+
+ def sample(self, tensor_size: Tuple[int, ...], device: torch.device) -> Tensor:
+ return torch.randint(size=tensor_size, high=self.vocab_size, device=device)
+
+ def sample_like(self, tensor_like: Tensor) -> Tensor:
+ return torch.randint_like(tensor_like, high=self.vocab_size)
+
+
+def get_path(scheduler_type: str, exponent: Optional[float] = None) -> ProbPath:
+ if scheduler_type == "polynomial":
+ scheduler = PolynomialConvexScheduler(n=exponent)
+ else:
+ raise ValueError(f"{scheduler_type} is not supported")
+
+ return MixtureDiscreteProbPath(scheduler=scheduler)
+
+
+def get_source_distribution(
+ source_distribution: str, vocab_size: int
+) -> SourceDistribution:
+ if source_distribution == "mask":
+ return MaskedSourceDistribution(mask_token=vocab_size)
+ elif source_distribution == "uniform":
+ return UniformSourceDistribution(vocab_size=vocab_size)
+ else:
+ raise ValueError(f"{source_distribution} is not supported")
+
+
+def get_loss_function(loss_function: str, path: Optional[ProbPath] = None) -> _Loss:
+ if loss_function == "cross_entropy":
+ return torch.nn.CrossEntropyLoss()
+ elif loss_function == "generalized_kl":
+ assert path is not None
+
+ return MixturePathGeneralizedKL(path=path)
+ else:
+ raise ValueError(f"{loss_function} is not supported")
\ No newline at end of file
diff --git a/flow_matching/utils/manifolds/__init__.py b/flow_matching/utils/manifolds/__init__.py
new file mode 100644
index 0000000..1148872
--- /dev/null
+++ b/flow_matching/utils/manifolds/__init__.py
@@ -0,0 +1,18 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .manifold import Euclidean, Manifold
+from .sphere import Sphere
+from .torus import FlatTorus
+from .utils import geodesic
+
+__all__ = [
+ "Euclidean",
+ "Manifold",
+ "Sphere",
+ "FlatTorus",
+ "geodesic",
+]
diff --git a/flow_matching/utils/manifolds/manifold.py b/flow_matching/utils/manifolds/manifold.py
new file mode 100644
index 0000000..52a6a1b
--- /dev/null
+++ b/flow_matching/utils/manifolds/manifold.py
@@ -0,0 +1,93 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+
+import abc
+
+import torch.nn as nn
+from torch import Tensor
+
+
+class Manifold(nn.Module, metaclass=abc.ABCMeta):
+ """A manifold class that contains projection operations and logarithm and exponential maps."""
+
+ @abc.abstractmethod
+ def expmap(self, x: Tensor, u: Tensor) -> Tensor:
+ r"""Computes exponential map :math:`\exp_x(u)`.
+
+ Args:
+ x (Tensor): point on the manifold
+ u (Tensor): tangent vector at point :math:`x`
+
+ Raises:
+ NotImplementedError: if not implemented
+
+ Returns:
+ Tensor: transported point
+ """
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def logmap(self, x: Tensor, y: Tensor) -> Tensor:
+ r"""Computes logarithmic map :math:`\log_x(y)`.
+
+ Args:
+ x (Tensor): point on the manifold
+ y (Tensor): point on the manifold
+
+ Raises:
+ NotImplementedError: if not implemented
+
+ Returns:
+ Tensor: tangent vector at point :math:`x`
+ """
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def projx(self, x: Tensor) -> Tensor:
+ """Project point :math:`x` on the manifold.
+
+ Args:
+ x (Tensor): point to be projected
+
+ Raises:
+ NotImplementedError: if not implemented
+
+ Returns:
+ Tensor: projected point on the manifold
+ """
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def proju(self, x: Tensor, u: Tensor) -> Tensor:
+ """Project vector :math:`u` on a tangent space for :math:`x`.
+
+ Args:
+ x (Tensor): point on the manifold
+ u (Tensor): vector to be projected
+
+ Raises:
+ NotImplementedError: if not implemented
+
+ Returns:
+ Tensor: projected tangent vector
+ """
+ raise NotImplementedError
+
+
+class Euclidean(Manifold):
+ """The Euclidean manifold."""
+
+ def expmap(self, x: Tensor, u: Tensor) -> Tensor:
+ return x + u
+
+ def logmap(self, x: Tensor, y: Tensor) -> Tensor:
+ return y - x
+
+ def projx(self, x: Tensor) -> Tensor:
+ return x
+
+ def proju(self, x: Tensor, u: Tensor) -> Tensor:
+ return u
diff --git a/flow_matching/utils/manifolds/sphere.py b/flow_matching/utils/manifolds/sphere.py
new file mode 100644
index 0000000..76bf748
--- /dev/null
+++ b/flow_matching/utils/manifolds/sphere.py
@@ -0,0 +1,45 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import Tensor
+
+from flow_matching.utils.manifolds import Manifold
+
+
+class Sphere(Manifold):
+ """Represents a hyperpshere in :math:`R^D`. Isometric to the product of 1-D spheres."""
+
+ EPS = {torch.float32: 1e-4, torch.float64: 1e-7}
+
+ def expmap(self, x: Tensor, u: Tensor) -> Tensor:
+ norm_u = u.norm(dim=-1, keepdim=True)
+ exp = x * torch.cos(norm_u) + u * torch.sin(norm_u) / norm_u
+ retr = self.projx(x + u)
+ cond = norm_u > self.EPS[norm_u.dtype]
+
+ return torch.where(cond, exp, retr)
+
+ def logmap(self, x: Tensor, y: Tensor) -> Tensor:
+ u = self.proju(x, y - x)
+ dist = self.dist(x, y, keepdim=True)
+ cond = dist.gt(self.EPS[x.dtype])
+ result = torch.where(
+ cond,
+ u * dist / u.norm(dim=-1, keepdim=True).clamp_min(self.EPS[x.dtype]),
+ u,
+ )
+ return result
+
+ def projx(self, x: Tensor) -> Tensor:
+ return x / x.norm(dim=-1, keepdim=True)
+
+ def proju(self, x: Tensor, u: Tensor) -> Tensor:
+ return u - (x * u).sum(dim=-1, keepdim=True) * x
+
+ def dist(self, x: Tensor, y: Tensor, *, keepdim=False) -> Tensor:
+ inner = (x * y).sum(-1, keepdim=keepdim)
+ return torch.acos(inner)
diff --git a/flow_matching/utils/manifolds/torus.py b/flow_matching/utils/manifolds/torus.py
new file mode 100644
index 0000000..3587ed7
--- /dev/null
+++ b/flow_matching/utils/manifolds/torus.py
@@ -0,0 +1,28 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+
+import torch
+from torch import Tensor
+
+from flow_matching.utils.manifolds import Manifold
+
+
+class FlatTorus(Manifold):
+ r"""Represents a flat torus on the :math:`[0, 2\pi]^D` subspace. Isometric to the product of 1-D spheres."""
+
+ def expmap(self, x: Tensor, u: Tensor) -> Tensor:
+ return (x + u) % (2 * math.pi)
+
+ def logmap(self, x: Tensor, y: Tensor) -> Tensor:
+ return torch.atan2(torch.sin(y - x), torch.cos(y - x))
+
+ def projx(self, x: Tensor) -> Tensor:
+ return x % (2 * math.pi)
+
+ def proju(self, x: Tensor, u: Tensor) -> Tensor:
+ return u
diff --git a/flow_matching/utils/manifolds/utils.py b/flow_matching/utils/manifolds/utils.py
new file mode 100644
index 0000000..b83d2fa
--- /dev/null
+++ b/flow_matching/utils/manifolds/utils.py
@@ -0,0 +1,45 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Callable
+
+import torch
+from torch import Tensor
+
+from flow_matching.utils.manifolds import Manifold
+
+
+def geodesic(
+ manifold: Manifold, start_point: Tensor, end_point: Tensor
+) -> Callable[[Tensor], Tensor]:
+ """Generate parameterized function for geodesic curve.
+
+ Args:
+ manifold (Manifold): the manifold to compute geodesic on.
+ start_point (Tensor): point on the manifold at :math:`t=0`.
+ end_point (Tensor): point on the manifold at :math:`t=1`.
+
+ Returns:
+ Callable[[Tensor], Tensor]: a function that takes in :math:`t` and outputs the geodesic at time :math:`t`.
+ """
+
+ shooting_tangent_vec = manifold.logmap(start_point, end_point)
+
+ def path(t: Tensor) -> Tensor:
+ """Generate parameterized function for geodesic curve.
+
+ Args:
+ t (Tensor): Times at which to compute points of the geodesics.
+
+ Returns:
+ Tensor: geodesic path evaluated at time t.
+ """
+ tangent_vecs = torch.einsum("i,...k->...ik", t, shooting_tangent_vec)
+ points_at_time_t = manifold.expmap(start_point.unsqueeze(-2), tangent_vecs)
+
+ return points_at_time_t
+
+ return path
diff --git a/flow_matching/utils/model_wrapper.py b/flow_matching/utils/model_wrapper.py
new file mode 100644
index 0000000..ac7d932
--- /dev/null
+++ b/flow_matching/utils/model_wrapper.py
@@ -0,0 +1,43 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+
+from abc import ABC
+
+from torch import nn, Tensor
+
+
+class ModelWrapper(ABC, nn.Module):
+ """
+ This class is used to wrap around another model, adding custom forward pass logic.
+ """
+
+ def __init__(self, model: nn.Module):
+ super().__init__()
+ self.model = model
+
+ def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor:
+ r"""
+ This method defines how inputs should be passed through the wrapped model.
+ Here, we're assuming that the wrapped model takes both :math:`x` and :math:`t` as input,
+ along with any additional keyword arguments.
+
+ Optional things to do here:
+ - check that t is in the dimensions that the model is expecting.
+ - add a custom forward pass logic.
+ - call the wrapped model.
+
+ | given x, t
+ | returns the model output for input x at time t, with extra information `extra`.
+
+ Args:
+ x (Tensor): input data to the model (batch_size, ...).
+ t (Tensor): time (batch_size).
+ **extras: additional information forwarded to the model, e.g., text condition.
+
+ Returns:
+ Tensor: model output.
+ """
+ return self.model(x=x, t=t, **extras)
diff --git a/flow_matching/utils/utils.py b/flow_matching/utils/utils.py
new file mode 100644
index 0000000..beb31ff
--- /dev/null
+++ b/flow_matching/utils/utils.py
@@ -0,0 +1,90 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Optional
+
+import torch
+from torch import Tensor
+
+
+def unsqueeze_to_match(source: Tensor, target: Tensor, how: str = "suffix") -> Tensor:
+ """
+ Unsqueeze the source tensor to match the dimensionality of the target tensor.
+
+ Args:
+ source (Tensor): The source tensor to be unsqueezed.
+ target (Tensor): The target tensor to match the dimensionality of.
+ how (str, optional): Whether to unsqueeze the source tensor at the beginning
+ ("prefix") or end ("suffix"). Defaults to "suffix".
+
+ Returns:
+ Tensor: The unsqueezed source tensor.
+ """
+ assert (
+ how == "prefix" or how == "suffix"
+ ), f"{how} is not supported, only 'prefix' and 'suffix' are supported."
+
+ dim_diff = target.dim() - source.dim()
+
+ for _ in range(dim_diff):
+ if how == "prefix":
+ source = source.unsqueeze(0)
+ elif how == "suffix":
+ source = source.unsqueeze(-1)
+
+ return source
+
+
+def expand_tensor_like(input_tensor: Tensor, expand_to: Tensor) -> Tensor:
+ """`input_tensor` is a 1d vector of length equal to the batch size of `expand_to`,
+ expand `input_tensor` to have the same shape as `expand_to` along all remaining dimensions.
+
+ Args:
+ input_tensor (Tensor): (batch_size,).
+ expand_to (Tensor): (batch_size, ...).
+
+ Returns:
+ Tensor: (batch_size, ...).
+ """
+ assert input_tensor.ndim == 1, "Input tensor must be a 1d vector."
+ assert (
+ input_tensor.shape[0] == expand_to.shape[0]
+ ), f"The first (batch_size) dimension must match. Got shape {input_tensor.shape} and {expand_to.shape}."
+
+ dim_diff = expand_to.ndim - input_tensor.ndim
+
+ t_expanded = input_tensor.clone()
+ t_expanded = t_expanded.reshape(-1, *([1] * dim_diff))
+
+ return t_expanded.expand_as(expand_to)
+
+
+def gradient(
+ output: Tensor,
+ x: Tensor,
+ grad_outputs: Optional[Tensor] = None,
+ create_graph: bool = False,
+) -> Tensor:
+ """
+ Compute the gradient of the inner product of output and grad_outputs w.r.t :math:`x`.
+
+ Args:
+ output (Tensor): [N, D] Output of the function.
+ x (Tensor): [N, d_1, d_2, ... ] input
+ grad_outputs (Optional[Tensor]): [N, D] Gradient of outputs, if `None`,
+ then will use a tensor of ones
+ create_graph (bool): If True, graph of the derivative will be constructed, allowing
+ to compute higher order derivative products. Defaults to False.
+ Returns:
+ Tensor: [N, d_1, d_2, ... ]. the gradient w.r.t x.
+ """
+
+ if grad_outputs is None:
+ grad_outputs = torch.ones_like(output).detach()
+ grad = torch.autograd.grad(
+ output, x, grad_outputs=grad_outputs, create_graph=create_graph
+ )[0]
+ return grad
diff --git a/fudoki/__init__.py b/fudoki/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/fudoki/eval_loop.py b/fudoki/eval_loop.py
new file mode 100644
index 0000000..454a37c
--- /dev/null
+++ b/fudoki/eval_loop.py
@@ -0,0 +1,96 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+import logging
+
+import torch
+from torch.nn.modules import Module
+
+from flow_matching.utils import ModelWrapper
+
+PRINT_FREQUENCY = 50
+
+
+logger = logging.getLogger(__name__)
+
+
+def top_k_logits(logits, top_k=None):
+ top_k = min(top_k, logits.size(-1)) # Safety check
+ # Remove all tokens with a probability less than the last token of the top-k
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
+ logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
+ return logits
+
+
+class CFGScaledModel(ModelWrapper):
+ def __init__(self, model: Module, g_or_u, mode='eval', top=1):
+ super().__init__(model)
+ self.nfe_counter = 0
+ self.g_or_u = g_or_u
+ self.mode = mode
+ self.top = top
+ assert self.g_or_u in ['understanding', 'generation', 'generation and understanding']
+ assert self.mode in ['train', 'eval', "train-top"]
+
+ def forward(
+ self, x: torch.Tensor, cfg_scale: float, datainfo, uncond_id=100015
+ ):
+ with torch.no_grad():
+ conditional_img_logits, conditional_txt_logits = self.model(x, datainfo)
+ if self.g_or_u == 'understanding':
+ conditional_logits = conditional_txt_logits
+ result_txt = conditional_logits
+ if self.mode == 'eval':
+ result_txt = top_k_logits(result_txt, top_k=1)
+ elif self.mode == 'train-top':
+ result_txt = self.add_gumbel_noise(result_txt, temperature=0.1, dtype=result_txt.dtype)
+ result_txt = top_k_logits(result_txt, top_k=self.top)
+ result_img = None
+ elif self.g_or_u == 'generation':
+ conditional_logits = conditional_img_logits
+
+ uncondition_x = x.clone()
+ text_token_mask = datainfo['text_token_mask']
+ for bs in range(text_token_mask.shape[0]):
+ nz = datainfo['text_token_mask'][bs].nonzero()
+ if nz.numel() > 0: # Make sure there's at least one nonzero
+ text_nonzero_idx_begin = nz[0, 0] # first nonzero along dim=0
+ text_nonzero_idx_end = nz[-1, 0]
+ uncondition_x[bs, text_nonzero_idx_begin:text_nonzero_idx_end+1] = uncond_id
+ unconditional_img_logits, _ = self.model(uncondition_x, datainfo)
+ unconditional_logits = unconditional_img_logits
+ result_img = (1.0 + cfg_scale) * conditional_logits - cfg_scale * unconditional_logits
+ # result_img = top_k_logits(result_img, top_k=1) # forcing this reduces diversity
+ result_txt = None
+ else:
+ result_img = conditional_img_logits
+ result_txt = conditional_txt_logits
+
+ self.nfe_counter += 1
+ if self.g_or_u == 'understanding':
+ return torch.softmax(result_txt.to(dtype=torch.float32), dim=-1), result_img, datainfo
+ elif self.g_or_u == 'generation':
+ return result_txt, torch.softmax(result_img.to(dtype=torch.float32), dim=-1), datainfo
+ else:
+ return torch.softmax(result_txt.to(dtype=torch.float32), dim=-1), torch.softmax(result_img.to(dtype=torch.float32), dim=-1), datainfo
+
+ def reset_nfe_counter(self) -> None:
+ self.nfe_counter = 0
+
+ def get_nfe(self) -> int:
+ return self.nfe_counter
+
+ def add_gumbel_noise(self, logits, temperature, dtype):
+ """
+ The Gumbel max is a method for sampling categorical distributions.
+ According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
+ Thus, we use float64.
+ """
+ if temperature == 0.0:
+ return logits # Skip noise when temperature is 0
+ logits = logits.to(dtype)
+ noise = torch.rand_like(logits, dtype=dtype)
+ gumbel_noise = (-torch.log(noise)) ** temperature
+ return logits.exp() / gumbel_noise
diff --git a/fudoki/janus/__init__.py b/fudoki/janus/__init__.py
new file mode 100644
index 0000000..09cc08c
--- /dev/null
+++ b/fudoki/janus/__init__.py
@@ -0,0 +1,31 @@
+# Copyright (c) 2023-2024 DeepSeek.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy of
+# this software and associated documentation files (the "Software"), to deal in
+# the Software without restriction, including without limitation the rights to
+# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
+# the Software, and to permit persons to whom the Software is furnished to do so,
+# subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
+# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
+# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+
+# check if python version is above 3.10
+import sys
+
+if sys.version_info >= (3, 10):
+ print("Python version is above 3.10, patching the collections module.")
+ # Monkey patch collections
+ import collections
+ import collections.abc
+
+ for type_name in collections.abc.__all__:
+ setattr(collections, type_name, getattr(collections.abc, type_name))
diff --git a/fudoki/janus/models/__init__.py b/fudoki/janus/models/__init__.py
new file mode 100644
index 0000000..9469193
--- /dev/null
+++ b/fudoki/janus/models/__init__.py
@@ -0,0 +1,28 @@
+# Copyright (c) 2023-2024 DeepSeek.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy of
+# this software and associated documentation files (the "Software"), to deal in
+# the Software without restriction, including without limitation the rights to
+# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
+# the Software, and to permit persons to whom the Software is furnished to do so,
+# subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
+# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
+# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+from .image_processing_vlm import VLMImageProcessor
+from .modeling_vlm import MultiModalityCausalLM
+from .processing_vlm import VLChatProcessor
+
+__all__ = [
+ "VLMImageProcessor",
+ "VLChatProcessor",
+ "MultiModalityCausalLM",
+]
diff --git a/fudoki/janus/models/action_tokenizer.py b/fudoki/janus/models/action_tokenizer.py
new file mode 100644
index 0000000..b653d51
--- /dev/null
+++ b/fudoki/janus/models/action_tokenizer.py
@@ -0,0 +1,83 @@
+"""
+action_tokenizer.py
+
+Extension class; wraps base LLM/VLM tokenizer with logic to discretize and tokenize continuous robot actions.
+"""
+
+from typing import List, Union
+
+import numpy as np
+from transformers import PreTrainedTokenizerBase
+
+
+class ActionTokenizer:
+ def __init__(
+ self, tokenizer: PreTrainedTokenizerBase, bins: int = 256, min_action: int = -1, max_action: int = 1
+ ) -> None:
+ """
+ Discretizes continuous robot actions into N bins per dimension and maps to the least used tokens.
+
+ NOTE =>> by default, assumes a BPE-style tokenizer akin to the LlamaTokenizer, where *the least used tokens*
+ appear at the end of the vocabulary!
+
+ :param tokenizer: Base LLM/VLM tokenizer to extend.
+ :param bins: Number of bins for each continuous value; we'll adopt a uniform binning strategy.
+ :param min_action: Minimum action value (for clipping, setting lower bound on bin interval).
+ :param max_action: Maximum action value (for clipping, setting upper bound on bin interval).
+ """
+ self.tokenizer, self.n_bins, self.min_action, self.max_action = tokenizer, bins, min_action, max_action
+
+ # Create Uniform Bins + Compute Bin Centers
+ self.bins = np.linspace(min_action, max_action, self.n_bins)
+ self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0
+
+ # [Contract] Set "action_token_begin_idx" based on `self.tokenizer.vocab_size - (self.n_bins + 1)`
+ # =>> Assumes we're always overwriting the final `n_bins` tokens of the vocabulary!
+ self.action_token_begin_idx: int = int(self.tokenizer.vocab_size - (self.n_bins + 1))
+
+ def __call__(self, action: np.ndarray) -> Union[str, List[str]]:
+ """Clip & bin actions to *the last `n_bins` tokens* of the vocabulary (e.g., tokenizer.vocab[-256:])."""
+ action = np.clip(action, a_min=float(self.min_action), a_max=float(self.max_action))
+ discretized_action = np.digitize(action, self.bins)
+
+ # Handle single element vs. batch
+ if len(discretized_action.shape) == 1:
+ return self.tokenizer.decode(list(self.tokenizer.vocab_size - discretized_action))
+ else:
+ return self.tokenizer.batch_decode((self.tokenizer.vocab_size - discretized_action).tolist())
+
+ def encode_actions_to_token_ids(self, action: np.ndarray) -> Union[str, List[str]]:
+ action = np.clip(action, a_min=float(self.min_action), a_max=float(self.max_action))
+ discretized_action = np.digitize(action, self.bins)
+
+ # Handle single element vs. batch
+ if len(discretized_action.shape) == 1:
+ return list(self.tokenizer.vocab_size - discretized_action)
+
+ else:
+ return (self.tokenizer.vocab_size - discretized_action).tolist()
+
+ def decode_token_ids_to_actions(self, action_token_ids: np.ndarray) -> np.ndarray:
+ """
+ Returns continuous actions for discrete action token IDs.
+
+ NOTE =>> Because of the way the actions are discretized w.r.t. the bins (and not the bin centers), the
+ digitization returns bin indices between [1, # bins], inclusive, when there are actually only
+ (# bins - 1) bin intervals.
+
+ Therefore, if the digitization returns the last possible index, we map this to the last bin interval.
+
+ EXAMPLE =>> Let's say self._bins has 256 values. Then self._bin_centers has 255 values. Digitization returns
+ indices between [1, 256]. We subtract 1 from all indices so that they are between [0, 255]. There
+ is still one index (i==255) that would cause an out-of-bounds error if used to index into
+ self._bin_centers. Therefore, if i==255, we subtract 1 from it so that it just becomes the index of
+ the last bin center. We implement this simply via clipping between [0, 255 - 1].
+ """
+ discretized_actions = self.tokenizer.vocab_size - action_token_ids
+ discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
+
+ return self.bin_centers[discretized_actions]
+
+ @property
+ def vocab_size(self) -> int:
+ return self.n_bins
\ No newline at end of file
diff --git a/fudoki/janus/models/clip_encoder.py b/fudoki/janus/models/clip_encoder.py
new file mode 100644
index 0000000..c436498
--- /dev/null
+++ b/fudoki/janus/models/clip_encoder.py
@@ -0,0 +1,122 @@
+# Copyright (c) 2023-2024 DeepSeek.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy of
+# this software and associated documentation files (the "Software"), to deal in
+# the Software without restriction, including without limitation the rights to
+# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
+# the Software, and to permit persons to whom the Software is furnished to do so,
+# subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
+# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
+# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+from typing import Dict, List, Literal, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torchvision.transforms
+from einops import rearrange
+
+from fudoki.janus.models.siglip_vit import create_siglip_vit
+
+
+class CLIPVisionTower(nn.Module):
+ def __init__(
+ self,
+ model_name: str = "siglip_large_patch16_384",
+ image_size: Union[Tuple[int, int], int] = 336,
+ select_feature: str = "patch",
+ select_layer: int = -2,
+ select_layers: list = None,
+ ckpt_path: str = "",
+ pixel_mean: Optional[List[float]] = None,
+ pixel_std: Optional[List[float]] = None,
+ **kwargs,
+ ):
+ super().__init__()
+
+ self.model_name = model_name
+ self.select_feature = select_feature
+ self.select_layer = select_layer
+ self.select_layers = select_layers
+
+ vision_tower_params = {
+ "model_name": model_name,
+ "image_size": image_size,
+ "ckpt_path": ckpt_path,
+ "select_layer": select_layer,
+ }
+ vision_tower_params.update(kwargs)
+ self.vision_tower, self.forward_kwargs = self.build_vision_tower(
+ vision_tower_params
+ )
+
+ if pixel_mean is not None and pixel_std is not None:
+ image_norm = torchvision.transforms.Normalize(
+ mean=pixel_mean, std=pixel_std
+ )
+ else:
+ image_norm = None
+
+ self.image_norm = image_norm
+
+ def build_vision_tower(self, vision_tower_params):
+ if self.model_name.startswith("siglip"):
+ self.select_feature = "same"
+ vision_tower = create_siglip_vit(**vision_tower_params)
+ forward_kwargs = dict()
+
+ elif self.model_name.startswith("sam"):
+ vision_tower = create_sam_vit(**vision_tower_params)
+ forward_kwargs = dict()
+
+ else: # huggingface
+ from transformers import CLIPVisionModel
+
+ vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params)
+ forward_kwargs = dict(output_hidden_states=True)
+
+ return vision_tower, forward_kwargs
+
+ def feature_select(self, image_forward_outs):
+ if isinstance(image_forward_outs, torch.Tensor):
+ # the output has been the self.select_layer"s features
+ image_features = image_forward_outs
+ else:
+ image_features = image_forward_outs.hidden_states[self.select_layer]
+
+ if self.select_feature == "patch":
+ # if the output has cls_token
+ image_features = image_features[:, 1:]
+ elif self.select_feature == "cls_patch":
+ image_features = image_features
+ elif self.select_feature == "same":
+ image_features = image_features
+
+ else:
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
+ return image_features
+
+ def forward(self, images):
+ """
+
+ Args:
+ images (torch.Tensor): [b, 3, H, W]
+
+ Returns:
+ image_features (torch.Tensor): [b, n_patch, d]
+ """
+
+ if self.image_norm is not None:
+ images = self.image_norm(images)
+
+ image_forward_outs = self.vision_tower(images, **self.forward_kwargs)
+ image_features = self.feature_select(image_forward_outs)
+ return image_features
diff --git a/fudoki/janus/models/image_processing_vlm.py b/fudoki/janus/models/image_processing_vlm.py
new file mode 100644
index 0000000..367dee1
--- /dev/null
+++ b/fudoki/janus/models/image_processing_vlm.py
@@ -0,0 +1,208 @@
+# Copyright (c) 2023-2024 DeepSeek.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy of
+# this software and associated documentation files (the "Software"), to deal in
+# the Software without restriction, including without limitation the rights to
+# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
+# the Software, and to permit persons to whom the Software is furnished to do so,
+# subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
+# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
+# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+from typing import List, Tuple, Union
+
+import numpy as np
+import torch
+import torchvision
+import torchvision.transforms.functional
+from PIL import Image
+from transformers import AutoImageProcessor, PretrainedConfig
+from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
+from transformers.image_utils import to_numpy_array
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
+IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073)
+IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711)
+IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
+IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
+
+
+def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+
+class VLMImageProcessorConfig(PretrainedConfig):
+ model_type = "deepseek_vlm"
+ image_size: int
+ min_size: int
+ image_mean: Union[Tuple[float, float, float], List[float]]
+ image_std: Union[Tuple[float, float, float], List[float]]
+ rescale_factor: float
+ do_normalize: bool
+
+ def __init__(
+ self,
+ image_size: int,
+ min_size: int = 14,
+ image_mean: Union[Tuple[float, float, float], List[float]] = (
+ 0.48145466,
+ 0.4578275,
+ 0.40821073,
+ ),
+ image_std: Union[Tuple[float, float, float], List[float]] = (
+ 0.26862954,
+ 0.26130258,
+ 0.27577711,
+ ),
+ rescale_factor: float = 1.0 / 255.0,
+ do_normalize: bool = True,
+ **kwargs,
+ ):
+ self.image_size = image_size
+ self.min_size = min_size
+ self.image_mean = image_mean
+ self.image_std = image_std
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+
+ super().__init__(**kwargs)
+
+
+class VLMImageProcessor(BaseImageProcessor):
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ image_size: int,
+ min_size: int = 14,
+ image_mean: Union[Tuple[float, float, float], List[float]] = (
+ 0.48145466,
+ 0.4578275,
+ 0.40821073,
+ ),
+ image_std: Union[Tuple[float, float, float], List[float]] = (
+ 0.26862954,
+ 0.26130258,
+ 0.27577711,
+ ),
+ rescale_factor: float = 1.0 / 255.0,
+ do_normalize: bool = True,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.image_size = image_size
+ self.rescale_factor = rescale_factor
+ self.image_mean = image_mean
+ self.image_std = image_std
+ self.min_size = min_size
+ self.do_normalize = do_normalize
+
+ if image_mean is None:
+ self.background_color = (127, 127, 127)
+ else:
+ self.background_color = tuple([int(x * 255) for x in image_mean])
+
+ def resize(self, pil_img: Image) -> np.ndarray:
+ """
+
+ Args:
+ pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB
+
+ Returns:
+ x (np.ndarray): [3, self.image_size, self.image_size]
+ """
+
+ width, height = pil_img.size
+ max_size = max(width, height)
+
+ size = [
+ max(int(height / max_size * self.image_size), self.min_size),
+ max(int(width / max_size * self.image_size), self.min_size),
+ ]
+
+ if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0:
+ print(f"orig size = {pil_img.size}, new size = {size}")
+ raise ValueError("Invalid size!")
+
+ pil_img = torchvision.transforms.functional.resize(
+ pil_img,
+ size,
+ interpolation=torchvision.transforms.functional.InterpolationMode.BICUBIC,
+ antialias=True,
+ )
+
+ pil_img = expand2square(pil_img, self.background_color)
+ x = to_numpy_array(pil_img)
+
+ # [H, W, 3] -> [3, H, W]
+ x = np.transpose(x, (2, 0, 1))
+
+ return x
+
+ def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature:
+ # resize and pad to [self.image_size, self.image_size]
+ # then convert from [H, W, 3] to [3, H, W]
+ images: List[np.ndarray] = [self.resize(image) for image in images]
+
+ # resacle from [0, 255] -> [0, 1]
+ images = [
+ self.rescale(
+ image=image,
+ scale=self.rescale_factor,
+ input_data_format="channels_first",
+ )
+ for image in images
+ ]
+
+ # normalize
+ if self.do_normalize:
+ images = [
+ self.normalize(
+ image=image,
+ mean=self.image_mean,
+ std=self.image_std,
+ input_data_format="channels_first",
+ )
+ for image in images
+ ]
+
+ data = {"pixel_values": images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ @property
+ def default_shape(self):
+ return [3, self.image_size, self.image_size]
+
+
+AutoImageProcessor.register(VLMImageProcessorConfig, VLMImageProcessor)
+
+
+if __name__ == "__main__":
+ image_processor = VLMImageProcessor(
+ image_size=1024,
+ image_mean=IMAGENET_INCEPTION_MEAN,
+ image_std=IMAGENET_INCEPTION_STD,
+ do_normalize=True,
+ )
diff --git a/fudoki/janus/models/modeling_vlm.py b/fudoki/janus/models/modeling_vlm.py
new file mode 100644
index 0000000..07dd37e
--- /dev/null
+++ b/fudoki/janus/models/modeling_vlm.py
@@ -0,0 +1,439 @@
+# Copyright (c) 2023-2024 DeepSeek.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy of
+# this software and associated documentation files (the "Software"), to deal in
+# the Software without restriction, including without limitation the rights to
+# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
+# the Software, and to permit persons to whom the Software is furnished to do so,
+# subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
+# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
+# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+import torch
+from attrdict import AttrDict
+from einops import rearrange
+from transformers import (
+ AutoConfig,
+ AutoModelForCausalLM,
+ LlamaConfig,
+ LlamaForCausalLM,
+ PreTrainedModel,
+)
+from transformers.configuration_utils import PretrainedConfig
+
+from fudoki.janus.models.clip_encoder import CLIPVisionTower
+from fudoki.janus.models.projector import MlpProjector
+
+from icecream import ic
+
+class vision_head(torch.nn.Module):
+ def __init__(self, params):
+ super().__init__()
+ self.output_mlp_projector = torch.nn.Linear(
+ params.n_embed, params.image_token_embed
+ )
+ self.vision_activation = torch.nn.GELU()
+ self.vision_head = torch.nn.Linear(
+ params.image_token_embed, params.image_token_size
+ )
+
+ def forward(self, x):
+ x = self.output_mlp_projector(x)
+ x = self.vision_activation(x)
+ x = self.vision_head(x)
+ return x
+
+class number_head(torch.nn.Module):
+ def __init__(self, d_model=2048, dim_feedforward=5632, numhead_bias=True):
+ super().__init__()
+ self.num_head = torch.nn.Sequential(
+ torch.nn.Linear(d_model, dim_feedforward, bias=numhead_bias),
+ torch.nn.GELU(),
+ torch.nn.Linear(dim_feedforward, 1, bias=numhead_bias),
+ )
+
+ def forward(self, x):
+ x = self.num_head(x)
+ return x
+
+def model_name_to_cls(cls_name):
+ if "MlpProjector" in cls_name:
+ cls = MlpProjector
+
+ elif "CLIPVisionTower" in cls_name:
+ cls = CLIPVisionTower
+
+ elif "VQ" in cls_name:
+ from fudoki.janus.models.vq_model import VQ_models
+
+ cls = VQ_models[cls_name]
+ elif "vision_head" in cls_name:
+ cls = vision_head
+ else:
+ raise ValueError(f"class_name {cls_name} is invalid.")
+
+ return cls
+
+
+class VisionConfig(PretrainedConfig):
+ model_type = "vision"
+ cls: str = ""
+ params: AttrDict = {}
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ self.cls = kwargs.get("cls", "")
+ if not isinstance(self.cls, str):
+ self.cls = self.cls.__name__
+
+ self.params = AttrDict(kwargs.get("params", {}))
+
+
+class AlignerConfig(PretrainedConfig):
+ model_type = "aligner"
+ cls: str = ""
+ params: AttrDict = {}
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ self.cls = kwargs.get("cls", "")
+ if not isinstance(self.cls, str):
+ self.cls = self.cls.__name__
+
+ self.params = AttrDict(kwargs.get("params", {}))
+
+
+class GenVisionConfig(PretrainedConfig):
+ model_type = "gen_vision"
+ cls: str = ""
+ params: AttrDict = {}
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ self.cls = kwargs.get("cls", "")
+ if not isinstance(self.cls, str):
+ self.cls = self.cls.__name__
+
+ self.params = AttrDict(kwargs.get("params", {}))
+
+
+class GenAlignerConfig(PretrainedConfig):
+ model_type = "gen_aligner"
+ cls: str = ""
+ params: AttrDict = {}
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ self.cls = kwargs.get("cls", "")
+ if not isinstance(self.cls, str):
+ self.cls = self.cls.__name__
+
+ self.params = AttrDict(kwargs.get("params", {}))
+
+
+class GenHeadConfig(PretrainedConfig):
+ model_type = "gen_head"
+ cls: str = ""
+ params: AttrDict = {}
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ self.cls = kwargs.get("cls", "")
+ if not isinstance(self.cls, str):
+ self.cls = self.cls.__name__
+
+ self.params = AttrDict(kwargs.get("params", {}))
+
+
+class MultiModalityConfig(PretrainedConfig):
+ model_type = "multi_modality"
+ vision_config: VisionConfig
+ aligner_config: AlignerConfig
+
+ gen_vision_config: GenVisionConfig
+ gen_aligner_config: GenAlignerConfig
+ gen_head_config: GenHeadConfig
+
+ language_config: LlamaConfig
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ vision_config = kwargs.get("vision_config", {})
+ self.vision_config = VisionConfig(**vision_config)
+
+ aligner_config = kwargs.get("aligner_config", {})
+ self.aligner_config = AlignerConfig(**aligner_config)
+
+ gen_vision_config = kwargs.get("gen_vision_config", {})
+ self.gen_vision_config = GenVisionConfig(**gen_vision_config)
+
+ gen_aligner_config = kwargs.get("gen_aligner_config", {})
+ self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config)
+
+ gen_head_config = kwargs.get("gen_head_config", {})
+ self.gen_head_config = GenHeadConfig(**gen_head_config)
+
+ language_config = kwargs.get("language_config", {})
+ if isinstance(language_config, LlamaConfig):
+ self.language_config = language_config
+ else:
+ self.language_config = LlamaConfig(**language_config)
+
+
+class MultiModalityPreTrainedModel(PreTrainedModel):
+ config_class = MultiModalityConfig
+ base_model_prefix = "multi_modality"
+ _no_split_modules = []
+ _skip_keys_device_placement = "past_key_values"
+
+
+class MultiModalityCausalLM(MultiModalityPreTrainedModel):
+ def __init__(self, config: MultiModalityConfig):
+ super().__init__(config)
+
+ vision_config = config.vision_config
+ vision_cls = model_name_to_cls(vision_config.cls)
+ self.vision_model = vision_cls(**vision_config.params)
+
+ aligner_config = config.aligner_config
+ aligner_cls = model_name_to_cls(aligner_config.cls)
+ self.aligner = aligner_cls(aligner_config.params)
+
+ gen_vision_config = config.gen_vision_config
+ gen_vision_cls = model_name_to_cls(gen_vision_config.cls)
+ self.gen_vision_model = gen_vision_cls()
+
+ gen_aligner_config = config.gen_aligner_config
+ gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls)
+ self.gen_aligner = gen_aligner_cls(gen_aligner_config.params)
+
+ gen_head_config = config.gen_head_config
+ gen_head_cls = model_name_to_cls(gen_head_config.cls)
+ self.gen_head = gen_head_cls(gen_head_config.params)
+
+ self.gen_embed = torch.nn.Embedding(
+ gen_vision_config.params.image_token_size, gen_vision_config.params.n_embed
+ )
+
+ language_config = config.language_config
+ self.language_model = LlamaForCausalLM(language_config)
+ # Convert causal attention to full attention
+ self._convert_to_full_attention()
+
+ def add_number_head(self):
+ self.language_model.num_head = number_head(
+ d_model=self.language_model.config.hidden_size,
+ dim_feedforward=self.language_model.config.intermediate_size,
+ numhead_bias=True,
+ )
+
+ def _convert_to_full_attention(self):
+ """Convert all causal attention layers to full attention using BlockDiagonalMask"""
+ import types
+ import xformers
+ import xformers.ops.fmha as fmha
+ from transformers.models.llama.modeling_llama import repeat_kv, apply_rotary_pos_emb
+ from transformers.cache_utils import Cache
+ import torch.nn.functional as F
+ from typing import Optional, Tuple
+
+ for layer_idx, layer in enumerate(self.language_model.model.layers):
+ if hasattr(layer, 'self_attn'):
+ def full_attention_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ query_states = query_states.contiguous().transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_heads, self.head_dim).contiguous()
+ key_states = key_states.contiguous().transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).contiguous()
+ value_states = value_states.contiguous().transpose(1, 2).contiguous().reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).contiguous()
+ # attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, p=self.attention_dropout, attn_bias=None)
+
+ attn_bias, flattened_q = fmha.BlockDiagonalMask.from_tensor_list([query_states[bs, :attention_mask[bs].sum()][None] for bs in range(bsz)])
+ _, flattened_k = fmha.BlockDiagonalMask.from_tensor_list([key_states[bs, :attention_mask[bs].sum()][None] for bs in range(bsz)])
+ _, flattened_v = fmha.BlockDiagonalMask.from_tensor_list([value_states[bs, :attention_mask[bs].sum()][None] for bs in range(bsz)])
+
+
+ output = xformers.ops.memory_efficient_attention(flattened_q, flattened_k, flattened_v, p=self.attention_dropout, attn_bias=attn_bias)
+ output = attn_bias.split(output)
+ attn_output = hidden_states.clone().reshape(bsz, q_len, self.num_heads, self.head_dim)
+ for bs in range(bsz):
+ attn_output[bs, :attention_mask[bs].sum()] = output[bs]
+
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+ else:
+ query, key = flattened_q, flattened_k
+ scale = 1.0 / query.shape[-1] ** 0.5
+ query = query * scale
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ attn = query @ key.transpose(-2, -1)
+ attn = attn.softmax(-1)
+ attn_weights = attn
+
+ return attn_output, attn_weights, past_key_value
+
+ # Bind the modified forward method
+ layer.self_attn.forward = types.MethodType(full_attention_forward, layer.self_attn)
+
+ if hasattr(self.language_model.model, '_update_causal_mask'):
+ def _update_causal_mask(
+ self,
+ attention_mask: torch.Tensor,
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Optional[Cache] = None,
+ output_attentions: bool = False,
+ ):
+ # Simply return the original attention mask without any causal modifications
+ return attention_mask
+
+ # Override the method in the language model
+ self.language_model.model._update_causal_mask = types.MethodType(
+ _update_causal_mask, self.language_model.model
+ )
+
+ def prepare_inputs_embeds(
+ self,
+ input_ids: torch.LongTensor,
+ pixel_values: torch.FloatTensor,
+ images_seq_mask: torch.LongTensor,
+ images_emb_mask: torch.LongTensor,
+ **kwargs,
+ ):
+ """
+
+ Args:
+ input_ids (torch.LongTensor): [b, T]
+ pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
+ images_seq_mask (torch.BoolTensor): [b, T]
+ images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
+
+ assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
+
+ Returns:
+ input_embeds (torch.Tensor): [b, T, D]
+ """
+
+ bs, n = pixel_values.shape[0:2]
+ images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
+ # [b x n, T2, D]
+ images_embeds = self.aligner(self.vision_model(images))
+
+ # [b x n, T2, D] -> [b, n x T2, D]
+ images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
+ # [b, n, T2] -> [b, n x T2]
+ images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
+
+ # [b, T, D]
+ input_ids[input_ids < 0] = 0 # ignore the image embeddings
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
+
+ # replace with the image embeddings
+ inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
+
+ return inputs_embeds
+
+ def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
+ return self.gen_aligner(self.gen_embed(image_ids))
+
+ def get_fsdp_wrap_module_list(self):
+ return list(self.language_model.model.layers)
+
+ def token_drop(self, mmsamples, datainfo, uncond_prob=0.1, uncond_id=100015):
+ batch_size = mmsamples.shape[0]
+ drop_ids = torch.rand(batch_size, device=mmsamples.device) < uncond_prob
+ uncondition_context = mmsamples.clone()
+ generation_mask = (datainfo['generation_or_understanding_mask'] == 1)
+ for b in range(batch_size):
+ if drop_ids[b]:
+ if generation_mask[b]:
+ nz = datainfo['text_token_mask'][b].nonzero()
+ if nz.numel() > 0:
+ text_nonzero_idx_begin = nz[0, 0]
+ text_nonzero_idx_end = nz[-1, 0]
+ uncondition_context[b, text_nonzero_idx_begin:text_nonzero_idx_end+1] = uncond_id
+ return uncondition_context
+
+ def forward(self, mmsamples, datainfo):
+ if self.training:
+ mmsamples = self.token_drop(mmsamples, datainfo, uncond_prob=0.1, uncond_id=100015)
+ for b_index in range(mmsamples.shape[0]):
+ mask = datainfo['image_token_mask'][b_index] == 1
+ indices = torch.nonzero(mask, as_tuple=False)
+
+ if datainfo['generation_or_understanding_mask'][b_index] == 1:
+ imgsamples = mmsamples[b_index, indices[:, 0]]
+ img_embeds = self.prepare_gen_img_embeds(imgsamples.unsqueeze(0))
+ inputs_embeds = self.language_model.get_input_embeddings()(mmsamples)
+ inputs_embeds[b_index, indices[:, 0]] = img_embeds.reshape(-1, img_embeds.shape[-1])
+ elif datainfo['generation_or_understanding_mask'][b_index] == 0:
+ imgsamples = datainfo['understanding_img']
+ if datainfo['has_understanding_img'][b_index] == 1:
+ imgs = imgsamples[b_index]
+ if imgs.dim() == 3:
+ imgs = imgs.unsqueeze(0)
+ img_embeds = self.aligner(self.vision_model(imgs))
+ inputs_embeds = self.language_model.get_input_embeddings()(mmsamples)
+ inputs_embeds[b_index, indices[:, 0]] = img_embeds.reshape(-1, img_embeds.shape[-1])
+ else:
+ inputs_embeds = self.language_model.get_input_embeddings()(mmsamples)
+
+ outputs = self.language_model.model(inputs_embeds=inputs_embeds, use_cache=False, attention_mask=datainfo['attention_mask'])
+ hidden_states = outputs.last_hidden_state
+
+ img_logits = self.gen_head(hidden_states)
+ txt_logits = self.language_model.lm_head(hidden_states)
+
+ img_logits = torch.cat([torch.zeros((img_logits.shape[0], 1, img_logits.shape[2]), device=img_logits.device), img_logits[:, :-1, :]], dim=1)
+ txt_logits = torch.cat([torch.zeros((txt_logits.shape[0], 1, txt_logits.shape[2]), device=txt_logits.device), txt_logits[:, :-1, :]], dim=1)
+
+ return img_logits, txt_logits
+
+
+AutoConfig.register("vision", VisionConfig)
+AutoConfig.register("aligner", AlignerConfig)
+AutoConfig.register("gen_vision", GenVisionConfig)
+AutoConfig.register("gen_aligner", GenAlignerConfig)
+AutoConfig.register("gen_head", GenHeadConfig)
+AutoConfig.register("multi_modality", MultiModalityConfig)
+AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM)
diff --git a/fudoki/janus/models/processing_vlm.py b/fudoki/janus/models/processing_vlm.py
new file mode 100644
index 0000000..5b1ddb4
--- /dev/null
+++ b/fudoki/janus/models/processing_vlm.py
@@ -0,0 +1,419 @@
+# Copyright (c) 2023-2024 DeepSeek.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy of
+# this software and associated documentation files (the "Software"), to deal in
+# the Software without restriction, including without limitation the rights to
+# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
+# the Software, and to permit persons to whom the Software is furnished to do so,
+# subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
+# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
+# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+from dataclasses import dataclass
+from typing import Dict, List
+
+import torch
+from PIL.Image import Image
+from transformers import LlamaTokenizerFast
+from transformers.processing_utils import ProcessorMixin
+
+from fudoki.janus.models.image_processing_vlm import VLMImageProcessor
+from fudoki.janus.utils.conversation import get_conv_template
+
+
+class DictOutput(object):
+ def keys(self):
+ return self.__dict__.keys()
+
+ def __getitem__(self, item):
+ return self.__dict__[item]
+
+ def __setitem__(self, key, value):
+ self.__dict__[key] = value
+
+
+@dataclass
+class VLChatProcessorOutput(DictOutput):
+ sft_format: str
+ input_ids: torch.Tensor
+ pixel_values: torch.Tensor
+ num_image_tokens: torch.IntTensor
+
+ def __len__(self):
+ return len(self.input_ids)
+
+
+@dataclass
+class BatchedVLChatProcessorOutput(DictOutput):
+ sft_format: List[str]
+ input_ids: torch.Tensor
+ pixel_values: torch.Tensor
+ attention_mask: torch.Tensor
+ images_seq_mask: torch.BoolTensor
+ images_emb_mask: torch.BoolTensor
+
+ def to(self, device, dtype=torch.bfloat16):
+ self.input_ids = self.input_ids.to(device)
+ self.attention_mask = self.attention_mask.to(device)
+ self.images_seq_mask = self.images_seq_mask.to(device)
+ self.images_emb_mask = self.images_emb_mask.to(device)
+ self.pixel_values = self.pixel_values.to(device=device, dtype=dtype)
+ return self
+
+
+class VLChatProcessor(ProcessorMixin):
+ image_processor_class = "AutoImageProcessor"
+ tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
+
+ attributes = ["image_processor", "tokenizer"]
+
+ system_prompt = (
+ "You are a helpful language and vision assistant. "
+ "You are able to understand the visual content that the user provides, "
+ "and assist the user with a variety of tasks using natural language."
+ )
+
+ def __init__(
+ self,
+ image_processor: VLMImageProcessor,
+ tokenizer: LlamaTokenizerFast,
+ image_tag: str = "",
+ image_start_tag: str = "",
+ image_end_tag: str = "",
+ num_image_tokens: int = 576,
+ add_special_token: bool = False,
+ sft_format: str = "fudoki",
+ mask_prompt: bool = True,
+ ignore_id: int = -100,
+ **kwargs,
+ ):
+ self.image_processor = image_processor
+ self.tokenizer = tokenizer
+
+ image_id = self.tokenizer.vocab.get(image_tag)
+ if image_id is None:
+ special_tokens = [image_tag]
+ special_tokens_dict = {"additional_special_tokens": special_tokens}
+ self.tokenizer.add_special_tokens(special_tokens_dict)
+ print(f"Add image tag = {image_tag} to the tokenizer")
+
+ self.image_tag = image_tag
+ self.image_start_tag = image_start_tag
+ self.image_end_tag = image_end_tag
+
+ self.num_image_tokens = num_image_tokens
+ self.add_special_token = add_special_token
+ self.sft_format = sft_format
+ self.mask_prompt = mask_prompt
+ self.ignore_id = ignore_id
+
+ super().__init__(
+ image_processor,
+ tokenizer,
+ image_tag,
+ num_image_tokens,
+ add_special_token,
+ sft_format,
+ mask_prompt,
+ ignore_id,
+ **kwargs,
+ )
+
+ def new_chat_template(self):
+ conv = get_conv_template(self.sft_format)
+ conv.set_system_message(self.system_prompt)
+ return conv
+
+ def apply_sft_template_for_multi_turn_prompts(
+ self,
+ conversations: List[Dict[str, str]],
+ sft_format: str = "fudoki",
+ system_prompt: str = "",
+ ):
+ """
+ Applies the SFT template to conversation.
+
+ An example of conversation:
+ conversation = [
+ {
+ "role": "User",
+ "content": " is Figure 1.\n is Figure 2.\nWhich image is brighter?",
+ "images": [
+ "./multi-images/attribute_comparison_1.png",
+ "./multi-images/attribute_comparison_2.png"
+ ]
+ },
+ {
+ "role": "Assistant",
+ "content": ""
+ }
+ ]
+
+ Args:
+ conversations (List[Dict]): A conversation with a List of Dict[str, str] text.
+ sft_format (str, optional): The format of the SFT template to use. Defaults to "fudoki".
+ system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "".
+
+ Returns:
+ sft_prompt (str): The formatted text.
+ """
+
+ conv = get_conv_template(sft_format)
+ conv.set_system_message(system_prompt)
+ for message in conversations:
+ conv.append_message(message["role"], message["content"].strip())
+ sft_prompt = conv.get_prompt().strip()
+
+ return sft_prompt
+
+ @property
+ def image_token(self):
+ return self.image_tag
+
+ @property
+ def image_id(self):
+ image_id = self.tokenizer.vocab.get(self.image_tag)
+ return image_id
+
+ @property
+ def image_start_id(self):
+ image_start_id = self.tokenizer.vocab.get(self.image_start_tag)
+ return image_start_id
+
+ @property
+ def image_end_id(self):
+ image_end_id = self.tokenizer.vocab.get(self.image_end_tag)
+ return image_end_id
+
+ @property
+ def image_start_token(self):
+ return self.image_start_tag
+
+ @property
+ def image_end_token(self):
+ return self.image_end_tag
+
+ @property
+ def pad_id(self):
+ pad_id = self.tokenizer.pad_token_id
+ if pad_id is None:
+ pad_id = self.tokenizer.eos_token_id
+
+ return pad_id
+
+ def num_id(self, tag=""):
+ num_id = self.tokenizer.vocab.get(tag)
+ return num_id
+
+ def add_image_token(
+ self,
+ image_indices: List[int],
+ input_ids: torch.LongTensor,
+ ):
+ """
+
+ Args:
+ image_indices (List[int]): [index_0, index_1, ..., index_j]
+ input_ids (torch.LongTensor): [N]
+
+ Returns:
+ input_ids (torch.LongTensor): [N + image tokens]
+ num_image_tokens (torch.IntTensor): [n_images]
+ """
+
+ input_slices = []
+
+ start = 0
+ for index in image_indices:
+ if self.add_special_token:
+ end = index + 1
+ else:
+ end = index
+
+ # original text tokens
+ input_slices.append(input_ids[start:end])
+
+ # add boi, image tokens, eoi and set the mask as False
+ input_slices.append(self.image_start_id * torch.ones((1), dtype=torch.long))
+ input_slices.append(
+ self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long)
+ )
+ input_slices.append(self.image_end_id * torch.ones((1), dtype=torch.long))
+ start = index + 1
+
+ # the left part
+ input_slices.append(input_ids[start:])
+
+ # concat all slices
+ input_ids = torch.cat(input_slices, dim=0)
+ num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices))
+
+ return input_ids, num_image_tokens
+
+ def process_one(
+ self,
+ prompt: str = None,
+ conversations: List[Dict[str, str]] = None,
+ images: List[Image] = None,
+ **kwargs,
+ ):
+ """
+
+ Args:
+ prompt (str): the formatted prompt;
+ conversations (List[Dict]): conversations with a list of messages;
+ images (List[ImageType]): the list of images;
+ **kwargs:
+
+ Returns:
+ outputs (BaseProcessorOutput): the output of the processor,
+ - input_ids (torch.LongTensor): [N + image tokens]
+ - target_ids (torch.LongTensor): [N + image tokens]
+ - images (torch.FloatTensor): [n_images, 3, H, W]
+ - image_id (int): the id of the image token
+ - num_image_tokens (List[int]): the number of image tokens
+ """
+
+ assert (
+ prompt is None or conversations is None
+ ), "prompt and conversations cannot be used at the same time."
+
+ if prompt is None:
+ # apply sft format
+ sft_format = self.apply_sft_template_for_multi_turn_prompts(
+ conversations=conversations,
+ sft_format=self.sft_format,
+ system_prompt=self.system_prompt,
+ )
+ else:
+ sft_format = prompt
+
+ # tokenize
+ input_ids = self.tokenizer.encode(sft_format)
+ input_ids = torch.LongTensor(input_ids)
+
+ # add image tokens to the input_ids
+ image_token_mask: torch.BoolTensor = input_ids == self.image_id
+ image_indices = image_token_mask.nonzero()
+ input_ids, num_image_tokens = self.add_image_token(
+ image_indices=image_indices,
+ input_ids=input_ids,
+ )
+
+ # load images
+ images_outputs = self.image_processor(images, return_tensors="pt")
+
+ prepare = VLChatProcessorOutput(
+ sft_format=sft_format,
+ input_ids=input_ids,
+ pixel_values=images_outputs.pixel_values,
+ num_image_tokens=num_image_tokens,
+ )
+
+ return prepare
+
+ def __call__(
+ self,
+ *,
+ prompt: str = None,
+ conversations: List[Dict[str, str]] = None,
+ images: List[Image] = None,
+ force_batchify: bool = True,
+ **kwargs,
+ ):
+ """
+
+ Args:
+ prompt (str): the formatted prompt;
+ conversations (List[Dict]): conversations with a list of messages;
+ images (List[ImageType]): the list of images;
+ force_batchify (bool): force batchify the inputs;
+ **kwargs:
+
+ Returns:
+ outputs (BaseProcessorOutput): the output of the processor,
+ - input_ids (torch.LongTensor): [N + image tokens]
+ - images (torch.FloatTensor): [n_images, 3, H, W]
+ - image_id (int): the id of the image token
+ - num_image_tokens (List[int]): the number of image tokens
+ """
+
+ prepare = self.process_one(
+ prompt=prompt, conversations=conversations, images=images
+ )
+
+ if force_batchify:
+ prepare = self.batchify([prepare])
+
+ return prepare
+
+ def batchify(
+ self, prepare_list: List[VLChatProcessorOutput]
+ ) -> BatchedVLChatProcessorOutput:
+ """
+ Preprocesses the inputs for multimodal inference.
+
+ Args:
+ prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput.
+
+ Returns:
+ BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference.
+ """
+
+ batch_size = len(prepare_list)
+ sft_format = []
+ n_images = []
+ seq_lens = []
+ for prepare in prepare_list:
+ n_images.append(len(prepare.num_image_tokens))
+ seq_lens.append(len(prepare))
+
+ input_token_max_len = max(seq_lens)
+ max_n_images = max(1, max(n_images))
+
+ batched_input_ids = torch.full(
+ (batch_size, input_token_max_len), self.pad_id
+ ).long() # FIXME
+ batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long()
+ batched_pixel_values = torch.zeros(
+ (batch_size, max_n_images, *self.image_processor.default_shape)
+ ).float()
+ batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool()
+ batched_images_emb_mask = torch.zeros(
+ (batch_size, max_n_images, self.num_image_tokens)
+ ).bool()
+
+ for i, prepare in enumerate(prepare_list):
+ input_ids = prepare.input_ids
+ seq_len = len(prepare)
+ n_image = len(prepare.num_image_tokens)
+ # left-padding
+ batched_attention_mask[i, -seq_len:] = 1
+ batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids)
+ batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id
+
+ if n_image > 0:
+ batched_pixel_values[i, :n_image] = prepare.pixel_values
+ for j, n_image_tokens in enumerate(prepare.num_image_tokens):
+ batched_images_emb_mask[i, j, :n_image_tokens] = True
+
+ sft_format.append(prepare.sft_format)
+
+ batched_prepares = BatchedVLChatProcessorOutput(
+ input_ids=batched_input_ids,
+ attention_mask=batched_attention_mask,
+ pixel_values=batched_pixel_values,
+ images_seq_mask=batched_images_seq_mask,
+ images_emb_mask=batched_images_emb_mask,
+ sft_format=sft_format,
+ )
+
+ return batched_prepares
diff --git a/fudoki/janus/models/projector.py b/fudoki/janus/models/projector.py
new file mode 100644
index 0000000..15f4ca3
--- /dev/null
+++ b/fudoki/janus/models/projector.py
@@ -0,0 +1,100 @@
+# Copyright (c) 2023-2024 DeepSeek.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy of
+# this software and associated documentation files (the "Software"), to deal in
+# the Software without restriction, including without limitation the rights to
+# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
+# the Software, and to permit persons to whom the Software is furnished to do so,
+# subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
+# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
+# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+from typing import Tuple, Union
+
+import torch
+import torch.nn as nn
+from attrdict import AttrDict
+
+
+class MlpProjector(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+
+ self.cfg = cfg
+
+ if cfg.projector_type == "identity":
+ modules = nn.Identity()
+
+ elif cfg.projector_type == "linear":
+ modules = nn.Linear(cfg.input_dim, cfg.n_embed)
+
+ elif cfg.projector_type == "mlp_gelu":
+ mlp_depth = cfg.get("depth", 1)
+ modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
+ for _ in range(1, mlp_depth):
+ modules.append(nn.GELU())
+ modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
+ modules = nn.Sequential(*modules)
+
+ elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
+ mlp_depth = cfg.get("depth", 1)
+ self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
+ self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
+
+ modules = []
+ for _ in range(1, mlp_depth):
+ modules.append(nn.GELU())
+ modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
+ modules = nn.Sequential(*modules)
+
+ else:
+ raise ValueError(f"Unknown projector type: {cfg.projector_type}")
+
+ self.layers = modules
+
+ def forward(
+ self, x_or_tuple: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]
+ ):
+ """
+
+ Args:
+ x_or_tuple (Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if it is a tuple of torch.Tensor,
+ then it comes from the hybrid vision encoder, and x = high_res_x, low_res_x);
+ otherwise it is the feature from the single vision encoder.
+
+ Returns:
+ x (torch.Tensor): [b, s, c]
+ """
+
+ if isinstance(x_or_tuple, tuple):
+ # self.cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
+ high_x, low_x = x_or_tuple
+ high_x = self.high_up_proj(high_x)
+ low_x = self.low_up_proj(low_x)
+ x = torch.concat([high_x, low_x], dim=-1)
+ else:
+ x = x_or_tuple
+
+ return self.layers(x)
+
+
+if __name__ == "__main__":
+ cfg = AttrDict(
+ input_dim=1024,
+ n_embed=2048,
+ depth=2,
+ projector_type="low_high_hybrid_split_mlp_gelu",
+ )
+ inputs = (torch.rand(4, 576, 1024), torch.rand(4, 576, 1024))
+
+ m = MlpProjector(cfg)
+ out = m(inputs)
+ print(out.shape)
diff --git a/fudoki/janus/models/siglip_vit.py b/fudoki/janus/models/siglip_vit.py
new file mode 100644
index 0000000..ba426d6
--- /dev/null
+++ b/fudoki/janus/models/siglip_vit.py
@@ -0,0 +1,681 @@
+# Copyright (c) 2023-2024 DeepSeek.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy of
+# this software and associated documentation files (the "Software"), to deal in
+# the Software without restriction, including without limitation the rights to
+# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
+# the Software, and to permit persons to whom the Software is furnished to do so,
+# subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
+# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
+# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+# https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
+import math
+import warnings
+from dataclasses import dataclass
+from functools import partial
+from typing import (
+ Callable,
+ Dict,
+ Final,
+ List,
+ Literal,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+ Type,
+ Union,
+)
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from timm.layers import (
+ AttentionPoolLatent,
+ DropPath,
+ LayerType,
+ Mlp,
+ PatchDropout,
+ PatchEmbed,
+ resample_abs_pos_embed,
+)
+from timm.models._manipulate import checkpoint_seq, named_apply
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2,
+ )
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ l = norm_cdf((a - mean) / std) # noqa: E741
+ u = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.0))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
+ # type: (torch.Tensor, float, float, float, float) -> torch.Tensor
+ r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first
+ convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its original dtype.
+ Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn
+ from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.trunc_normal_(w)
+ """
+
+ with torch.no_grad():
+ dtype = tensor.dtype
+ tensor_fp32 = tensor.float()
+ tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b)
+ tensor_dtype = tensor_fp32.to(dtype=dtype)
+ tensor.copy_(tensor_dtype)
+
+
+def init_weights(self):
+ if self.pos_embed is not None:
+ trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
+ trunc_normal_(self.latent, std=self.latent_dim**-0.5)
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = "") -> None:
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif hasattr(module, "init_weights"):
+ module.init_weights()
+
+
+class Attention(nn.Module):
+ fused_attn: Final[bool]
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = False,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ norm_layer: nn.Module = nn.LayerNorm,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim**-0.5
+ # self.fused_attn = use_fused_attn()
+ self.fused_attn = True
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, N, C = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
+ .permute(2, 0, 3, 1, 4)
+ )
+ q, k, v = qkv.unbind(0)
+ q, k = self.q_norm(q), self.k_norm(k)
+
+ if self.fused_attn:
+ x = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ dropout_p=self.attn_drop.p if self.training else 0.0,
+ )
+ else:
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = attn @ v
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: float = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ qk_norm: bool = False,
+ proj_drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values: Optional[float] = None,
+ drop_path: float = 0.0,
+ act_layer: nn.Module = nn.GELU,
+ norm_layer: nn.Module = nn.LayerNorm,
+ mlp_layer: nn.Module = Mlp,
+ ) -> None:
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_norm=qk_norm,
+ attn_drop=attn_drop,
+ proj_drop=proj_drop,
+ norm_layer=norm_layer,
+ )
+ self.ls1 = (
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ )
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ self.mlp = mlp_layer(
+ in_features=dim,
+ hidden_features=int(dim * mlp_ratio),
+ act_layer=act_layer,
+ drop=proj_drop,
+ )
+ self.ls2 = (
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ )
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
+ return x
+
+
+class VisionTransformer(nn.Module):
+ """Vision Transformer
+
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
+ - https://arxiv.org/abs/2010.11929
+ """
+
+ dynamic_img_size: Final[bool]
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ num_classes: int = 1000,
+ global_pool: Literal["", "avg", "token", "map"] = "token",
+ embed_dim: int = 768,
+ depth: int = 12,
+ num_heads: int = 12,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = True,
+ qk_norm: bool = False,
+ init_values: Optional[float] = None,
+ class_token: bool = True,
+ no_embed_class: bool = False,
+ reg_tokens: int = 0,
+ pre_norm: bool = False,
+ fc_norm: Optional[bool] = None,
+ dynamic_img_size: bool = False,
+ dynamic_img_pad: bool = False,
+ drop_rate: float = 0.0,
+ pos_drop_rate: float = 0.0,
+ patch_drop_rate: float = 0.0,
+ proj_drop_rate: float = 0.0,
+ attn_drop_rate: float = 0.0,
+ drop_path_rate: float = 0.0,
+ weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "",
+ embed_layer: Callable = PatchEmbed,
+ norm_layer: Optional[LayerType] = None,
+ act_layer: Optional[LayerType] = None,
+ block_fn: Type[nn.Module] = Block,
+ mlp_layer: Type[nn.Module] = Mlp,
+ ignore_head: bool = False,
+ ) -> None:
+ """
+ Args:
+ img_size: Input image size.
+ patch_size: Patch size.
+ in_chans: Number of image input channels.
+ num_classes: Mumber of classes for classification head.
+ global_pool: Type of global pooling for final sequence (default: 'token').
+ embed_dim: Transformer embedding dimension.
+ depth: Depth of transformer.
+ num_heads: Number of attention heads.
+ mlp_ratio: Ratio of mlp hidden dim to embedding dim.
+ qkv_bias: Enable bias for qkv projections if True.
+ init_values: Layer-scale init values (layer-scale enabled if not None).
+ class_token: Use class token.
+ no_embed_class: Don't include position embeddings for class (or reg) tokens.
+ reg_tokens: Number of register tokens.
+ fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
+ drop_rate: Head dropout rate.
+ pos_drop_rate: Position embedding dropout rate.
+ attn_drop_rate: Attention dropout rate.
+ drop_path_rate: Stochastic depth rate.
+ weight_init: Weight initialization scheme.
+ embed_layer: Patch embedding layer.
+ norm_layer: Normalization layer.
+ act_layer: MLP activation layer.
+ block_fn: Transformer block layer.
+ """
+ super().__init__()
+ assert global_pool in ("", "avg", "token", "map")
+ assert class_token or global_pool != "token"
+ use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
+ # norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
+ # act_layer = get_act_layer(act_layer) or nn.GELU
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+ act_layer = nn.GELU
+
+ self.num_classes = num_classes
+ self.global_pool = global_pool
+ self.num_features = self.embed_dim = (
+ embed_dim # num_features for consistency with other models
+ )
+ self.num_prefix_tokens = 1 if class_token else 0
+ self.num_prefix_tokens += reg_tokens
+ self.num_reg_tokens = reg_tokens
+ self.has_class_token = class_token
+ self.no_embed_class = (
+ no_embed_class # don't embed prefix positions (includes reg)
+ )
+ self.dynamic_img_size = dynamic_img_size
+ self.grad_checkpointing = False
+ self.ignore_head = ignore_head
+
+ embed_args = {}
+ if dynamic_img_size:
+ # flatten deferred until after pos embed
+ embed_args.update(dict(strict_img_size=False, output_fmt="NHWC"))
+ self.patch_embed = embed_layer(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
+ dynamic_img_pad=dynamic_img_pad,
+ **embed_args,
+ )
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = (
+ nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
+ )
+ self.reg_token = (
+ nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
+ )
+ embed_len = (
+ num_patches if no_embed_class else num_patches + self.num_prefix_tokens
+ )
+ self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
+ self.pos_drop = nn.Dropout(p=pos_drop_rate)
+ if patch_drop_rate > 0:
+ self.patch_drop = PatchDropout(
+ patch_drop_rate,
+ num_prefix_tokens=self.num_prefix_tokens,
+ )
+ else:
+ self.patch_drop = nn.Identity()
+ self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
+
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
+ ] # stochastic depth decay rule
+ self.blocks = nn.Sequential(
+ *[
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_norm=qk_norm,
+ init_values=init_values,
+ proj_drop=proj_drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ mlp_layer=mlp_layer,
+ )
+ for i in range(depth)
+ ]
+ )
+ self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
+
+ # Classifier Head
+ if global_pool == "map":
+ AttentionPoolLatent.init_weights = init_weights
+ self.attn_pool = AttentionPoolLatent(
+ self.embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ norm_layer=norm_layer,
+ )
+ else:
+ self.attn_pool = None
+ self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
+ self.head_drop = nn.Dropout(drop_rate)
+ self.head = (
+ nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+ )
+
+ if weight_init != "skip":
+ self.init_weights(weight_init)
+
+ def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None:
+ assert mode in ("jax", "jax_nlhb", "moco", "")
+ # head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
+ trunc_normal_(self.pos_embed, std=0.02)
+ if self.cls_token is not None:
+ nn.init.normal_(self.cls_token, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ @torch.jit.ignore
+ def no_weight_decay(self) -> Set:
+ return {"pos_embed", "cls_token", "dist_token"}
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse: bool = False) -> Dict:
+ return dict(
+ stem=r"^cls_token|pos_embed|patch_embed", # stem and embed
+ blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))],
+ )
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable: bool = True) -> None:
+ self.grad_checkpointing = enable
+
+ @torch.jit.ignore
+ def get_classifier(self) -> nn.Module:
+ return self.head
+
+ def reset_classifier(self, num_classes: int, global_pool=None) -> None:
+ self.num_classes = num_classes
+ if global_pool is not None:
+ assert global_pool in ("", "avg", "token", "map")
+ if global_pool == "map" and self.attn_pool is None:
+ assert (
+ False
+ ), "Cannot currently add attention pooling in reset_classifier()."
+ elif global_pool != "map " and self.attn_pool is not None:
+ self.attn_pool = None # remove attention pooling
+ self.global_pool = global_pool
+ self.head = (
+ nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+ )
+
+ def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
+ if self.dynamic_img_size:
+ B, H, W, C = x.shape
+ pos_embed = resample_abs_pos_embed(
+ self.pos_embed,
+ (H, W),
+ num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
+ )
+ x = x.view(B, -1, C)
+ else:
+ pos_embed = self.pos_embed
+
+ to_cat = []
+ if self.cls_token is not None:
+ to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
+ if self.reg_token is not None:
+ to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
+
+ if self.no_embed_class:
+ # deit-3, updated JAX (big vision)
+ # position embedding does not overlap with class token, add then concat
+ x = x + pos_embed
+ if to_cat:
+ x = torch.cat(to_cat + [x], dim=1)
+ else:
+ # original timm, JAX, and deit vit impl
+ # pos_embed has entry for class token, concat then add
+ if to_cat:
+ x = torch.cat(to_cat + [x], dim=1)
+ x = x + pos_embed
+
+ return self.pos_drop(x)
+
+ def _intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1,
+ ) -> List[torch.Tensor]:
+ outputs, num_blocks = [], len(self.blocks)
+ take_indices = set(
+ range(num_blocks - n, num_blocks) if isinstance(n, int) else n
+ )
+
+ # forward pass
+ x = self.patch_embed(x)
+ x = self._pos_embed(x)
+ x = self.patch_drop(x)
+ x = self.norm_pre(x)
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in take_indices:
+ outputs.append(x)
+
+ return outputs
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1,
+ reshape: bool = False,
+ return_prefix_tokens: bool = False,
+ norm: bool = False,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ """Intermediate layer accessor (NOTE: This is a WIP experiment).
+ Inspired by DINO / DINOv2 interface
+ """
+ # take last n blocks if n is an int, if in is a sequence, select by matching indices
+ outputs = self._intermediate_layers(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs]
+ outputs = [out[:, self.num_prefix_tokens :] for out in outputs]
+
+ if reshape:
+ grid_size = self.patch_embed.grid_size
+ outputs = [
+ out.reshape(x.shape[0], grid_size[0], grid_size[1], -1)
+ .permute(0, 3, 1, 2)
+ .contiguous()
+ for out in outputs
+ ]
+
+ if return_prefix_tokens:
+ return tuple(zip(outputs, prefix_tokens))
+ return tuple(outputs)
+
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.patch_embed(x)
+ x = self._pos_embed(x)
+ x = self.patch_drop(x)
+ x = self.norm_pre(x)
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint_seq(self.blocks, x)
+ else:
+ x = self.blocks(x)
+ x = self.norm(x)
+ return x
+
+ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
+ if self.attn_pool is not None:
+ x = self.attn_pool(x)
+ elif self.global_pool == "avg":
+ x = x[:, self.num_prefix_tokens :].mean(dim=1)
+ elif self.global_pool:
+ x = x[:, 0] # class token
+ x = self.fc_norm(x)
+ x = self.head_drop(x)
+ return x if pre_logits else self.head(x)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.forward_features(x)
+ if not self.ignore_head:
+ x = self.forward_head(x)
+ return x
+
+
+@dataclass
+class SigLIPVisionCfg:
+ width: int = 1152
+ layers: Union[Tuple[int, int, int, int], int] = 27
+ heads: int = 16
+ patch_size: int = 14
+ image_size: Union[Tuple[int, int], int] = 336
+ global_pool: str = "map"
+ mlp_ratio: float = 3.7362
+ class_token: bool = False
+ num_classes: int = 0
+ use_checkpoint: bool = False
+
+
+SigLIP_MODEL_CONFIG = {
+ "siglip_so400m_patch14_384": {
+ "image_size": 336,
+ "patch_size": 14,
+ "width": 1152,
+ "layers": 27,
+ "heads": 16,
+ "mlp_ratio": 3.7362,
+ "global_pool": "map",
+ "use_checkpoint": False,
+ },
+ "siglip_so400m_patch14_224": {
+ "image_size": 224,
+ "patch_size": 14,
+ "width": 1152,
+ "layers": 27,
+ "heads": 16,
+ "mlp_ratio": 3.7362,
+ "global_pool": "map",
+ "use_checkpoint": False,
+ },
+ "siglip_large_patch16_384": {
+ "image_size": 384,
+ "patch_size": 16,
+ "width": 1024,
+ "layers": 24,
+ "heads": 16,
+ "mlp_ratio": 4,
+ "global_pool": "map",
+ "use_checkpoint": False,
+ },
+}
+
+
+def create_siglip_vit(
+ model_name: str = "siglip_so400m_patch14_384",
+ image_size: int = 384,
+ select_layer: int = -1,
+ ckpt_path: str = "",
+ **kwargs,
+):
+ assert (
+ model_name in SigLIP_MODEL_CONFIG.keys()
+ ), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}"
+
+ vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name])
+
+ if select_layer <= 0:
+ layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1)
+ else:
+ layers = min(vision_cfg.layers, select_layer)
+
+ model = VisionTransformer(
+ img_size=image_size,
+ patch_size=vision_cfg.patch_size,
+ embed_dim=vision_cfg.width,
+ depth=layers,
+ num_heads=vision_cfg.heads,
+ mlp_ratio=vision_cfg.mlp_ratio,
+ class_token=vision_cfg.class_token,
+ global_pool=vision_cfg.global_pool,
+ ignore_head=kwargs.get("ignore_head", True),
+ weight_init=kwargs.get("weight_init", "skip"),
+ num_classes=0,
+ )
+
+ if ckpt_path:
+ state_dict = torch.load(ckpt_path, map_location="cpu")
+
+ incompatible_keys = model.load_state_dict(state_dict, strict=False)
+ print(
+ f"SigLIP-ViT restores from {ckpt_path},\n"
+ f"\tincompatible_keys:', {incompatible_keys}."
+ )
+
+ return model
diff --git a/fudoki/janus/models/vq_model.py b/fudoki/janus/models/vq_model.py
new file mode 100644
index 0000000..90a47e2
--- /dev/null
+++ b/fudoki/janus/models/vq_model.py
@@ -0,0 +1,527 @@
+# Copyright (c) 2023-2024 DeepSeek.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy of
+# this software and associated documentation files (the "Software"), to deal in
+# the Software without restriction, including without limitation the rights to
+# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
+# the Software, and to permit persons to whom the Software is furnished to do so,
+# subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
+# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
+# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+
+from dataclasses import dataclass, field
+from typing import List
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from functools import partial
+
+
+@dataclass
+class ModelArgs:
+ codebook_size: int = 16384
+ codebook_embed_dim: int = 8
+ codebook_l2_norm: bool = True
+ codebook_show_usage: bool = True
+ commit_loss_beta: float = 0.25
+ entropy_loss_ratio: float = 0.0
+
+ encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
+ decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
+ z_channels: int = 256
+ dropout_p: float = 0.0
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ in_channels=3,
+ ch=128,
+ ch_mult=(1, 1, 2, 2, 4),
+ num_res_blocks=2,
+ norm_type="group",
+ dropout=0.0,
+ resamp_with_conv=True,
+ z_channels=256,
+ ):
+ super().__init__()
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1)
+
+ # downsampling
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.conv_blocks = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ conv_block = nn.Module()
+ # res & attn
+ res_block = nn.ModuleList()
+ attn_block = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for _ in range(self.num_res_blocks):
+ res_block.append(
+ ResnetBlock(
+ block_in, block_out, dropout=dropout, norm_type=norm_type
+ )
+ )
+ block_in = block_out
+ if i_level == self.num_resolutions - 1:
+ attn_block.append(AttnBlock(block_in, norm_type))
+ conv_block.res = res_block
+ conv_block.attn = attn_block
+ # downsample
+ if i_level != self.num_resolutions - 1:
+ conv_block.downsample = Downsample(block_in, resamp_with_conv)
+ self.conv_blocks.append(conv_block)
+
+ # middle
+ self.mid = nn.ModuleList()
+ self.mid.append(
+ ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
+ )
+ self.mid.append(AttnBlock(block_in, norm_type=norm_type))
+ self.mid.append(
+ ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
+ )
+
+ # end
+ self.norm_out = Normalize(block_in, norm_type)
+ self.conv_out = nn.Conv2d(
+ block_in, z_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, x):
+ h = self.conv_in(x)
+ # downsampling
+ for i_level, block in enumerate(self.conv_blocks):
+ for i_block in range(self.num_res_blocks):
+ h = block.res[i_block](h)
+ if len(block.attn) > 0:
+ h = block.attn[i_block](h)
+ if i_level != self.num_resolutions - 1:
+ h = block.downsample(h)
+
+ # middle
+ for mid_block in self.mid:
+ h = mid_block(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ z_channels=256,
+ ch=128,
+ ch_mult=(1, 1, 2, 2, 4),
+ num_res_blocks=2,
+ norm_type="group",
+ dropout=0.0,
+ resamp_with_conv=True,
+ out_channels=3,
+ ):
+ super().__init__()
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ # z to block_in
+ self.conv_in = nn.Conv2d(
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
+ )
+
+ # middle
+ self.mid = nn.ModuleList()
+ self.mid.append(
+ ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
+ )
+ self.mid.append(AttnBlock(block_in, norm_type=norm_type))
+ self.mid.append(
+ ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
+ )
+
+ # upsampling
+ self.conv_blocks = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ conv_block = nn.Module()
+ # res & attn
+ res_block = nn.ModuleList()
+ attn_block = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for _ in range(self.num_res_blocks + 1):
+ res_block.append(
+ ResnetBlock(
+ block_in, block_out, dropout=dropout, norm_type=norm_type
+ )
+ )
+ block_in = block_out
+ if i_level == self.num_resolutions - 1:
+ attn_block.append(AttnBlock(block_in, norm_type))
+ conv_block.res = res_block
+ conv_block.attn = attn_block
+ # downsample
+ if i_level != 0:
+ conv_block.upsample = Upsample(block_in, resamp_with_conv)
+ self.conv_blocks.append(conv_block)
+
+ # end
+ self.norm_out = Normalize(block_in, norm_type)
+ self.conv_out = nn.Conv2d(
+ block_in, out_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ @property
+ def last_layer(self):
+ return self.conv_out.weight
+
+ def forward(self, z):
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ for mid_block in self.mid:
+ h = mid_block(h)
+
+ # upsampling
+ for i_level, block in enumerate(self.conv_blocks):
+ for i_block in range(self.num_res_blocks + 1):
+ h = block.res[i_block](h)
+ if len(block.attn) > 0:
+ h = block.attn[i_block](h)
+ if i_level != self.num_resolutions - 1:
+ h = block.upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class VectorQuantizer(nn.Module):
+ def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage):
+ super().__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+ self.entropy_loss_ratio = entropy_loss_ratio
+ self.l2_norm = l2_norm
+ self.show_usage = show_usage
+
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+ if self.l2_norm:
+ self.embedding.weight.data = F.normalize(
+ self.embedding.weight.data, p=2, dim=-1
+ )
+ if self.show_usage:
+ self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536)))
+
+ def forward(self, z):
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = torch.einsum("b c h w -> b h w c", z).contiguous()
+ z_flattened = z.view(-1, self.e_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ if self.l2_norm:
+ z = F.normalize(z, p=2, dim=-1)
+ z_flattened = F.normalize(z_flattened, p=2, dim=-1)
+ embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
+ else:
+ embedding = self.embedding.weight
+
+ d = (
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
+ + torch.sum(embedding**2, dim=1)
+ - 2
+ * torch.einsum(
+ "bd,dn->bn", z_flattened, torch.einsum("n d -> d n", embedding)
+ )
+ )
+
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_q = embedding[min_encoding_indices].view(z.shape)
+ perplexity = None
+ min_encodings = None
+ vq_loss = None
+ commit_loss = None
+ entropy_loss = None
+
+ # compute loss for embedding
+ if self.training:
+ vq_loss = torch.mean((z_q - z.detach()) ** 2)
+ commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2)
+ entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ z_q = torch.einsum("b h w c -> b c h w", z_q)
+
+ return (
+ z_q,
+ (vq_loss, commit_loss, entropy_loss),
+ (perplexity, min_encodings, min_encoding_indices),
+ )
+
+ def get_codebook_entry(self, indices, shape=None, channel_first=True):
+ # shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel)
+ if self.l2_norm:
+ embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
+ else:
+ embedding = self.embedding.weight
+ z_q = embedding[indices] # (b*h*w, c)
+
+ if shape is not None:
+ if channel_first:
+ z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1])
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+ else:
+ z_q = z_q.view(shape)
+ return z_q
+
+
+class ResnetBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout=0.0,
+ norm_type="group",
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels, norm_type)
+ self.conv1 = nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ self.norm2 = Normalize(out_channels, norm_type)
+ self.dropout = nn.Dropout(dropout)
+ self.conv2 = nn.Conv2d(
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ else:
+ self.nin_shortcut = nn.Conv2d(
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(self, x):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+ return x + h
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels, norm_type="group"):
+ super().__init__()
+ self.norm = Normalize(in_channels, norm_type)
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.proj_out = nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h * w)
+ q = q.permute(0, 2, 1) # b,hw,c
+ k = k.reshape(b, c, h * w) # b,c,hw
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = F.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h * w)
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b, c, h, w)
+
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+def nonlinearity(x):
+ # swish
+ return x * torch.sigmoid(x)
+
+
+def Normalize(in_channels, norm_type="group"):
+ assert norm_type in ["group", "batch"]
+ if norm_type == "group":
+ return nn.GroupNorm(
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
+ )
+ elif norm_type == "batch":
+ return nn.SyncBatchNorm(in_channels)
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = nn.Conv2d(
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, x):
+ if x.dtype != torch.float32:
+ x = F.interpolate(x.to(torch.float), scale_factor=2.0, mode="nearest").to(
+ torch.float16
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
+
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = nn.Conv2d(
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
+ )
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0, 1, 0, 1)
+ x = F.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = F.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01):
+ flat_affinity = affinity.reshape(-1, affinity.shape[-1])
+ flat_affinity /= temperature
+ probs = F.softmax(flat_affinity, dim=-1)
+ log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1)
+ if loss_type == "softmax":
+ target_probs = probs
+ else:
+ raise ValueError("Entropy loss {} not supported".format(loss_type))
+ avg_probs = torch.mean(target_probs, dim=0)
+ avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-5))
+ sample_entropy = -torch.mean(torch.sum(target_probs * log_probs, dim=-1))
+ loss = sample_entropy - avg_entropy
+ return loss
+
+
+class VQModel(nn.Module):
+ def __init__(self, config: ModelArgs):
+ super().__init__()
+ self.config = config
+ self.encoder = Encoder(
+ ch_mult=config.encoder_ch_mult,
+ z_channels=config.z_channels,
+ dropout=config.dropout_p,
+ )
+ self.decoder = Decoder(
+ ch_mult=config.decoder_ch_mult,
+ z_channels=config.z_channels,
+ dropout=config.dropout_p,
+ )
+
+ self.quantize = VectorQuantizer(
+ config.codebook_size,
+ config.codebook_embed_dim,
+ config.commit_loss_beta,
+ config.entropy_loss_ratio,
+ config.codebook_l2_norm,
+ config.codebook_show_usage,
+ )
+ self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)
+ self.post_quant_conv = nn.Conv2d(
+ config.codebook_embed_dim, config.z_channels, 1
+ )
+
+ def encode(self, x):
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ quant, emb_loss, info = self.quantize(h)
+ return quant, emb_loss, info
+
+ def decode(self, quant):
+ quant = self.post_quant_conv(quant)
+ dec = self.decoder(quant)
+ return dec
+
+ def decode_code(self, code_b, shape=None, channel_first=True):
+ quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first)
+ dec = self.decode(quant_b)
+ return dec
+
+ def forward(self, input):
+ quant, diff, _ = self.encode(input)
+ dec = self.decode(quant)
+ return dec, diff
+
+
+#################################################################################
+# VQ Model Configs #
+#################################################################################
+def VQ_16(**kwargs):
+ return VQModel(
+ ModelArgs(
+ encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs
+ )
+ )
+
+
+VQ_models = {"VQ-16": VQ_16}
diff --git a/fudoki/janus/utils/__init__.py b/fudoki/janus/utils/__init__.py
new file mode 100644
index 0000000..8cb7640
--- /dev/null
+++ b/fudoki/janus/utils/__init__.py
@@ -0,0 +1,18 @@
+# Copyright (c) 2023-2024 DeepSeek.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy of
+# this software and associated documentation files (the "Software"), to deal in
+# the Software without restriction, including without limitation the rights to
+# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
+# the Software, and to permit persons to whom the Software is furnished to do so,
+# subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
+# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
+# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
diff --git a/fudoki/janus/utils/conversation.py b/fudoki/janus/utils/conversation.py
new file mode 100644
index 0000000..3b227fe
--- /dev/null
+++ b/fudoki/janus/utils/conversation.py
@@ -0,0 +1,337 @@
+# Copyright (c) 2023-2024 DeepSeek.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy of
+# this software and associated documentation files (the "Software"), to deal in
+# the Software without restriction, including without limitation the rights to
+# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
+# the Software, and to permit persons to whom the Software is furnished to do so,
+# subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
+# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
+# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+"""
+From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
+"""
+
+import dataclasses
+from enum import IntEnum, auto
+from typing import Dict, List
+
+
+class SeparatorStyle(IntEnum):
+ """Separator styles."""
+
+ ADD_COLON_SINGLE = auto()
+ ADD_COLON_TWO = auto()
+ ADD_COLON_SPACE_SINGLE = auto()
+ NO_COLON_SINGLE = auto()
+ NO_COLON_TWO = auto()
+ ADD_NEW_LINE_SINGLE = auto()
+ LLAMA2 = auto()
+ CHATGLM = auto()
+ CHATML = auto()
+ CHATINTERN = auto()
+ DOLLY = auto()
+ RWKV = auto()
+ PHOENIX = auto()
+ ROBIN = auto()
+ DeepSeek = auto()
+ PLAIN = auto()
+ ALIGNMENT = auto()
+
+
+@dataclasses.dataclass
+class Conversation:
+ """A class that manages prompt templates and keeps all conversation history."""
+
+ # The name of this template
+ name: str
+ # The template of the system prompt
+ system_template: str = "{system_message}"
+ # The system message
+ system_message: str = ""
+ # The names of two roles
+ roles: List[str] = (("USER", "ASSISTANT"),)
+ # All messages. Each item is (role, message).
+ messages: List[List[str]] = ()
+ # The number of few shot examples
+ offset: int = 0
+ # The separator style and configurations
+ sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
+ sep: str = "\n"
+ sep2: str = None
+ # Stop criteria (the default one is EOS token)
+ stop_str: str = None
+ # Stops generation if meeting any token in this list
+ stop_token_ids: List[int] = None
+
+ def get_prompt(self) -> str:
+ """Get the prompt for generation."""
+ system_prompt = self.system_template.format(system_message=self.system_message)
+
+ if self.sep_style == SeparatorStyle.DeepSeek:
+ seps = [self.sep, self.sep2]
+ if system_prompt == "" or system_prompt is None:
+ ret = ""
+ else:
+ ret = system_prompt + seps[0]
+ for i, (role, message) in enumerate(self.messages):
+ if message:
+ ret += role + ": " + message + seps[i % 2]
+ else:
+ ret += role + ":"
+ return ret
+ elif self.sep_style == SeparatorStyle.LLAMA2:
+ seps = [self.sep, self.sep2]
+ if self.system_message:
+ ret = system_prompt
+ else:
+ ret = "[INST] "
+ for i, (role, message) in enumerate(self.messages):
+ tag = self.roles[i % 2]
+ if message:
+ if type(message) is tuple: # multimodal message
+ message, _ = message
+ if i == 0:
+ ret += message + " "
+ else:
+ ret += tag + " " + message + seps[i % 2]
+ else:
+ ret += tag
+ return ret
+ elif self.sep_style == SeparatorStyle.PLAIN:
+ seps = [self.sep, self.sep2]
+ ret = ""
+ for i, (role, message) in enumerate(self.messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ if i % 2 == 0:
+ ret += message + seps[i % 2]
+ else:
+ ret += message + seps[i % 2]
+ else:
+ ret += ""
+ return ret
+ elif self.sep_style == SeparatorStyle.ALIGNMENT:
+ seps = [self.sep, self.sep2]
+ ret = ""
+ for i, (role, message) in enumerate(self.messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ if i % 2 == 0:
+ ret += "\n" + seps[i % 2]
+ else:
+ ret += message + seps[i % 2]
+ else:
+ ret += ""
+ return ret
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+ def get_prompt_for_current_round(self, content=None):
+ """Get current round formatted question prompt during sft training"""
+ if self.sep_style == SeparatorStyle.PLAIN:
+ formatted_question = "\n"
+ elif self.sep_style == SeparatorStyle.DeepSeek:
+ formatted_question = (
+ f"{self.roles[0]}: " + content.strip() + self.sep + f"{self.roles[1]}:"
+ )
+ else:
+ raise ValueError(f"Unsupported sep_style: {self.sep_style}")
+ return formatted_question
+
+ def set_system_message(self, system_message: str):
+ """Set the system message."""
+ self.system_message = system_message
+
+ def append_message(self, role: str, message: str):
+ """Append a new message."""
+ self.messages.append([role, message])
+
+ def reset_message(self):
+ """Reset a new message."""
+ self.messages = []
+
+ def update_last_message(self, message: str):
+ """Update the last output.
+
+ The last message is typically set to be None when constructing the prompt,
+ so we need to update it in-place after getting the response from a model.
+ """
+ self.messages[-1][1] = message
+
+ def to_gradio_chatbot(self):
+ """Convert the conversation to gradio chatbot format."""
+ ret = []
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
+ if i % 2 == 0:
+ ret.append([msg, None])
+ else:
+ ret[-1][-1] = msg
+ return ret
+
+ def to_openai_api_messages(self):
+ """Convert the conversation to OpenAI chat completion format."""
+ system_prompt = self.system_template.format(system_message=self.system_message)
+ ret = [{"role": "system", "content": system_prompt}]
+
+ for i, (_, msg) in enumerate(self.messages[self.offset :]):
+ if i % 2 == 0:
+ ret.append({"role": "user", "content": msg})
+ else:
+ if msg is not None:
+ ret.append({"role": "assistant", "content": msg})
+ return ret
+
+ def copy(self):
+ return Conversation(
+ name=self.name,
+ system_template=self.system_template,
+ system_message=self.system_message,
+ roles=self.roles,
+ messages=[[x, y] for x, y in self.messages],
+ offset=self.offset,
+ sep_style=self.sep_style,
+ sep=self.sep,
+ sep2=self.sep2,
+ stop_str=self.stop_str,
+ stop_token_ids=self.stop_token_ids,
+ )
+
+ def dict(self):
+ return {
+ "template_name": self.name,
+ "system_message": self.system_message,
+ "roles": self.roles,
+ "messages": self.messages,
+ "offset": self.offset,
+ }
+
+
+# A global registry for all conversation templates
+conv_templates: Dict[str, Conversation] = {}
+
+
+def register_conv_template(template: Conversation, override: bool = False):
+ """Register a new conversation template."""
+ if not override:
+ assert (
+ template.name not in conv_templates
+ ), f"{template.name} has been registered."
+
+ conv_templates[template.name] = template
+
+
+def get_conv_template(name: str) -> Conversation:
+ """Get a conversation template."""
+ return conv_templates[name].copy()
+
+
+# llava_llama2 template
+register_conv_template(
+ Conversation(
+ name="llava_llama2",
+ system_message="You are a helpful language and vision assistant. "
+ "You are able to understand the visual content that the user provides, "
+ "and assist the user with a variety of tasks using natural language.",
+ system_template="[INST] <>\n{system_message}\n<>\n\n",
+ roles=("[INST]", "[/INST]"),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.LLAMA2,
+ sep=" ",
+ sep2=" ",
+ stop_token_ids=[2],
+ )
+)
+
+# llama2 template
+# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
+register_conv_template(
+ Conversation(
+ name="llama-2",
+ system_template="[INST] <>\n{system_message}\n<>\n\n",
+ roles=("[INST]", "[/INST]"),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.LLAMA2,
+ sep=" ",
+ sep2=" ",
+ stop_token_ids=[2],
+ )
+)
+
+
+# deepseek template
+register_conv_template(
+ Conversation(
+ name="fudoki",
+ system_template="{system_message}",
+ # system_message="You are a helpful assistant. Please answer truthfully and write out your "
+ # "thinking step by step to be sure you get the right answer.",
+ system_message="",
+ roles=("User", "Assistant"),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.DeepSeek,
+ sep="\n\n",
+ sep2="<ο½endβofβsentenceο½>",
+ stop_token_ids=[100001],
+ stop_str=["User:", "<ο½endβofβsentenceο½>"],
+ )
+)
+
+register_conv_template(
+ Conversation(
+ name="plain",
+ system_template="",
+ system_message="",
+ roles=("", ""),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.PLAIN,
+ sep="",
+ sep2="",
+ stop_token_ids=[2],
+ stop_str=[""],
+ )
+)
+
+
+register_conv_template(
+ Conversation(
+ name="alignment",
+ system_template="",
+ system_message="",
+ roles=("", ""),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.ALIGNMENT,
+ sep="",
+ sep2="",
+ stop_token_ids=[2],
+ stop_str=[""],
+ )
+)
+
+
+if __name__ == "__main__":
+ print("fudoki template:")
+ conv = get_conv_template("fudoki")
+ conv.append_message(conv.roles[0], "Hello!")
+ conv.append_message(conv.roles[1], "Hi! This is Tony.")
+ conv.append_message(conv.roles[0], "Who are you?")
+ conv.append_message(conv.roles[1], "I am a helpful assistant.")
+ conv.append_message(conv.roles[0], "How are you?")
+ conv.append_message(conv.roles[1], None)
+ print(conv.get_prompt())
diff --git a/fudoki/janus/utils/io.py b/fudoki/janus/utils/io.py
new file mode 100644
index 0000000..e8e9da7
--- /dev/null
+++ b/fudoki/janus/utils/io.py
@@ -0,0 +1,89 @@
+# Copyright (c) 2023-2024 DeepSeek.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy of
+# this software and associated documentation files (the "Software"), to deal in
+# the Software without restriction, including without limitation the rights to
+# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
+# the Software, and to permit persons to whom the Software is furnished to do so,
+# subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
+# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
+# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+import json
+from typing import Dict, List
+
+import PIL.Image
+import torch
+import base64
+import io
+from transformers import AutoModelForCausalLM
+
+from janus.models import MultiModalityCausalLM, VLChatProcessor
+
+
+def load_pretrained_model(model_path: str):
+ vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
+ tokenizer = vl_chat_processor.tokenizer
+
+ vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
+ model_path, trust_remote_code=True
+ )
+ vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
+
+ return tokenizer, vl_chat_processor, vl_gpt
+
+
+def load_pil_images(conversations: List[Dict[str, str]]) -> List[PIL.Image.Image]:
+ """
+
+ Support file path or base64 images.
+
+ Args:
+ conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is :
+ [
+ {
+ "role": "User",
+ "content": "\nExtract all information from this image and convert them into markdown format.",
+ "images": ["./examples/table_datasets.png"]
+ },
+ {"role": "Assistant", "content": ""},
+ ]
+
+ Returns:
+ pil_images (List[PIL.Image.Image]): the list of PIL images.
+
+ """
+
+ pil_images = []
+
+ for message in conversations:
+ if "images" not in message:
+ continue
+
+ for image_data in message["images"]:
+ if image_data.startswith("data:image"):
+ # Image data is in base64 format
+ _, image_data = image_data.split(",", 1)
+ image_bytes = base64.b64decode(image_data)
+ pil_img = PIL.Image.open(io.BytesIO(image_bytes))
+ else:
+ # Image data is a file path
+ pil_img = PIL.Image.open(image_data)
+ pil_img = pil_img.convert("RGB")
+ pil_images.append(pil_img)
+
+ return pil_images
+
+
+def load_json(filepath):
+ with open(filepath, "r") as f:
+ data = json.load(f)
+ return data
diff --git a/fudoki/model.py b/fudoki/model.py
new file mode 100644
index 0000000..4cfe97e
--- /dev/null
+++ b/fudoki/model.py
@@ -0,0 +1,11 @@
+from transformers import AutoModelForCausalLM
+from fudoki.janus.models import MultiModalityCausalLM
+
+
+def instantiate_model(pretrained_weight_path):
+
+ vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
+ pretrained_weight_path, trust_remote_code=True
+ )
+ model = vl_gpt
+ return model
diff --git a/fudoki/vq_model.py b/fudoki/vq_model.py
new file mode 100644
index 0000000..4ee2b48
--- /dev/null
+++ b/fudoki/vq_model.py
@@ -0,0 +1,424 @@
+# Inherited from LlamaGen (https://arxiv.org/abs/2406.06525), which is modified from:
+# taming-transformers: https://github.com/CompVis/taming-transformers
+# maskgit: https://github.com/google-research/maskgit
+from dataclasses import dataclass, field
+from typing import List
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+@dataclass
+class ModelArgs:
+ codebook_size: int = 16384
+ codebook_embed_dim: int = 8
+ codebook_l2_norm: bool = True
+ codebook_show_usage: bool = True
+ commit_loss_beta: float = 0.25
+ entropy_loss_ratio: float = 0.0
+
+ encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
+ decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
+ z_channels: int = 256
+ dropout_p: float = 0.0
+
+
+
+class VQModel(nn.Module):
+ def __init__(self, config: ModelArgs):
+ super().__init__()
+ self.config = config
+ self.encoder = Encoder(ch_mult=config.encoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p)
+ self.decoder = Decoder(ch_mult=config.decoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p)
+
+ self.quantize = VectorQuantizer(config.codebook_size, config.codebook_embed_dim,
+ config.commit_loss_beta, config.entropy_loss_ratio,
+ config.codebook_l2_norm, config.codebook_show_usage)
+ self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)
+ self.post_quant_conv = nn.Conv2d(config.codebook_embed_dim, config.z_channels, 1)
+
+ def encode(self, x):
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ quant, emb_loss, info = self.quantize(h)
+ return quant, emb_loss, info
+
+ def decode(self, quant):
+ quant = self.post_quant_conv(quant)
+ dec = self.decoder(quant)
+ return dec
+
+ def decode_code(self, code_b, shape=None, channel_first=True):
+ quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first)
+ dec = self.decode(quant_b)
+ return dec
+
+ def forward(self, input):
+ quant, diff, _ = self.encode(input)
+ dec = self.decode(quant)
+ return dec, diff
+
+
+
+class Encoder(nn.Module):
+ def __init__(self, in_channels=3, ch=128, ch_mult=(1,1,2,2,4), num_res_blocks=2,
+ norm_type='group', dropout=0.0, resamp_with_conv=True, z_channels=256):
+ super().__init__()
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1)
+
+ # downsampling
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.conv_blocks = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ conv_block = nn.Module()
+ # res & attn
+ res_block = nn.ModuleList()
+ attn_block = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for _ in range(self.num_res_blocks):
+ res_block.append(ResnetBlock(block_in, block_out, dropout=dropout, norm_type=norm_type))
+ block_in = block_out
+ if i_level == self.num_resolutions - 1:
+ attn_block.append(AttnBlock(block_in, norm_type))
+ conv_block.res = res_block
+ conv_block.attn = attn_block
+ # downsample
+ if i_level != self.num_resolutions-1:
+ conv_block.downsample = Downsample(block_in, resamp_with_conv)
+ self.conv_blocks.append(conv_block)
+
+ # middle
+ self.mid = nn.ModuleList()
+ self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
+ self.mid.append(AttnBlock(block_in, norm_type=norm_type))
+ self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
+
+ # end
+ self.norm_out = Normalize(block_in, norm_type)
+ self.conv_out = nn.Conv2d(block_in, z_channels, kernel_size=3, stride=1, padding=1)
+
+
+ def forward(self, x):
+ h = self.conv_in(x)
+ # downsampling
+ for i_level, block in enumerate(self.conv_blocks):
+ for i_block in range(self.num_res_blocks):
+ h = block.res[i_block](h)
+ if len(block.attn) > 0:
+ h = block.attn[i_block](h)
+ if i_level != self.num_resolutions - 1:
+ h = block.downsample(h)
+
+ # middle
+ for mid_block in self.mid:
+ h = mid_block(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+
+class Decoder(nn.Module):
+ def __init__(self, z_channels=256, ch=128, ch_mult=(1,1,2,2,4), num_res_blocks=2, norm_type="group",
+ dropout=0.0, resamp_with_conv=True, out_channels=3):
+ super().__init__()
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+
+ block_in = ch*ch_mult[self.num_resolutions-1]
+ # z to block_in
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
+
+ # middle
+ self.mid = nn.ModuleList()
+ self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
+ self.mid.append(AttnBlock(block_in, norm_type=norm_type))
+ self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
+
+ # upsampling
+ self.conv_blocks = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ conv_block = nn.Module()
+ # res & attn
+ res_block = nn.ModuleList()
+ attn_block = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ for _ in range(self.num_res_blocks + 1):
+ res_block.append(ResnetBlock(block_in, block_out, dropout=dropout, norm_type=norm_type))
+ block_in = block_out
+ if i_level == self.num_resolutions - 1:
+ attn_block.append(AttnBlock(block_in, norm_type))
+ conv_block.res = res_block
+ conv_block.attn = attn_block
+ # downsample
+ if i_level != 0:
+ conv_block.upsample = Upsample(block_in, resamp_with_conv)
+ self.conv_blocks.append(conv_block)
+
+ # end
+ self.norm_out = Normalize(block_in, norm_type)
+ self.conv_out = nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
+
+ @property
+ def last_layer(self):
+ return self.conv_out.weight
+
+ def forward(self, z):
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ for mid_block in self.mid:
+ h = mid_block(h)
+
+ # upsampling
+ for i_level, block in enumerate(self.conv_blocks):
+ for i_block in range(self.num_res_blocks + 1):
+ h = block.res[i_block](h)
+ if len(block.attn) > 0:
+ h = block.attn[i_block](h)
+ if i_level != self.num_resolutions - 1:
+ h = block.upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class VectorQuantizer(nn.Module):
+ def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage):
+ super().__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+ self.entropy_loss_ratio = entropy_loss_ratio
+ self.l2_norm = l2_norm
+ self.show_usage = show_usage
+
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+ if self.l2_norm:
+ self.embedding.weight.data = F.normalize(self.embedding.weight.data, p=2, dim=-1)
+ if self.show_usage:
+ self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536)))
+
+
+ def forward(self, z):
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = torch.einsum('b c h w -> b h w c', z).contiguous()
+ z_flattened = z.view(-1, self.e_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ if self.l2_norm:
+ z = F.normalize(z, p=2, dim=-1)
+ z_flattened = F.normalize(z_flattened, p=2, dim=-1)
+ embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
+ else:
+ embedding = self.embedding.weight
+
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
+ torch.sum(embedding**2, dim=1) - 2 * \
+ torch.einsum('bd,dn->bn', z_flattened, torch.einsum('n d -> d n', embedding))
+
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_q = embedding[min_encoding_indices].view(z.shape)
+ perplexity = None
+ min_encodings = None
+ vq_loss = None
+ commit_loss = None
+ entropy_loss = None
+ codebook_usage = 0
+
+ if self.show_usage and self.training:
+ cur_len = min_encoding_indices.shape[0]
+ self.codebook_used[:-cur_len] = self.codebook_used[cur_len:].clone()
+ self.codebook_used[-cur_len:] = min_encoding_indices
+ codebook_usage = len(torch.unique(self.codebook_used)) / self.n_e
+
+ # compute loss for embedding
+ if self.training:
+ vq_loss = torch.mean((z_q - z.detach()) ** 2)
+ commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2)
+ entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ z_q = torch.einsum('b h w c -> b c h w', z_q)
+
+ return z_q, (vq_loss, commit_loss, entropy_loss, codebook_usage), (perplexity, min_encodings, min_encoding_indices)
+
+ def get_codebook_entry(self, indices, shape=None, channel_first=True):
+ # shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel)
+ if self.l2_norm:
+ embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
+ else:
+ embedding = self.embedding.weight
+ z_q = embedding[indices] # (b*h*w, c)
+
+ if shape is not None:
+ if channel_first:
+ z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1])
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+ else:
+ z_q = z_q.view(shape)
+ return z_q
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group'):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels, norm_type)
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.norm2 = Normalize(out_channels, norm_type)
+ self.dropout = nn.Dropout(dropout)
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ else:
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+ return x+h
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels, norm_type='group'):
+ super().__init__()
+ self.norm = Normalize(in_channels, norm_type)
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = q.reshape(b,c,h*w)
+ q = q.permute(0,2,1) # b,hw,c
+ k = k.reshape(b,c,h*w) # b,c,hw
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = F.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b,c,h*w)
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b,c,h,w)
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+def nonlinearity(x):
+ # swish
+ return x*torch.sigmoid(x)
+
+
+def Normalize(in_channels, norm_type='group'):
+ assert norm_type in ['group', 'batch']
+ if norm_type == 'group':
+ return nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+ elif norm_type == 'batch':
+ return nn.SyncBatchNorm(in_channels)
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x):
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0,1,0,1)
+ x = F.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = F.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01):
+ flat_affinity = affinity.reshape(-1, affinity.shape[-1])
+ flat_affinity /= temperature
+ probs = F.softmax(flat_affinity, dim=-1)
+ log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1)
+ if loss_type == "softmax":
+ target_probs = probs
+ else:
+ raise ValueError("Entropy loss {} not supported".format(loss_type))
+ avg_probs = torch.mean(target_probs, dim=0)
+ avg_entropy = - torch.sum(avg_probs * torch.log(avg_probs + 1e-5))
+ sample_entropy = - torch.mean(torch.sum(target_probs * log_probs, dim=-1))
+ loss = sample_entropy - avg_entropy
+ return loss
+
+
+#################################################################################
+# VQ Model Configs #
+#################################################################################
+def VQ_8(**kwargs):
+ return VQModel(ModelArgs(encoder_ch_mult=[1, 2, 2, 4], decoder_ch_mult=[1, 2, 2, 4], **kwargs))
+
+def VQ_16(**kwargs):
+ return VQModel(ModelArgs(encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs))
+
+VQ_models = {'VQ-16': VQ_16, 'VQ-8': VQ_8}
\ No newline at end of file
diff --git a/infer.py b/infer.py
new file mode 100644
index 0000000..051d12e
--- /dev/null
+++ b/infer.py
@@ -0,0 +1,332 @@
+# LINT_ME
+import os
+import argparse
+
+import numpy as np
+from PIL import Image
+import torch
+import torch.distributed as dist
+from torchvision import transforms
+from torch.backends import cudnn
+from transformers import set_seed
+
+from flow_matching.data.navsim import resize_pad
+from flow_matching.path import MixtureDiscreteSoftmaxProbPath
+from flow_matching.solver import MixtureDiscreteSoftmaxEulerSolver
+from fudoki.eval_loop import CFGScaledModel
+from fudoki.janus.models import VLChatProcessor
+from fudoki.model import instantiate_model
+
+
+VOCABULARY_SIZE_TXT = 102400
+VOCABULARY_SIZE_IMG = 16384
+IMG_LEN = 576
+
+
+def parse_arguments():
+ parser = argparse.ArgumentParser(description="Run the script with custom arguments.")
+ parser.add_argument(
+ "--seed", type=int, default=999,
+ help="Random seed for reproducibility."
+ )
+ parser.add_argument(
+ "--batch_size", type=int, default=1,
+ help="Batch size for processing."
+ )
+ parser.add_argument(
+ "--checkpoint_path", type=str, required=True,
+ help="Path to the checkpoint directory."
+ )
+ parser.add_argument(
+ "--processor_path", type=str, required=True,
+ help="Path to the processor."
+ )
+ parser.add_argument(
+ "--text_embedding_path", type=str, required=True,
+ help="Path to the text embedding."
+ )
+ parser.add_argument(
+ "--image_embedding_path", type=str, required=True,
+ help="Path to the image embedding."
+ )
+ parser.add_argument(
+ "--discrete_fm_steps", type=int, default=5,
+ help="Inference steps for discrete flow matching"
+ )
+ parser.add_argument(
+ "--txt_max_length", type=int, default=500,
+ help="Text length maximum"
+ )
+ parser.add_argument(
+ "--image_paths", type=str, required=True,
+ help="Path to the input image."
+ )
+ parser.add_argument(
+ "--output_dir", type=str, required=False,
+ help="Directory to save the output files."
+ )
+ return parser.parse_args()
+
+
+def extract_number_pairs(input_str, k=None):
+ """
+ Extract number pairs from a string and form tuples of (float(x), float(y))
+
+ Parameters:
+ input_str (str): Input string in the format "4.78,-0.01,9.69,-0.01,..."
+ k (int, optional): Number of tuples to extract. If None, extract all possible tuples
+
+ Returns:
+ tuple: Contains two elements
+ - list: List of successfully extracted tuples
+ - list: List of error messages
+ """
+ pairs = []
+ errors = []
+
+ # Check if input is empty or not a string
+ if not input_str or not isinstance(input_str, str):
+ errors.append("Input is not a valid string")
+ return pairs, errors
+
+ # Split the string
+ elements = input_str.split(',')
+
+ # Process each element and attempt to convert to float
+ numbers = []
+ for index, elem in enumerate(elements):
+ # Remove possible whitespace characters
+ elem_clean = elem.strip()
+ if not elem_clean:
+ errors.append(f"Empty value at position {index}")
+ continue
+
+ try:
+ num = float(elem_clean)
+ numbers.append(num)
+ except ValueError:
+ errors.append(f"Value '{elem_clean}' at position {index} \
+ cannot be converted to a number")
+
+ # Calculate maximum possible pairs
+ max_possible = len(numbers) // 2
+
+ # Determine number of pairs to extract
+ if k is None:
+ # Extract all possible pairs
+ num_to_extract = max_possible
+ else:
+ # Ensure k is a positive integer
+ try:
+ k = int(k)
+ if k <= 0:
+ errors.append(f"k value {k} must be a positive integer")
+ return pairs, errors
+ num_to_extract = min(k, max_possible)
+ except (ValueError, TypeError):
+ errors.append(f"k value {k} is not a valid integer")
+ return pairs, errors
+
+ # Extract number pairs
+ for i in range(num_to_extract):
+ x = numbers[2*i]
+ y = numbers[2*i + 1]
+ pairs.append((x, y))
+
+ # Check for unpaired numbers
+ if len(numbers) % 2 != 0:
+ errors.append(f"There are {len(numbers) % 2} unpaired number(s)")
+
+ # Check if requested k value was achieved
+ if k is not None and num_to_extract < k:
+ errors.append(f"Only {num_to_extract} valid number pairs can be extracted, \
+ less than the requested {k}")
+
+ return pairs, errors
+
+
+def main():
+ args = parse_arguments()
+
+ dist.init_process_group(
+ "nccl",
+ rank=int(os.environ["RANK"]),
+ world_size=int(os.environ["WORLD_SIZE"]),
+ )
+ local_rank = int(os.environ["LOCAL_RANK"])
+ torch.cuda.set_device(local_rank)
+ set_seed(args.seed)
+ cudnn.benchmark = True
+ device = torch.device(f"cuda:{local_rank}")
+
+ image_paths = args.image_paths.split(',')
+
+ vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(args.processor_path)
+ num_tokens_length = 0
+ num_tokens = [f"{x:.2f}" for x in np.linspace(-100, 100, 20001)]
+ num_tokens_length = len(num_tokens)
+ vl_chat_processor.tokenizer.add_tokens(num_tokens)
+
+ model = instantiate_model(args.checkpoint_path).to(device, dtype=torch.float32)
+ model.train(False)
+
+ batch_size = args.batch_size
+ discrete_fm_steps = args.discrete_fm_steps
+ txt_max_length = args.txt_max_length
+
+ cfg_weighted_model = CFGScaledModel(model=model, g_or_u='understanding')
+ with torch.no_grad():
+ path_txt = MixtureDiscreteSoftmaxProbPath(
+ mode='text',
+ embedding_path=args.text_embedding_path
+ )
+ path_txt.set_embedding(model.language_model.get_input_embeddings())
+ path_img = MixtureDiscreteSoftmaxProbPath(
+ mode='image',
+ embedding_path=args.image_embedding_path
+ )
+ solver = MixtureDiscreteSoftmaxEulerSolver(
+ model=cfg_weighted_model,
+ path_txt=path_txt,
+ path_img=path_img,
+ vocabulary_size_txt=VOCABULARY_SIZE_TXT + num_tokens_length,
+ vocabulary_size_img=VOCABULARY_SIZE_IMG,
+ )
+
+ imgs = []
+ if isinstance(image_paths, str):
+ image_paths = [image_paths]
+ for path in image_paths:
+ img = Image.open(path).convert("RGB")
+ transform = transforms.Compose([
+ transforms.Lambda(resize_pad),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
+ ])
+ imgs.append(transform(img))
+
+ if len(imgs) > 0:
+ imgs = torch.stack(imgs, dim=0) # [N, C, H, W]
+ img_len = len(imgs) * IMG_LEN
+ else:
+ imgs = None
+ img_len = IMG_LEN # default
+
+ generation_understanding_indicator = 0 # this is an understanding sample
+ conversation = [
+ {
+ "role": "User",
+ "content": (
+ "Here is front views of a driving vehicle:\n\n"
+ "The navigation information is: straight\n"
+ "The current position is (0.00,0.00)\n"
+ "Current velocity is: (8.34,0.18) and current accelerate is: (-0.83,0.28)\n"
+ "Predict the optimal driving action for the next 4 seconds with 8 new waypoints."
+ )
+ },
+ {
+ "role": "Assistant",
+ "content": ""
+ } # "3.88,-0.06,7.50,-0.07,10.86,-0.10,13.95,-0.11,16.75,-0.13,19.29,-0.15,21.60,-0.12,23.67,-0.11"
+ ]
+ sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
+ conversations=conversation,
+ sft_format=vl_chat_processor.sft_format,
+ system_prompt=vl_chat_processor.system_prompt,
+ )
+
+ # tokenize
+ input_ids = vl_chat_processor.tokenizer.encode(sft_format)
+ input_ids = torch.LongTensor(input_ids)
+ # add image tokens to the input_ids
+ image_token_mask = input_ids == vl_chat_processor.image_id
+ image_indices = image_token_mask.nonzero()
+ assert len(image_indices) == len(image_paths), \
+ f"Number of images ({len(image_paths)}) \
+ does not match the number of image tokens ({len(image_indices)})"
+
+ input_ids, _ = vl_chat_processor.add_image_token(
+ image_indices=image_indices,
+ input_ids=input_ids,
+ )
+
+ # pad tokens
+ original_input_id_len = input_ids.shape[0]
+ if original_input_id_len >= txt_max_length + img_len:
+ raise ValueError("Sentences too long, not supported so far...")
+
+ rows_to_pad = txt_max_length + img_len - input_ids.shape[0]
+ input_ids = torch.concat([
+ input_ids,
+ torch.LongTensor([vl_chat_processor.pad_id]).repeat(rows_to_pad)
+ ], dim=0)
+ attention_mask = torch.zeros((input_ids.shape[0]), dtype=torch.bool)
+ attention_mask[:] = True
+
+ # obtain image token mask and fill in img token_ids
+ if imgs is not None:
+ image_expanded_token_mask = (input_ids == vl_chat_processor.image_id).to(dtype=int)
+ image_expanded_mask_indices = torch.where(image_expanded_token_mask == 1)[0]
+ input_ids[image_expanded_mask_indices] = 0
+ else:
+ image_expanded_token_mask = torch.zeros_like(input_ids)
+
+ # obtain text token mask
+ # We assume that there is only one turn for assistant to respond
+ text_expanded_token_mask = torch.zeros_like(image_expanded_token_mask)
+ split_token = vl_chat_processor.tokenizer.encode("Assistant:", add_special_tokens=False)
+ split_token_length = len(split_token)
+
+ start_index = -1
+ for j in range(len(input_ids) - split_token_length + 1):
+ if input_ids[j:j + split_token_length].numpy().tolist() == split_token:
+ start_index = j
+ break
+ if start_index != -1:
+ text_expanded_token_mask[(start_index+split_token_length):] = 1
+ else:
+ raise ValueError("Split token not found in input_ids")
+
+ generation_or_understanding_mask = generation_understanding_indicator
+ data_info = {}
+ data_info['text_token_mask'] = text_expanded_token_mask.unsqueeze(0).repeat(batch_size, 1).to(device)
+ data_info['image_token_mask'] = image_expanded_token_mask.unsqueeze(0).repeat(batch_size, 1).to(device)
+ data_info['generation_or_understanding_mask'] = \
+ torch.Tensor([generation_or_understanding_mask]).unsqueeze(0).repeat(batch_size, 1).to(device).to(dtype=int)
+
+ data_info['attention_mask'] = attention_mask.unsqueeze(0).repeat(batch_size, 1).to(device)
+ data_info['sft_format'] = sft_format
+ if imgs is not None:
+ data_info['understanding_img'] = imgs.unsqueeze(0).to(device, dtype=torch.float32).repeat(batch_size, 1, 1, 1, 1)
+ data_info['has_understanding_img'] = torch.Tensor([True]).to(dtype=int).unsqueeze(0).repeat(batch_size, 1).to(device)
+ else:
+ data_info['understanding_img'] = torch.zeros((3, 384, 384)).unsqueeze(0).repeat(batch_size, 1, 1, 1).to(device)
+ data_info['has_understanding_img'] = torch.Tensor([False]).to(dtype=int).unsqueeze(0).repeat(batch_size, 1).to(device)
+ input_ids = torch.LongTensor(input_ids).unsqueeze(0).repeat(batch_size, 1).to(device)
+
+
+ x_0_txt = torch.randint(VOCABULARY_SIZE_TXT + num_tokens_length, input_ids.shape, dtype=torch.long, device=device)
+ x_init = x_0_txt * data_info['text_token_mask'] + input_ids * (1 - data_info['text_token_mask'])
+
+ synthetic_samples = solver.sample(
+ x_init=x_init,
+ step_size=1.0/discrete_fm_steps,
+ verbose=True,
+ return_intermediates=False,
+ div_free=0,
+ dtype_categorical=torch.float32,
+ datainfo=data_info,
+ cfg_scale=0,
+ )
+ sentence = vl_chat_processor.tokenizer.batch_decode(
+ synthetic_samples,
+ skip_special_tokens=True
+ )[0]
+ print("Sentence:", sentence)
+
+ waypoint = extract_number_pairs(sentence, k=8)[0]
+ print("Waypoint: ", waypoint)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..dab4802
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,17 @@
+omegaconf
+transformers==4.42.4
+timm
+tokenizers
+attrdict
+torch==2.0.1
+torchvision
+accelerate
+sentencepiece
+einops
+torchdiffeq
+matplotlib
+numpy==1.26.4
+icecream
+xformers==0.0.22
+diffusers==0.32.2
+deepspeed
\ No newline at end of file
diff --git a/script/infer.sh b/script/infer.sh
new file mode 100644
index 0000000..114b5f4
--- /dev/null
+++ b/script/infer.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+
+CKPT_PATH="pretrained_model/wam-flow/navsim"
+FUDOKI_PATH="pretrained_model/fudoki"
+IMAGE_PATH="data/navsim_data/sensor_blobs/test/2021.09.09.17.18.51_veh-48_00889_01147/CAM_F0/9a6f0331d98258a0.jpg"
+
+torchrun --nproc_per_node 1 infer.py \
+ --checkpoint_path $CKPT_PATH \
+ --image_path $IMAGE_PATH \
+ --processor_path $FUDOKI_PATH \
+ --text_embedding_path $FUDOKI_PATH/text_embedding.pt \
+ --image_embedding_path $FUDOKI_PATH/image_embedding.pt \
+ --discrete_fm_steps 2 \
+ --seed 123
\ No newline at end of file
diff --git a/script/sft_debug.sh b/script/sft_debug.sh
new file mode 100644
index 0000000..fd310a6
--- /dev/null
+++ b/script/sft_debug.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+NUM_NODES=1
+NUM_GPUS=1
+
+config=config/debug.yaml
+output_dir=output/train/debug
+
+accelerate launch \
+ --config_file ./config/accelerate_config_ds2.yaml \
+ --machine_rank 0 \
+ --main_process_port 12345 \
+ --num_machines $NUM_NODES \
+ --num_processes $NUM_GPUS \
+ train.py \
+ --config $config \
+ --output_dir $output_dir
diff --git a/script/sft_navsim.sh b/script/sft_navsim.sh
new file mode 100644
index 0000000..aedda23
--- /dev/null
+++ b/script/sft_navsim.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+NUM_NODES=1
+NUM_GPUS=1
+
+config=config/sft_navsim.yaml
+output_dir=output/train/debug
+
+accelerate launch \
+ --config_file ./config/accelerate_config_ds2.yaml \
+ --machine_rank 0 \
+ --main_process_port 12345 \
+ --num_machines $NUM_NODES \
+ --num_processes $NUM_GPUS \
+ train.py \
+ --config $config \
+ --output_dir $output_dir
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..6c8398f
--- /dev/null
+++ b/train.py
@@ -0,0 +1,451 @@
+# LINT_ME
+import argparse
+import os
+import logging
+import shutil
+import copy
+import random
+
+import numpy as np
+import transformers
+import torch
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+from torch.nn import CrossEntropyLoss
+import diffusers
+from diffusers.optimization import get_scheduler
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import set_seed
+from omegaconf import OmegaConf
+from tqdm.auto import tqdm
+
+from fudoki.model import instantiate_model
+from fudoki.janus.models import VLChatProcessor
+from flow_matching.path import MixtureDiscreteSoftmaxProbPath
+from flow_matching.data.navsim import SupervisedDataset
+from flow_matching.utils.flow import get_source_distribution
+
+
+logger = get_logger(__name__, log_level="INFO")
+
+
+@torch.no_grad()
+def init_numeric_and_special_tokens(
+ model,
+ tokenizer,
+ numeric_tokens,
+ noise_scale: float = 0.01,
+):
+ emb = model.get_input_embeddings().weight
+ device = emb.device
+ dim = emb.shape[1]
+
+ def tok_ids_for_text(text: str):
+ # Get subword ids (no specials), filter out -100/None if any
+ ids = tokenizer.encode(text,add_special_tokens=False)
+ return [i for i in ids if isinstance(i, int) and i >= 0]
+
+ # ---- Build a numeric "base" vector from digits and dot ----
+ digit_ids = []
+ for d in "0123456789":
+ _ids = tok_ids_for_text(d)
+ digit_ids.extend(_ids)
+ dot_ids = tok_ids_for_text(".")
+
+ base_chunks = []
+ if digit_ids:
+ base_chunks.append(emb[torch.tensor(digit_ids, device=device)].mean(dim=0))
+ if dot_ids:
+ base_chunks.append(emb[torch.tensor(dot_ids, device=device)].mean(dim=0))
+ if base_chunks:
+ numeric_base = torch.stack(base_chunks, dim=0).mean(dim=0)
+ else:
+ # fallback if tokenizer lacks digits/dot as standalone pieces
+ numeric_base = torch.zeros(dim, device=device)
+
+ # ---- Initialize numeric tokens ----
+ for t in numeric_tokens:
+ tid = tokenizer.convert_tokens_to_ids(t)
+ if tid is None or tid < 0:
+ continue
+ noise = noise_scale * torch.randn(dim, device=device)
+ emb[tid] = numeric_base + noise
+
+
+def training_step(
+ model,
+ x_1,
+ source_distribution,
+ data_info,
+ path,
+ time_epsilon = 0.001,
+ loss_fn = CrossEntropyLoss(),
+ stage="s1",
+ vl_chat_processor=None,
+ args=None,
+):
+ x_0 = source_distribution.sample_like(x_1)
+ t = torch.rand(x_1.shape[0], device=x_1.device) * (1.0 - time_epsilon)
+
+ if stage == "s1":
+ x_t = x_1
+ elif stage == "s2":
+ # update emb layer when using num tokenizer
+ path_sample = path.sample(x_0, x_1, t)
+ x_t = path_sample.x_t
+
+ # text_token_mask==1 ==> generated text token
+ x_t = x_t * data_info['text_token_mask'] + x_1 * (1 - data_info['text_token_mask'])
+ data_info['understanding_img'] = data_info['understanding_img'].to(dtype=model.dtype)
+
+ _, txt_logits = model(x_t, data_info)
+
+ b, _, c = txt_logits.shape
+ mask = data_info['text_token_mask'].unsqueeze(-1).bool()
+ txt_logits = txt_logits.masked_select(mask)
+ txt_logits = txt_logits.view(b, -1, c)
+ x_1 = x_1.masked_select(mask.squeeze(-1)).view(b, -1)
+
+ loss = ce_loss = loss_fn(txt_logits.flatten(0, 1), x_1.flatten(0, 1)).mean()
+ loss_dict = {"ce_loss": ce_loss.detach().item()}
+
+ if stage == "s2":
+ start_mask = x_1 >= vl_chat_processor.num_start_id
+ end_mask = x_1 <= vl_chat_processor.num_end_id - 1
+ action_mask = start_mask & end_mask
+
+ if action_mask.any():
+ if args.l2_loss_weight > 0:
+ pred_probabilities = F.softmax(txt_logits, dim=-1)
+ pred_ids = torch.argmax(pred_probabilities, dim=-1)
+ pred_num_ids = pred_ids.masked_select(action_mask)
+ pred_nums = vl_chat_processor.min_num + (pred_num_ids - vl_chat_processor.num_start_id) * vl_chat_processor.interval
+ pred_nums = torch.clip(pred_nums, vl_chat_processor.min_num, vl_chat_processor.max_num)
+
+ tgt_num_ids = x_1.masked_select(action_mask) # [N]
+ tgt_nums = vl_chat_processor.min_num + (tgt_num_ids - vl_chat_processor.num_start_id) * vl_chat_processor.interval
+
+ l2_loss = args.l2_loss_weight + F.mse_loss(pred_nums, tgt_nums, reduction="mean")
+ loss = loss + l2_loss
+ loss_dict["l2_loss"] = l2_loss.detach().item()
+
+ loss_dict["loss"] = loss.detach().item()
+ return loss, loss_dict
+
+
+def main(args):
+ accelerator = Accelerator(
+ mixed_precision=args.mixed_precision,
+ gradient_accumulation_steps=args.accumulate_grad_batches,
+ log_with="tensorboard",
+ project_dir=args.output_dir
+ )
+
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # random seed
+ seed = args.seed + accelerator.process_index
+ if args.random_seed:
+ seed = seed + random.randint(0, 500)
+ set_seed(seed)
+ logger.info(f"accelerator.process_index: {accelerator.process_index}, seed: {seed} \n")
+
+ # work dir
+ if accelerator.is_main_process:
+ os.makedirs(args.output_dir, exist_ok=True)
+ config_path = os.path.join(args.output_dir, "config.yaml")
+ OmegaConf.save(args, config_path)
+ accelerate_config_path = os.path.join(args.output_dir, "accelerate_config_ds2.yaml")
+ shutil.copyfile(
+ "./config/accelerate_config_ds2.yaml",
+ accelerate_config_path
+ )
+ accelerator.wait_for_everyone()
+
+ # dtype
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # prepare dataset
+ vl_chat_processor = VLChatProcessor.from_pretrained(args.model_path)
+ if args.use_quantize:
+ origin_len = len(vl_chat_processor.tokenizer)
+ num_tokens = [f"{x:.2f}" for x in np.linspace(-100, 100, 20001)]
+ num_tokens_length = len(num_tokens)
+ vl_chat_processor.tokenizer.add_tokens(num_tokens)
+
+ vl_chat_processor.num_start_id = origin_len
+ vl_chat_processor.num_end_id = origin_len + num_tokens_length - 1
+ vl_chat_processor.min_num = -100
+ vl_chat_processor.max_num = 100
+ vl_chat_processor.interval = 0.01
+
+ logger.info(f"Total number tokens: {num_tokens_length}")
+
+ # data
+ dataset = SupervisedDataset(
+ data_list=args.data_list,
+ vl_chat_processor=vl_chat_processor,
+ txt_max_length=args.txt_max_length
+ )
+ dataloader = DataLoader(
+ dataset,
+ shuffle=True,
+ batch_size=args.batch_size,
+ num_workers=args.dataloader_num_workers,
+ pin_memory=True
+ )
+ logger.info(f"Max txt length: {args.txt_max_length}")
+ logger.info(f"Total data samples: {len(dataset)}")
+
+ # prepare model
+ stage = args.stage
+ logger.info(f"Training stege: {stage}")
+ model = instantiate_model(
+ args.pretrain_model_path
+ ).to(weight_dtype)
+ model.uncond_prob = args.uncond_prob
+
+ if os.path.exists(args.pretrain_path):
+ sd = torch.load(args.pretrain_path, map_location='cpu')
+ model.load_state_dict(sd, strict=True)
+ model = model.to(weight_dtype)
+ logger.info(f"Loading pretrain ckpt from {args.pretrain_path}")
+
+ if stage == "s1":
+ model.language_model.resize_token_embeddings(args.vocab_size + num_tokens_length)
+ init_numeric_and_special_tokens(
+ model.language_model,
+ vl_chat_processor.tokenizer,
+ numeric_tokens=num_tokens
+ )
+ elif stage == "s2":
+ if os.path.exists(args.new_embedding_path):
+ old_emb = copy.deepcopy(model.language_model.get_input_embeddings())
+
+ with torch.serialization.safe_globals([torch.nn.modules.sparse.Embedding]):
+ new_emb_state = torch.load(args.new_embedding_path, map_location="cpu")
+ logger.info(f"Loading new embedding from {args.new_embedding_path}")
+
+ if isinstance(new_emb_state, dict) and "weight" in new_emb_state:
+ weight = new_emb_state["weight"]
+ new_emb = torch.nn.Embedding(weight.size(0), weight.size(1))
+ new_emb.load_state_dict(new_emb_state)
+ else:
+ new_emb = new_emb_state
+
+ model.language_model.resize_token_embeddings(args.vocab_size + num_tokens_length)
+
+ # origin_len = old_emb.weight.shape[0]
+ origin_len = vl_chat_processor.num_start_id
+ new_emb.weight.data[:origin_len, :] = old_emb.weight.data[:origin_len, :]
+
+ model.language_model.set_input_embeddings(new_emb)
+
+ if os.path.exists(args.ckpt_path):
+ if model.language_model.model.embed_tokens.weight.shape[0] != args.vocab_size + num_tokens_length:
+ model.language_model.resize_token_embeddings(args.vocab_size + num_tokens_length)
+ sd = torch.load(args.ckpt_path, map_location='cpu')
+ model.load_state_dict(sd, strict=True)
+ model = model.to(weight_dtype)
+ logger.info(f"Loading ckpt from {args.ckpt_path}")
+
+ # prepare path
+ path = MixtureDiscreteSoftmaxProbPath(
+ mode='text',
+ embedding_path=args.text_embedding_path
+ )
+ if args.use_quantize:
+ path.set_embedding(model.language_model.get_input_embeddings())
+ else:
+ logger.info("No quantize!")
+
+ logger.info(f"path.a = {path.a}")
+ logger.info(f"path.c = {path.c}")
+
+ # set trainable params
+ model.requires_grad_(False)
+ if stage == "s1":
+ model.language_model.requires_grad_(False)
+ model.language_model.model.embed_tokens.requires_grad_(True)
+ model.language_model.lm_head.requires_grad_(True)
+ elif stage == "s2":
+ model.language_model.requires_grad_(True)
+ if args.train_llm_emb:
+ model.language_model.model.embed_tokens.requires_grad_(True)
+ else:
+ model.language_model.model.embed_tokens.requires_grad_(False)
+
+ trainable_params = list(
+ filter(lambda p: p.requires_grad, model.parameters())
+ )
+
+ # log trainable params
+ # for name, param in model.named_parameters():
+ # if param.requires_grad:
+ # logger.info(f"Trainable params: {name}")
+
+ num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ logger.info(f"Total trainable parameters: {num_trainable/1e9:.3} B")
+
+ optimizer = torch.optim.AdamW(
+ trainable_params,
+ lr=args.learning_rate,
+ betas=(0.9, 0.95),
+ weight_decay=0.05,
+ )
+
+ # lr scheduler
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler_type,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes
+ )
+
+ # accelerator
+ model, optimizer, dataloader, lr_scheduler = accelerator.prepare(
+ model, optimizer, dataloader, lr_scheduler
+ )
+
+ source_distribution = get_source_distribution(
+ source_distribution=args.source_distribution,
+ vocab_size=args.vocab_size + num_tokens_length if args.use_quantize else args.vocab_size,
+ )
+
+
+ global_step = 0
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ disable=not accelerator.is_local_main_process,
+ )
+
+ # training loop
+ for epoch in range(args.max_epochs):
+ logger.info(f"Epoch {epoch + 1}/{args.max_epochs}")
+ logger.info(f"training sample length: {len(dataloader)}")
+
+ for _, batch in enumerate(dataloader):
+ with accelerator.accumulate(model):
+ x_1 = batch["input_ids"].to(dtype=torch.long)
+ loss, logs = training_step(
+ x_1=x_1,
+ model=model,
+ source_distribution=source_distribution,
+ data_info=batch,
+ path=path,
+ stage=stage,
+ vl_chat_processor=vl_chat_processor,
+ args=args
+ )
+ accelerator.backward(loss)
+
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(
+ trainable_params,
+ args.max_grad_norm,
+ )
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ logs["lr"] = optimizer.param_groups[0]['lr']
+
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step % args.checkpointing_steps == 0 \
+ and accelerator.is_main_process and args.checkpoints_total_limit is not None:
+
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ if os.path.exists(removing_checkpoint):
+ shutil.rmtree(removing_checkpoint)
+
+ accelerator.wait_for_everyone()
+ unwrap_net = accelerator.unwrap_model(model)
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ if accelerator.is_main_process:
+ unwrap_net.save_pretrained(save_path, max_shard_size="20GB")
+ logger.info(f"Saved state to {save_path}")
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if global_step >= args.max_train_steps:
+ break
+
+ accelerator.wait_for_everyone()
+ accelerator.end_training()
+ logger.info("training completed!")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--config", type=str, required=True)
+ parser.add_argument("--output_dir", type=str, required=True)
+ parser.add_argument("--output_obs_dir", type=str, default=None)
+
+ return parser.parse_args()
+
+
+if __name__ == '__main__':
+ args = parse_args()
+
+ config = OmegaConf.load(args.config)
+
+ # merge args
+ args_dict = vars(args).copy()
+ args_dict.pop("config", None)
+ config_keys = set(config.keys())
+ cli_keys = set(args_dict.keys())
+
+ # check conflict
+ conflict_keys = cli_keys & config_keys
+ if conflict_keys:
+ print(f"Args conflict: {conflict_keys}")
+
+ # merge
+ merged_config = OmegaConf.merge(OmegaConf.create(args_dict), config)
+ args = merged_config
+
+ # training
+ main(args)