diff --git a/.gitignore b/.gitignore
index 32f4aa5..ec68d71 100644
--- a/.gitignore
+++ b/.gitignore
@@ -148,3 +148,6 @@ dmypy.json
.cache/
.DS_Store
+
+AGENTS.md
+CLAUDE.md
diff --git a/README.md b/README.md
index 13c72fc..be47d10 100644
--- a/README.md
+++ b/README.md
@@ -36,6 +36,24 @@ Why xTuring:
pip install xturing
```
+### Development Installation
+
+If you want to contribute to xTuring or run from source:
+
+```bash
+# Clone the repository
+git clone https://github.com/stochasticai/xturing.git
+cd xturing
+
+# Install in editable mode with development dependencies
+pip install -e .
+pip install -r requirements-dev.txt
+
+# Set up pre-commit hooks (required before contributing)
+pre-commit install
+pre-commit install --hook-type commit-msg
+```
+
## π Quickstart
@@ -158,7 +176,7 @@ dataset = InstructionDataset('../llama/alpaca_data')
model = GenericLoraKbitModel('tiiuae/falcon-7b')
# Generate outputs on desired prompts
-outputs = model.generate(dataset = dataset, batch_size=10)
+ outputs = model.generate(dataset = dataset, batch_size=10)
```
@@ -173,6 +191,16 @@ model.finetune(dataset=dataset)
```
> See `examples/models/qwen3/qwen3_lora_finetune.py` for a runnable script.
+8. __Qwen3-Omni dataset generation__ β Run the multimodal checkpoint locally (download from Hugging Face) to bootstrap instruction corpora without leaving your machine.
+```python
+from xturing.datasets import InstructionDataset
+from xturing.model_apis.qwen import Qwen3OmniTextGenerationAPI
+
+# Download `Qwen/Qwen3-Omni-30B-A3B-Instruct` (or another HF variant) ahead of time
+engine = Qwen3OmniTextGenerationAPI(model_name_or_path="Qwen/Qwen3-Omni-30B-A3B-Instruct")
+dataset = InstructionDataset.generate_dataset("./tasks.jsonl", engine=engine)
+```
+
An exploration of the [Llama LoRA INT4 working example](examples/features/int4_finetuning/LLaMA_lora_int4.ipynb) is recommended for an understanding of its application.
For an extended insight, consider examining the [GenericModel working example](examples/features/generic/generic_model.py) available in the repository.
@@ -182,9 +210,17 @@ For an extended insight, consider examining the [GenericModel working example](e
## CLI playground
+The `xturing` CLI provides interactive tools for working with fine-tuned models:
+
```bash
-$ xturing chat -m ""
+# Chat with a fine-tuned model
+xturing chat -m ""
+
+# Launch the UI playground (alternative to programmatic Playground)
+xturing ui
+# Get help and see all available commands
+xturing --help
```
## UI playground
@@ -210,6 +246,8 @@ Playground().launch() ## launches localhost UI
## π Tutorials
- [Preparing your dataset](examples/datasets/preparing_your_dataset.py)
+- [SIFT-50M dataset helpers](examples/datasets/README.md)
+- [Qwen3-Omni HF/PEFT template (A100/H100)](examples/models/qwen3_omni/README.md)
- [Cerebras-GPT fine-tuning with LoRA and INT8](examples/models/cerebras/cerebras_lora_int8.ipynb) [](https://colab.research.google.com/drive/1eKq3oF7dnK8KuIfsTE70Gvvniwr1O9D0?usp=sharing)
- [Cerebras-GPT fine-tuning with LoRA](examples/models/cerebras/cerebras_lora.ipynb) [](https://colab.research.google.com/drive/1VjqQhstm5pT4EjPjx4Je7b3W2X1V3vDo?usp=sharing)
- [LLaMA fine-tuning with LoRA and INT8](examples/models/llama/llama_lora_int8.py) [](https://colab.research.google.com/drive/1SQUXq1AMZPSLD4mk3A3swUIc6Y2dclme?usp=sharing)
@@ -250,13 +288,27 @@ Contribute to this by submitting your performance results on other GPUs by creat
## π Fineβtuned model checkpoints
We have already fine-tuned some models that you can use as your base or start playing with.
-Here is how you would load them:
+### Loading Models
+
+**Load from xTuring hub:**
```python
from xturing.models import BaseModel
model = BaseModel.load("x/distilgpt2_lora_finetuned_alpaca")
```
+**Load from local directory:**
+```python
+model = BaseModel.load("/path/to/saved/model")
+```
+
+**Create a new model for fine-tuning:**
+```python
+model = BaseModel.create("llama_lora")
+```
+
+### Available Pre-trained Models
+
| model | dataset | Path |
|---------------------|--------|---------------|
| DistilGPT-2 LoRA | alpaca | `x/distilgpt2_lora_finetuned_alpaca` |
@@ -281,6 +333,7 @@ Below is a list of all the supported models via `BaseModel` class of `xTuring` a
|LLaMA2 | llama2|
|MiniMaxM2 | minimax_m2|
|OPT-1.3B | opt|
+|Qwen3-0.6B | qwen3_0_6b|
The above are the base variants. Use these templates for `LoRA`, `INT8`, and `INT8 + LoRA` versions:
@@ -314,6 +367,36 @@ Replace `` with a local directory or a Hugging Face model like `face
+## π§ͺ Running Tests
+
+The project uses pytest for testing. Test files are located in the `tests/` directory.
+
+Run all tests:
+```bash
+pytest
+```
+
+Run a specific test file:
+```bash
+pytest tests/xturing/models/test_qwen_model.py
+```
+
+Skip slow tests:
+```bash
+pytest -m "not slow"
+```
+
+Skip GPU tests (for CPU-only environments):
+```bash
+pytest -m "not gpu"
+```
+
+Test markers used in this project:
+- `@pytest.mark.slow` - Tests that take significant time to run
+- `@pytest.mark.gpu` - Tests requiring GPU hardware
+
+
+
## π€ Help and Support
If you have any questions, you can create an issue on this repository.
@@ -321,6 +404,37 @@ You can also join our [Discord server](https://discord.gg/TgHXuSJEk6) and start
+## ποΈ Project Structure
+
+Understanding the codebase organization:
+
+```
+src/xturing/
+βββ models/ # Model classes and registry (BaseModel, LLaMA, GPT-2, etc.)
+βββ engines/ # Low-level model loading, tokenization, and operations
+βββ datasets/ # Dataset loaders (InstructionDataset, TextDataset)
+βββ trainers/ # Training loops (LightningTrainer with DeepSpeed support)
+βββ preprocessors/ # Data preprocessing and tokenization
+βββ config/ # YAML configurations for finetuning and generation
+βββ cli/ # CLI commands (chat, ui, api)
+βββ ui/ # Gradio UI playground
+βββ self_instruct/ # Dataset generation utilities
+βββ utils/ # Shared utilities
+
+tests/xturing/ # Test suite mirroring src structure
+examples/ # Example scripts organized by model and feature
+```
+
+**Key architectural patterns:**
+- **Registry Pattern**: Models and engines use a registry-based factory pattern via `BaseModel.create()` and `BaseEngine.create()`
+- **Model Variants**: Each model family has multiple variants following the naming template `_[lora]_[int8|kbit]`
+ - Example: `llama`, `llama_lora`, `llama_int8`, `llama_lora_int8`
+- **Configuration**: Training and generation parameters are defined in YAML files per model in `src/xturing/config/`
+- **Engines**: Handle the low-level operations (loading weights, tokenization, DeepSpeed integration)
+- **Models**: Provide high-level API (`finetune()`, `generate()`, `evaluate()`, `save()`, `load()`)
+
+
+
## π License
This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details.
@@ -328,3 +442,26 @@ This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENS
## π Contributing
As an open source project in a rapidly evolving field, we welcome contributions of all kinds, including new features and better documentation. Please read our [contributing guide](CONTRIBUTING.md) to learn how you can get involved.
+
+### Quick Contribution Guidelines
+
+**Important:** All pull requests should target the `dev` branch, not `main`.
+
+The project uses pre-commit hooks to enforce code quality:
+- **black** - Code formatting
+- **isort** - Import sorting (black profile)
+- **autoflake** - Remove unused imports
+- **absolufy-imports** - Convert relative to absolute imports
+- **gitlint** - Commit message linting
+
+You can manually format code:
+```bash
+black src/ tests/
+isort src/ tests/
+```
+
+Pre-commit hooks will automatically run these checks when you commit. Make sure to install them:
+```bash
+pre-commit install
+pre-commit install --hook-type commit-msg
+```
diff --git a/docs/docs/advanced/generate.md b/docs/docs/advanced/generate.md
index ac27ac8..ef8ba80 100644
--- a/docs/docs/advanced/generate.md
+++ b/docs/docs/advanced/generate.md
@@ -41,6 +41,16 @@ engine = Davinci("your-api-key")
engine = ClaudeSonnet("your-api-key")
```
+
+
+
+ Download the desired checkpoint from [Hugging Face](https://huggingface.co/Qwen/Qwen2.5-Omni) (or point to a local directory) and load it directly.
+
+ ```python
+ from xturing.model_apis.qwen import Qwen3OmniTextGenerationAPI
+ engine = Qwen3OmniTextGenerationAPI(model_name_or_path="Qwen/Qwen2.5-Omni")
+ ```
+
diff --git a/examples/README.md b/examples/README.md
index 910873c..de9f5b8 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -16,6 +16,10 @@ examples/
### datsets/
This directory consists of multiple ways to generate your custom dataset from a given set of examples.
+Also includes SIFT-50M helpers:
+- `examples/datasets/sift50m_subset_builder.py` builds a small English subset.
+- `examples/datasets/sift50m_audio_mapper.py` resolves `audio_path` to local files.
+- `examples/datasets/README.md` contains full CLI recipes.
### features/
This directory consists of files with exapmles highlighting speific major features of the library, which can be replicated to any LLM you want.
diff --git a/examples/datasets/README.md b/examples/datasets/README.md
new file mode 100644
index 0000000..3fe8bea
--- /dev/null
+++ b/examples/datasets/README.md
@@ -0,0 +1,58 @@
+# Datasets
+
+This folder includes dataset helpers and recipes used by xTuring examples.
+
+## SIFT-50M helpers (English subsets)
+
+### 1) Build a small English subset
+
+Filters `amazon-agi/SIFT-50M` to English plus:
+- `closed_ended_content_level`
+- `open_ended`
+- optional `controllable_generation`
+
+```bash
+python examples/datasets/sift50m_subset_builder.py \
+ --output-dir ./data/sift50m_en_small \
+ --max-examples 100000 \
+ --include-controllable-generation \
+ --jsonl
+```
+
+Notes:
+- Use `--language-col` or `--category-col` if the dataset schema changes.
+- Set `--max-examples 0` to keep all rows after filtering.
+
+### 2) Resolve audio paths to local files
+
+SIFT-50M includes `audio_path` and (often) `data_source`. This script adds a
+resolved `audio_file` column and can drop rows with missing files.
+
+```bash
+python examples/datasets/sift50m_audio_mapper.py \
+ --input-dir ./data/sift50m_en_small \
+ --output-dir ./data/sift50m_en_small_mapped \
+ --audio-root mls=/data/mls \
+ --audio-root cv15=/data/commonvoice15 \
+ --audio-root vctk=/data/vctk \
+ --verify-exists \
+ --drop-missing \
+ --jsonl
+```
+
+If your dataset uses different columns:
+
+```bash
+python examples/datasets/sift50m_audio_mapper.py \
+ --input-dir ./data/sift50m_en_small \
+ --output-dir ./data/sift50m_en_small_mapped \
+ --audio-path-col audio_path \
+ --data-source-col data_source
+```
+
+## Outputs
+
+Each script writes:
+- a Hugging Face dataset directory (via `save_to_disk`)
+- `subset.jsonl` (if `--jsonl` is set)
+- a `*_meta.json` file with the filter settings and detected columns
diff --git a/examples/datasets/sift50m_audio_mapper.py b/examples/datasets/sift50m_audio_mapper.py
new file mode 100644
index 0000000..019c219
--- /dev/null
+++ b/examples/datasets/sift50m_audio_mapper.py
@@ -0,0 +1,199 @@
+"""Resolve SIFT-50M audio paths to local files.
+
+Given a subset saved with `save_to_disk`, this script attaches a resolved
+`audio_file` column based on `audio_path` and an optional `data_source`.
+"""
+
+from __future__ import annotations
+
+import argparse
+import json
+import os
+from pathlib import Path
+from typing import Dict, Optional
+
+from datasets import Dataset, DatasetDict, load_from_disk
+
+
+def _parse_audio_roots(values) -> Dict[str, str]:
+ roots: Dict[str, str] = {}
+ for item in values or []:
+ if "=" in item:
+ key, path = item.split("=", 1)
+ roots[key.strip()] = path.strip()
+ else:
+ # Allow a single default root via --audio-root /path
+ roots[""] = item.strip()
+ return roots
+
+
+def _resolve_path(
+ audio_path: Optional[str],
+ data_source: Optional[str],
+ roots: Dict[str, str],
+ default_root: Optional[str],
+) -> Optional[str]:
+ if not audio_path:
+ return None
+
+ # Already absolute or explicitly rooted.
+ if os.path.isabs(audio_path):
+ return audio_path
+
+ base = None
+ if data_source and data_source in roots:
+ base = roots[data_source]
+ elif "" in roots:
+ base = roots[""]
+ elif default_root:
+ base = default_root
+
+ if not base:
+ return None
+
+ return str(Path(base) / audio_path)
+
+
+def _as_jsonl(dataset: Dataset, path: Path) -> None:
+ path.parent.mkdir(parents=True, exist_ok=True)
+ with path.open("w", encoding="utf-8") as f:
+ for row in dataset:
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
+
+
+def load_any(path: str) -> Dataset:
+ data = load_from_disk(path)
+ if isinstance(data, DatasetDict):
+ return data["train"]
+ return data
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(
+ description="Attach resolved audio paths to a SIFT-50M subset."
+ )
+ parser.add_argument(
+ "--input-dir",
+ required=True,
+ help="Path to a dataset saved with save_to_disk.",
+ )
+ parser.add_argument(
+ "--output-dir",
+ required=True,
+ help="Where to write the updated dataset.",
+ )
+ parser.add_argument(
+ "--audio-path-col",
+ default="audio_path",
+ help="Column containing relative audio paths.",
+ )
+ parser.add_argument(
+ "--data-source-col",
+ default="data_source",
+ help="Column containing data source (e.g., mls, cv15, vctk).",
+ )
+ parser.add_argument(
+ "--resolved-audio-col",
+ default="audio_file",
+ help="Name of output column for resolved paths.",
+ )
+ parser.add_argument(
+ "--audio-root",
+ action="append",
+ default=[],
+ help="Root mapping like data_source=/path or just /path for default.",
+ )
+ parser.add_argument(
+ "--default-audio-root",
+ default=None,
+ help="Fallback root if no mapping matches.",
+ )
+ parser.add_argument(
+ "--verify-exists",
+ action="store_true",
+ help="Set missing files to null and optionally drop them.",
+ )
+ parser.add_argument(
+ "--drop-missing",
+ action="store_true",
+ help="Drop rows where resolved audio files do not exist.",
+ )
+ parser.add_argument(
+ "--jsonl",
+ action="store_true",
+ help="Also export subset.jsonl.",
+ )
+
+ args = parser.parse_args()
+
+ roots = _parse_audio_roots(args.audio_root)
+ dataset = load_any(args.input_dir)
+
+ if args.audio_path_col not in dataset.column_names:
+ raise ValueError(
+ f"Missing audio path column '{args.audio_path_col}'. "
+ f"Available: {dataset.column_names}"
+ )
+
+ if args.data_source_col not in dataset.column_names:
+ # Not fatal; proceed with None data_source.
+ data_source_col = None
+ else:
+ data_source_col = args.data_source_col
+
+ def _map(batch):
+ audio_paths = batch[args.audio_path_col]
+ data_sources = (
+ batch[data_source_col] if data_source_col else [None] * len(audio_paths)
+ )
+
+ resolved = []
+ exists = []
+ for ap, ds in zip(audio_paths, data_sources):
+ path = _resolve_path(ap, ds, roots, args.default_audio_root)
+ if path and args.verify_exists:
+ ok = Path(path).exists()
+ else:
+ ok = True
+ resolved.append(path if ok else None)
+ exists.append(ok)
+
+ return {args.resolved_audio_col: resolved, "_audio_exists": exists}
+
+ mapped = dataset.map(_map, batched=True, desc="Resolve audio paths")
+
+ if args.drop_missing:
+ mapped = mapped.filter(
+ lambda row: row["_audio_exists"]
+ and row[args.resolved_audio_col] is not None,
+ desc="Drop missing audio",
+ )
+
+ if "_audio_exists" in mapped.column_names:
+ mapped = mapped.remove_columns(["_audio_exists"])
+
+ output_dir = Path(args.output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+ mapped.save_to_disk(str(output_dir))
+
+ if args.jsonl:
+ _as_jsonl(mapped, output_dir / "subset.jsonl")
+
+ meta = {
+ "input_dir": args.input_dir,
+ "audio_path_col": args.audio_path_col,
+ "data_source_col": args.data_source_col,
+ "resolved_audio_col": args.resolved_audio_col,
+ "audio_roots": roots,
+ "default_audio_root": args.default_audio_root,
+ "verify_exists": args.verify_exists,
+ "drop_missing": args.drop_missing,
+ "num_rows": len(mapped),
+ "columns": mapped.column_names,
+ }
+ with (output_dir / "audio_map_meta.json").open("w", encoding="utf-8") as f:
+ json.dump(meta, f, indent=2, ensure_ascii=False)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/datasets/sift50m_subset_builder.py b/examples/datasets/sift50m_subset_builder.py
new file mode 100644
index 0000000..3dfc8d3
--- /dev/null
+++ b/examples/datasets/sift50m_subset_builder.py
@@ -0,0 +1,189 @@
+"""Build a small English subset of SIFT-50M.
+
+This script downloads metadata from Hugging Face, filters to English and the
+specified categories, and saves the resulting dataset locally.
+
+It is designed to be robust to minor schema changes by allowing explicit
+column selection and by auto-detecting likely column names.
+"""
+
+from __future__ import annotations
+
+import argparse
+import json
+from pathlib import Path
+from typing import Iterable, Optional
+
+from datasets import Dataset, load_dataset
+
+LANG_COL_CANDIDATES = ["language", "lang", "locale", "language_code", "lang_code"]
+CATEGORY_COL_CANDIDATES = [
+ "category",
+ "task_category",
+ "instruction_category",
+ "task_type",
+ "type",
+]
+
+
+def _first_existing(cols: Iterable[str], candidates: Iterable[str]) -> Optional[str]:
+ cols_set = set(cols)
+ for name in candidates:
+ if name in cols_set:
+ return name
+ return None
+
+
+def _infer_column(name: str, cols: Iterable[str], candidates: Iterable[str]) -> str:
+ inferred = _first_existing(cols, candidates)
+ if inferred is None:
+ raise ValueError(
+ f"Could not infer {name} column. Available columns: {sorted(cols)}"
+ )
+ return inferred
+
+
+def _as_jsonl(dataset: Dataset, path: Path) -> None:
+ path.parent.mkdir(parents=True, exist_ok=True)
+ with path.open("w", encoding="utf-8") as f:
+ for row in dataset:
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
+
+
+def build_subset(
+ output_dir: Path,
+ max_examples: int,
+ seed: int,
+ include_controllable_generation: bool,
+ language: str,
+ language_col: Optional[str],
+ category_col: Optional[str],
+ jsonl: bool,
+ dataset_id: str,
+ split: str,
+) -> None:
+ dataset = load_dataset(dataset_id, split=split)
+
+ cols = dataset.column_names
+ lang_col = language_col or _infer_column("language", cols, LANG_COL_CANDIDATES)
+ cat_col = category_col or _infer_column("category", cols, CATEGORY_COL_CANDIDATES)
+
+ categories = ["closed_ended_content_level", "open_ended"]
+ if include_controllable_generation:
+ categories.append("controllable_generation")
+
+ def _lang_filter(example):
+ value = example.get(lang_col)
+ if value is None:
+ return False
+ return str(value).lower().startswith(language.lower())
+
+ def _category_filter(example):
+ value = example.get(cat_col)
+ if value is None:
+ return False
+ return value in categories
+
+ filtered = dataset.filter(_lang_filter, desc="Filter language")
+ filtered = filtered.filter(_category_filter, desc="Filter categories")
+
+ if max_examples > 0 and len(filtered) > max_examples:
+ filtered = filtered.shuffle(seed=seed).select(range(max_examples))
+
+ output_dir.mkdir(parents=True, exist_ok=True)
+ filtered.save_to_disk(str(output_dir))
+
+ if jsonl:
+ _as_jsonl(filtered, output_dir / "subset.jsonl")
+
+ meta = {
+ "dataset_id": dataset_id,
+ "split": split,
+ "language_col": lang_col,
+ "category_col": cat_col,
+ "language": language,
+ "categories": categories,
+ "max_examples": max_examples,
+ "seed": seed,
+ "num_rows": len(filtered),
+ "columns": filtered.column_names,
+ }
+ with (output_dir / "subset_meta.json").open("w", encoding="utf-8") as f:
+ json.dump(meta, f, indent=2, ensure_ascii=False)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(
+ description="Build a small English subset of SIFT-50M."
+ )
+ parser.add_argument(
+ "--output-dir",
+ default="./data/sift50m_en_small",
+ help="Where to write the subset dataset.",
+ )
+ parser.add_argument(
+ "--max-examples",
+ type=int,
+ default=100_000,
+ help="Max number of rows to keep (0 means keep all).",
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="Random seed for shuffling before sampling.",
+ )
+ parser.add_argument(
+ "--include-controllable-generation",
+ action="store_true",
+ help="Include controllable_generation category.",
+ )
+ parser.add_argument(
+ "--language",
+ default="en",
+ help="Language prefix to match (default: en).",
+ )
+ parser.add_argument(
+ "--language-col",
+ default=None,
+ help="Override language column name.",
+ )
+ parser.add_argument(
+ "--category-col",
+ default=None,
+ help="Override category column name.",
+ )
+ parser.add_argument(
+ "--jsonl",
+ action="store_true",
+ help="Also export subset.jsonl.",
+ )
+ parser.add_argument(
+ "--dataset-id",
+ default="amazon-agi/SIFT-50M",
+ help="HF dataset id.",
+ )
+ parser.add_argument(
+ "--split",
+ default="train",
+ help="HF split to load.",
+ )
+
+ args = parser.parse_args()
+
+ build_subset(
+ output_dir=Path(args.output_dir),
+ max_examples=args.max_examples,
+ seed=args.seed,
+ include_controllable_generation=args.include_controllable_generation,
+ language=args.language,
+ language_col=args.language_col,
+ category_col=args.category_col,
+ jsonl=args.jsonl,
+ dataset_id=args.dataset_id,
+ split=args.split,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/models/qwen3_omni/README.md b/examples/models/qwen3_omni/README.md
new file mode 100644
index 0000000..576b8b1
--- /dev/null
+++ b/examples/models/qwen3_omni/README.md
@@ -0,0 +1,42 @@
+# Qwen3-Omni (HF/PEFT template)
+
+This folder contains a minimal training config template for
+`Qwen/Qwen3-Omni-30B-A3B-Instruct` intended for A100/H100 class GPUs.
+
+## Files
+- `examples/models/qwen3_omni/peft_a100_h100.yaml`: LoRA template config.
+- `examples/models/qwen3_omni/train_qwen3_omni_peft.py`: minimal HF/PEFT trainer scaffold.
+
+## Notes
+- This repo does not include a Qwen3-Omni fine-tuning engine; use HF/PEFT on your
+ cloud machine.
+- The model expects **multimodal message-format inputs**. Your training script
+ should transform each row into the required chat format and attach audio.
+- Start small: 50k-200k examples with LoRA and gradient checkpointing.
+
+## Suggested usage
+Use this YAML as a base config for your own training script. For example, load
+it and pass the values into `transformers.TrainingArguments` and `peft.LoraConfig`.
+
+## Usage
+The training script supports on-the-fly multimodal preprocessing using
+`Qwen3OmniMoeProcessor`. Provide the dataset columns in
+`examples/models/qwen3_omni/peft_a100_h100.yaml`:
+- `audio_column` (path to audio)
+- `text_column` (user instruction)
+- `target_column` (assistant text response)
+- `messages_column` (optional, if your dataset already stores full chat messages)
+
+## Smoke test
+Quickly validate the processor with a single audio file:
+```bash
+python examples/models/qwen3_omni/smoke_test_processor.py \
+ --audio /path/to/example.wav \
+ --text \"Please summarize the audio.\"
+```
+
+You can generate a quick dummy WAV:
+```bash
+python examples/models/qwen3_omni/generate_dummy_wav.py \
+ --output /tmp/dummy_tone.wav
+```
diff --git a/examples/models/qwen3_omni/generate_dummy_wav.py b/examples/models/qwen3_omni/generate_dummy_wav.py
new file mode 100644
index 0000000..71932e6
--- /dev/null
+++ b/examples/models/qwen3_omni/generate_dummy_wav.py
@@ -0,0 +1,45 @@
+"""Generate a short dummy WAV tone for testing."""
+
+from __future__ import annotations
+
+import argparse
+import math
+import wave
+from pathlib import Path
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Generate a short dummy WAV tone")
+ parser.add_argument(
+ "--output",
+ default="/tmp/dummy_tone.wav",
+ help="Output WAV path.",
+ )
+ parser.add_argument(
+ "--seconds", type=float, default=1.0, help="Duration in seconds"
+ )
+ parser.add_argument("--freq", type=float, default=440.0, help="Tone frequency (Hz)")
+ parser.add_argument("--rate", type=int, default=16000, help="Sample rate")
+ parser.add_argument("--amp", type=float, default=0.2, help="Amplitude (0-1)")
+
+ args = parser.parse_args()
+
+ output = Path(args.output)
+ output.parent.mkdir(parents=True, exist_ok=True)
+
+ n_samples = int(args.seconds * args.rate)
+ with wave.open(str(output), "w") as wf:
+ wf.setnchannels(1)
+ wf.setsampwidth(2)
+ wf.setframerate(args.rate)
+
+ for i in range(n_samples):
+ sample = args.amp * math.sin(2 * math.pi * args.freq * (i / args.rate))
+ value = int(max(-1.0, min(1.0, sample)) * 32767)
+ wf.writeframesraw(value.to_bytes(2, byteorder="little", signed=True))
+
+ print(str(output))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/models/qwen3_omni/peft_a100_h100.yaml b/examples/models/qwen3_omni/peft_a100_h100.yaml
new file mode 100644
index 0000000..818e7e3
--- /dev/null
+++ b/examples/models/qwen3_omni/peft_a100_h100.yaml
@@ -0,0 +1,66 @@
+# Minimal HF/PEFT training template for Qwen3-Omni-30B-A3B-Instruct.
+# This is a starting point; tune for your dataset and budget.
+
+model:
+ name_or_path: Qwen/Qwen3-Omni-30B-A3B-Instruct
+ trust_remote_code: true
+ torch_dtype: bfloat16
+ use_flash_attention_2: true
+
+peft:
+ method: lora
+ r: 16
+ lora_alpha: 32
+ lora_dropout: 0.05
+ bias: none
+ task_type: CAUSAL_LM
+ target_modules:
+ - q_proj
+ - k_proj
+ - v_proj
+ - o_proj
+ - gate_proj
+ - up_proj
+ - down_proj
+
+# Data fields are placeholders. Make sure your training script maps
+# the dataset to the model's expected multimodal message format.
+data:
+ train_dataset: /path/to/sift50m_subset
+ train_split: train
+ audio_column: audio_file
+ text_column: text
+ target_column: target
+ messages_column: null
+ system_prompt: "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."
+ use_audio_in_video: false
+ load_audio_from_video: false
+ video_fps: 2.0
+ max_samples: 100000
+
+training:
+ output_dir: ./outputs/qwen3_omni_30b_a3b_instruct_lora
+ per_device_train_batch_size: 1
+ gradient_accumulation_steps: 8
+ num_train_epochs: 1
+ learning_rate: 2.0e-5
+ warmup_ratio: 0.03
+ lr_scheduler_type: cosine
+ weight_decay: 0.0
+ max_grad_norm: 1.0
+ logging_steps: 10
+ save_steps: 500
+ save_total_limit: 2
+ bf16: true
+ gradient_checkpointing: true
+ optim: adamw_torch
+ dataloader_num_workers: 4
+ remove_unused_columns: false
+
+# Optional evaluation block; fill in if you have a validation split.
+eval:
+ enabled: false
+ eval_dataset: /path/to/val
+ eval_split: validation
+ per_device_eval_batch_size: 1
+ eval_steps: 500
diff --git a/examples/models/qwen3_omni/smoke_test_processor.py b/examples/models/qwen3_omni/smoke_test_processor.py
new file mode 100644
index 0000000..7c0c8e9
--- /dev/null
+++ b/examples/models/qwen3_omni/smoke_test_processor.py
@@ -0,0 +1,75 @@
+"""Smoke test for Qwen3OmniMoeProcessor on a single audio file.
+
+This validates that the processor can build model inputs from audio + text
+without running a full training loop.
+"""
+
+from __future__ import annotations
+
+import argparse
+from pathlib import Path
+
+from transformers import Qwen3OmniMoeProcessor
+
+DEFAULT_SYSTEM_PROMPT = (
+ "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, "
+ "capable of perceiving auditory and visual inputs, as well as generating text and speech."
+)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Smoke test Qwen3-Omni processor")
+ parser.add_argument(
+ "--model",
+ default="Qwen/Qwen3-Omni-30B-A3B-Instruct",
+ help="HF model name or path.",
+ )
+ parser.add_argument("--audio", required=True, help="Path to an audio file.")
+ parser.add_argument(
+ "--text",
+ default="Please summarize the audio.",
+ help="Instruction text.",
+ )
+ parser.add_argument(
+ "--system",
+ default=DEFAULT_SYSTEM_PROMPT,
+ help="System prompt text.",
+ )
+
+ args = parser.parse_args()
+
+ audio_path = Path(args.audio)
+ if not audio_path.exists():
+ raise FileNotFoundError(f"Audio file not found: {audio_path}")
+
+ processor = Qwen3OmniMoeProcessor.from_pretrained(
+ args.model, trust_remote_code=True
+ )
+ if processor.tokenizer.pad_token_id is None:
+ processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id
+
+ conversation = [
+ {"role": "system", "content": [{"type": "text", "text": args.system}]},
+ {
+ "role": "user",
+ "content": [
+ {"type": "audio", "path": str(audio_path)},
+ {"type": "text", "text": args.text},
+ ],
+ },
+ ]
+
+ inputs = processor.apply_chat_template(
+ [conversation],
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+ padding=True,
+ )
+
+ print("input_ids shape:", tuple(inputs["input_ids"].shape))
+ print("attention_mask shape:", tuple(inputs["attention_mask"].shape))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/models/qwen3_omni/train_qwen3_omni_peft.py b/examples/models/qwen3_omni/train_qwen3_omni_peft.py
new file mode 100644
index 0000000..70aff3d
--- /dev/null
+++ b/examples/models/qwen3_omni/train_qwen3_omni_peft.py
@@ -0,0 +1,261 @@
+"""Minimal HF/PEFT trainer scaffold for Qwen3-Omni-30B-A3B-Instruct.
+
+Supports on-the-fly multimodal preprocessing using Qwen3OmniMoeProcessor.
+"""
+
+from __future__ import annotations
+
+import argparse
+import copy
+import json
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+import torch
+from datasets import Dataset, DatasetDict, load_from_disk
+from peft import LoraConfig, get_peft_model
+from transformers import (
+ AutoModelForMultimodalLM,
+ Qwen3OmniMoeProcessor,
+ Trainer,
+ TrainingArguments,
+)
+
+
+def load_yaml(path: str) -> Dict[str, Any]:
+ try:
+ import yaml # type: ignore
+ except Exception as exc: # pragma: no cover - best-effort dependency check
+ raise RuntimeError(
+ "PyYAML is required to load YAML configs. Install with `pip install pyyaml`."
+ ) from exc
+
+ with open(path, "r", encoding="utf-8") as f:
+ return yaml.safe_load(f)
+
+
+def load_dataset_any(path: str) -> Dataset:
+ data = load_from_disk(path)
+ if isinstance(data, DatasetDict):
+ return data["train"]
+ return data
+
+
+def pad_batch(
+ batch: List[Dict[str, Any]],
+ pad_token_id: int,
+) -> Dict[str, torch.Tensor]:
+ input_ids = [item["input_ids"] for item in batch]
+ labels_in = [item.get("labels") for item in batch]
+
+ max_len = max(len(ids) for ids in input_ids)
+ padded_ids = []
+ padded_labels = []
+ attention_mask = []
+
+ for ids, lbl in zip(input_ids, labels_in):
+ pad_len = max_len - len(ids)
+ padded_ids.append(ids + [pad_token_id] * pad_len)
+ attention_mask.append([1] * len(ids) + [0] * pad_len)
+
+ if lbl is None:
+ lbl = copy.deepcopy(ids)
+ padded_labels.append(lbl + [-100] * pad_len)
+
+ return {
+ "input_ids": torch.tensor(padded_ids, dtype=torch.long),
+ "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
+ "labels": torch.tensor(padded_labels, dtype=torch.long),
+ }
+
+
+DEFAULT_SYSTEM_PROMPT = (
+ "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, "
+ "capable of perceiving auditory and visual inputs, as well as generating text and speech."
+)
+
+
+def build_conversation(
+ example: Dict[str, Any],
+ system_prompt: str,
+ audio_col: Optional[str],
+ text_col: Optional[str],
+ target_col: Optional[str],
+ messages_col: Optional[str],
+ include_assistant: bool,
+) -> List[Dict[str, Any]]:
+ if messages_col and messages_col in example:
+ messages = example[messages_col]
+ if isinstance(messages, list):
+ return messages
+
+ content: List[Dict[str, Any]] = []
+ if audio_col and audio_col in example and example[audio_col]:
+ content.append({"type": "audio", "path": example[audio_col]})
+ if text_col and text_col in example and example[text_col]:
+ content.append({"type": "text", "text": example[text_col]})
+
+ conversation: List[Dict[str, Any]] = [
+ {"role": "system", "content": [{"type": "text", "text": system_prompt}]},
+ {"role": "user", "content": content},
+ ]
+
+ if (
+ include_assistant
+ and target_col
+ and target_col in example
+ and example[target_col]
+ ):
+ conversation.append(
+ {
+ "role": "assistant",
+ "content": [{"type": "text", "text": example[target_col]}],
+ }
+ )
+
+ return conversation
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(
+ description="Minimal PEFT trainer for Qwen3-Omni-30B-A3B-Instruct."
+ )
+ parser.add_argument(
+ "--config",
+ default="examples/models/qwen3_omni/peft_a100_h100.yaml",
+ help="Path to YAML config.",
+ )
+ args = parser.parse_args()
+
+ cfg = load_yaml(args.config)
+
+ model_cfg = cfg["model"]
+ peft_cfg = cfg["peft"]
+ data_cfg = cfg["data"]
+ train_cfg = cfg["training"]
+
+ processor = Qwen3OmniMoeProcessor.from_pretrained(
+ model_cfg["name_or_path"],
+ trust_remote_code=model_cfg.get("trust_remote_code", True),
+ )
+ if processor.tokenizer.pad_token_id is None:
+ processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id
+
+ model = AutoModelForMultimodalLM.from_pretrained(
+ model_cfg["name_or_path"],
+ trust_remote_code=model_cfg.get("trust_remote_code", True),
+ torch_dtype=getattr(torch, model_cfg.get("torch_dtype", "bfloat16")),
+ attn_implementation=(
+ "flash_attention_2" if model_cfg.get("use_flash_attention_2") else None
+ ),
+ )
+
+ lora = LoraConfig(
+ r=peft_cfg["r"],
+ lora_alpha=peft_cfg["lora_alpha"],
+ lora_dropout=peft_cfg["lora_dropout"],
+ bias=peft_cfg["bias"],
+ task_type=peft_cfg["task_type"],
+ target_modules=peft_cfg["target_modules"],
+ )
+ model = get_peft_model(model, lora)
+
+ dataset = load_dataset_any(data_cfg["train_dataset"])
+ if data_cfg.get("train_split") and isinstance(dataset, DatasetDict):
+ dataset = dataset[data_cfg["train_split"]]
+
+ if data_cfg.get("max_samples", 0) and len(dataset) > data_cfg["max_samples"]:
+ dataset = dataset.shuffle(seed=train_cfg.get("seed", 42)).select(
+ range(data_cfg["max_samples"])
+ )
+
+ audio_col = data_cfg.get("audio_column")
+ text_col = data_cfg.get("text_column")
+ target_col = data_cfg.get("target_column")
+ messages_col = data_cfg.get("messages_column")
+ system_prompt = data_cfg.get("system_prompt", DEFAULT_SYSTEM_PROMPT)
+ use_audio_in_video = data_cfg.get("use_audio_in_video", False)
+ load_audio_from_video = data_cfg.get("load_audio_from_video", False)
+ video_fps = data_cfg.get("video_fps", 2.0)
+
+ def collate_fn(batch: List[Dict[str, Any]]):
+ if "input_ids" in batch[0]:
+ return pad_batch(batch, processor.tokenizer.pad_token_id)
+
+ full_conversations = [
+ build_conversation(
+ ex,
+ system_prompt=system_prompt,
+ audio_col=audio_col,
+ text_col=text_col,
+ target_col=target_col,
+ messages_col=messages_col,
+ include_assistant=True,
+ )
+ for ex in batch
+ ]
+ prompt_conversations = [
+ build_conversation(
+ ex,
+ system_prompt=system_prompt,
+ audio_col=audio_col,
+ text_col=text_col,
+ target_col=target_col,
+ messages_col=messages_col,
+ include_assistant=False,
+ )
+ for ex in batch
+ ]
+
+ full_inputs = processor.apply_chat_template(
+ full_conversations,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+ padding=True,
+ use_audio_in_video=use_audio_in_video,
+ load_audio_from_video=load_audio_from_video,
+ video_fps=video_fps,
+ )
+ prompt_inputs = processor.apply_chat_template(
+ prompt_conversations,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+ padding=True,
+ add_generation_prompt=True,
+ use_audio_in_video=use_audio_in_video,
+ load_audio_from_video=load_audio_from_video,
+ video_fps=video_fps,
+ )
+
+ prompt_lens = prompt_inputs["attention_mask"].sum(dim=1)
+ labels = full_inputs["input_ids"].clone()
+ for i, prompt_len in enumerate(prompt_lens):
+ labels[i, : int(prompt_len)] = -100
+ full_inputs["labels"] = labels
+
+ return full_inputs
+
+ training_args = TrainingArguments(**train_cfg)
+
+ trainer = Trainer(
+ model=model,
+ args=training_args,
+ train_dataset=dataset,
+ data_collator=collate_fn,
+ tokenizer=processor.tokenizer,
+ )
+
+ trainer.train()
+ trainer.save_model(train_cfg["output_dir"])
+
+ meta_path = Path(train_cfg["output_dir"]) / "training_meta.json"
+ meta = {"config": cfg, "num_rows": len(dataset)}
+ meta_path.parent.mkdir(parents=True, exist_ok=True)
+ with meta_path.open("w", encoding="utf-8") as f:
+ json.dump(meta, f, indent=2)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/xturing/model_apis/__init__.py b/src/xturing/model_apis/__init__.py
index 4fced5e..1ed3a1c 100644
--- a/src/xturing/model_apis/__init__.py
+++ b/src/xturing/model_apis/__init__.py
@@ -5,6 +5,7 @@
from xturing.model_apis.openai import ChatGPT as OpenAIChatGPT
from xturing.model_apis.openai import Davinci as OpenAIDavinci
from xturing.model_apis.openai import OpenAITextGenerationAPI
+from xturing.model_apis.qwen import Qwen3OmniTextGenerationAPI
BaseApi.add_to_registry(OpenAITextGenerationAPI.config_name, OpenAITextGenerationAPI)
BaseApi.add_to_registry(CohereTextGenerationAPI.config_name, CohereTextGenerationAPI)
@@ -13,3 +14,6 @@
BaseApi.add_to_registry(OpenAIChatGPT.config_name, OpenAIChatGPT)
BaseApi.add_to_registry(CohereMedium.config_name, CohereMedium)
BaseApi.add_to_registry(ClaudeSonnet.config_name, ClaudeSonnet)
+BaseApi.add_to_registry(
+ Qwen3OmniTextGenerationAPI.config_name, Qwen3OmniTextGenerationAPI
+)
diff --git a/src/xturing/model_apis/qwen.py b/src/xturing/model_apis/qwen.py
new file mode 100644
index 0000000..63900c2
--- /dev/null
+++ b/src/xturing/model_apis/qwen.py
@@ -0,0 +1,146 @@
+from datetime import datetime
+from typing import Dict, List, Optional, Sequence
+
+import torch
+from transformers import AutoModelForMultimodalLM, AutoTokenizer
+
+from xturing.model_apis.base import TextGenerationAPI
+
+
+class Qwen3OmniTextGenerationAPI(TextGenerationAPI):
+ """Text generation API wrapper for running Qwen3-Omni locally via Hugging Face."""
+
+ config_name = "qwen3_omni"
+
+ def __init__(
+ self,
+ model_name_or_path: str = "Qwen/Qwen3-Omni-30B-A3B-Instruct",
+ device: Optional[str] = None,
+ tokenizer_kwargs: Optional[Dict] = None,
+ model_kwargs: Optional[Dict] = None,
+ default_generate_kwargs: Optional[Dict] = None,
+ ):
+ super().__init__(
+ engine=model_name_or_path,
+ api_key=None,
+ request_batch_size=1,
+ )
+ tokenizer_kwargs = tokenizer_kwargs or {}
+ model_kwargs = model_kwargs or {}
+ self.default_generate_kwargs = default_generate_kwargs or {}
+
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ model_name_or_path, trust_remote_code=True, **tokenizer_kwargs
+ )
+ self.model = AutoModelForMultimodalLM.from_pretrained(
+ model_name_or_path, trust_remote_code=True, **model_kwargs
+ )
+
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.device = torch.device(device)
+ self.model.to(self.device)
+ if self.tokenizer.pad_token is None:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
+
+ def _trim_stop_sequences(
+ self, text: str, stop_sequences: Optional[Sequence[str]]
+ ) -> str:
+ if not stop_sequences:
+ return text
+ cut_index = len(text)
+ for stop in stop_sequences:
+ if not stop:
+ continue
+ idx = text.find(stop)
+ if idx != -1 and idx < cut_index:
+ cut_index = idx
+ return text[:cut_index].rstrip()
+
+ def _generate_single(
+ self,
+ prompt: str,
+ max_tokens: int,
+ temperature: float,
+ top_p: Optional[float],
+ stop_sequences: Optional[Sequence[str]],
+ n: int,
+ generation_overrides: Dict,
+ ) -> List[Dict[str, str]]:
+ inputs = self.tokenizer(
+ prompt,
+ return_tensors="pt",
+ )
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
+ do_sample = temperature is not None and temperature > 0
+ generate_kwargs = {
+ "max_new_tokens": max_tokens,
+ "do_sample": do_sample,
+ "num_return_sequences": n,
+ "eos_token_id": self.tokenizer.eos_token_id,
+ "pad_token_id": self.tokenizer.pad_token_id,
+ }
+ if temperature is not None:
+ generate_kwargs["temperature"] = temperature
+ if top_p is not None:
+ generate_kwargs["top_p"] = top_p
+ generate_kwargs.update(self.default_generate_kwargs)
+ generate_kwargs.update(generation_overrides)
+ outputs = self.model.generate(**inputs, **generate_kwargs)
+ if n == 1:
+ outputs = outputs.unsqueeze(0) if outputs.dim() == 1 else outputs
+ generated_sequences: List[Dict[str, str]] = []
+ prompt_length = inputs["input_ids"].shape[-1]
+ for sequence in outputs:
+ completion_tokens = sequence[prompt_length:]
+ text = self.tokenizer.decode(
+ completion_tokens,
+ skip_special_tokens=True,
+ ).strip()
+ text = self._trim_stop_sequences(text, stop_sequences)
+ generated_sequences.append(
+ {
+ "text": text,
+ "finish_reason": "stop",
+ }
+ )
+ return generated_sequences
+
+ def generate_text(
+ self,
+ prompts,
+ max_tokens,
+ temperature,
+ top_p=None,
+ frequency_penalty=None,
+ presence_penalty=None,
+ stop_sequences=None,
+ logprobs=None,
+ n=1,
+ best_of=1,
+ retries=0,
+ **generation_overrides,
+ ):
+ if not isinstance(prompts, list):
+ prompts = [prompts]
+
+ results = []
+ for prompt in prompts:
+ choices = self._generate_single(
+ prompt=prompt,
+ max_tokens=max_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ stop_sequences=stop_sequences,
+ n=n,
+ generation_overrides=generation_overrides,
+ )
+ data = {
+ "prompt": prompt,
+ "response": {"choices": choices},
+ "created_at": str(datetime.now()),
+ }
+ results.append(data)
+
+ return results
diff --git a/tests/xturing/model_apis/test_qwen_api.py b/tests/xturing/model_apis/test_qwen_api.py
new file mode 100644
index 0000000..1084a9e
--- /dev/null
+++ b/tests/xturing/model_apis/test_qwen_api.py
@@ -0,0 +1,119 @@
+from types import SimpleNamespace
+
+import torch
+
+
+class DummyTokenizer:
+ def __init__(self, decoded_text="Generated response."):
+ self.eos_token_id = 0
+ self.pad_token_id = 0
+ self.pad_token = ""
+ self.decoded_text = decoded_text
+
+ def __call__(self, text, return_tensors=None):
+ input_ids = torch.tensor([[11, 12]])
+ attention_mask = torch.ones_like(input_ids)
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
+
+ def decode(self, tokens, skip_special_tokens=True):
+ return self.decoded_text
+
+
+class DummyModel:
+ def __init__(self):
+ self.device = torch.device("cpu")
+ self.last_kwargs = None
+
+ def to(self, device):
+ self.device = device
+ return self
+
+ def generate(self, input_ids=None, attention_mask=None, **kwargs):
+ self.last_kwargs = kwargs
+ prompt_len = input_ids.shape[-1] if input_ids is not None else 0
+ total_len = prompt_len + 2
+ base = torch.arange(total_len).unsqueeze(0).long()
+ num_sequences = kwargs.get("num_return_sequences", 1)
+ return base.repeat(num_sequences, 1)
+
+
+def _install_mocks(monkeypatch, tokenizer):
+ dummy_tokenizer = tokenizer
+ dummy_model = DummyModel()
+ monkeypatch.setattr(
+ "xturing.model_apis.qwen.AutoTokenizer",
+ SimpleNamespace(from_pretrained=lambda *_, **__: dummy_tokenizer),
+ raising=False,
+ )
+ monkeypatch.setattr(
+ "xturing.model_apis.qwen.AutoModelForMultimodalLM",
+ SimpleNamespace(from_pretrained=lambda *_, **__: dummy_model),
+ raising=False,
+ )
+ return dummy_tokenizer, dummy_model
+
+
+def test_qwen3_omni_initialization(monkeypatch):
+ from xturing.model_apis.qwen import Qwen3OmniTextGenerationAPI
+
+ tokenizer, model = _install_mocks(monkeypatch, DummyTokenizer())
+ api = Qwen3OmniTextGenerationAPI(model_name_or_path="local-qwen", device="cpu")
+
+ assert api.engine == "local-qwen"
+ assert api.tokenizer is tokenizer
+ assert api.model is model
+ assert str(api.device) == "cpu"
+
+
+def test_qwen3_omni_generate_text(monkeypatch):
+ from xturing.model_apis.qwen import Qwen3OmniTextGenerationAPI
+
+ tokenizer, model = _install_mocks(monkeypatch, DummyTokenizer("Hello world."))
+
+ api = Qwen3OmniTextGenerationAPI(model_name_or_path="local-qwen", device="cpu")
+ results = api.generate_text(
+ prompts="Hi",
+ max_tokens=16,
+ temperature=0.7,
+ top_p=0.9,
+ n=2,
+ )
+
+ assert len(results) == 1
+ response = results[0]["response"]
+ assert len(response["choices"]) == 2
+ for choice in response["choices"]:
+ assert choice["text"] == "Hello world."
+ assert choice["finish_reason"] == "stop"
+
+ assert model.last_kwargs["max_new_tokens"] == 16
+ assert model.last_kwargs["temperature"] == 0.7
+ assert model.last_kwargs["top_p"] == 0.9
+ assert model.last_kwargs["num_return_sequences"] == 2
+
+
+def test_qwen3_omni_stop_sequences(monkeypatch):
+ from xturing.model_apis.qwen import Qwen3OmniTextGenerationAPI
+
+ tokenizer, _ = _install_mocks(monkeypatch, DummyTokenizer("Answer: hello"))
+
+ api = Qwen3OmniTextGenerationAPI(model_name_or_path="local-qwen", device="cpu")
+ results = api.generate_text(
+ prompts="Question?",
+ max_tokens=8,
+ temperature=0.0,
+ stop_sequences=[""],
+ n=1,
+ )
+
+ assert results[0]["response"]["choices"][0]["text"] == "Answer: hello"
+
+
+def test_qwen3_omni_registered():
+ from xturing.model_apis import BaseApi
+ from xturing.model_apis.qwen import Qwen3OmniTextGenerationAPI
+
+ assert (
+ BaseApi.registry[Qwen3OmniTextGenerationAPI.config_name]
+ is Qwen3OmniTextGenerationAPI
+ )
diff --git a/tests/xturing/models/test_gpt2_model.py b/tests/xturing/models/test_gpt2_model.py
index 4477a40..110572f 100644
--- a/tests/xturing/models/test_gpt2_model.py
+++ b/tests/xturing/models/test_gpt2_model.py
@@ -1,6 +1,8 @@
import tempfile
from pathlib import Path
+import torch
+
from xturing.datasets import TextDataset
from xturing.models import BaseModel
@@ -69,6 +71,22 @@ def test_train_gpt2():
assert len(result) == 2
+def test_train_gpt2_updates_weights():
+ dataset = TextDataset(DATASET_OTHER_EXAMPLE_DICT)
+ model = BaseModel.create("distilgpt2")
+ finetuning_config = model.finetuning_config()
+ finetuning_config.num_train_epochs = 1
+ finetuning_config.batch_size = 1
+ layer = model.engine.model.transformer.h[0].mlp.c_fc
+ original_weight = layer.weight.detach().cpu().clone()
+
+ model.finetune(dataset=dataset)
+
+ updated_weight = layer.weight.detach().cpu()
+ weight_shift = torch.sum(torch.abs(original_weight - updated_weight)).item()
+ assert weight_shift > 0.0
+
+
def test_train_gpt2_lora():
dataset = TextDataset(DATASET_OTHER_EXAMPLE_DICT)
model = BaseModel.create("distilgpt2_lora")