Skip to content

Commit 5c816ba

Browse files
committed
feat: add local Qwen3-Omni API and fine-tune regression
* swap Qwen3-Omni dataset generation to load Hugging Face checkpoints locally and wire up tests/ docs * ensure AGENTS.md/CLAUDE.md stay untracked * add a lightweight GPT-2 fine-tuning test that asserts weights actually change
1 parent 71ef9e5 commit 5c816ba

File tree

7 files changed

+436
-3
lines changed

7 files changed

+436
-3
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,6 @@ dmypy.json
148148
.cache/
149149

150150
.DS_Store
151+
152+
AGENTS.md
153+
CLAUDE.md

README.md

Lines changed: 138 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,24 @@ Why xTuring:
3636
pip install xturing
3737
```
3838

39+
### Development Installation
40+
41+
If you want to contribute to xTuring or run from source:
42+
43+
```bash
44+
# Clone the repository
45+
git clone https://github.com/stochasticai/xturing.git
46+
cd xturing
47+
48+
# Install in editable mode with development dependencies
49+
pip install -e .
50+
pip install -r requirements-dev.txt
51+
52+
# Set up pre-commit hooks (required before contributing)
53+
pre-commit install
54+
pre-commit install --hook-type commit-msg
55+
```
56+
3957
<br>
4058

4159
## 🚀 Quickstart
@@ -158,7 +176,7 @@ dataset = InstructionDataset('../llama/alpaca_data')
158176
model = GenericLoraKbitModel('tiiuae/falcon-7b')
159177

160178
# Generate outputs on desired prompts
161-
outputs = model.generate(dataset = dataset, batch_size=10)
179+
outputs = model.generate(dataset = dataset, batch_size=10)
162180

163181
```
164182

@@ -173,6 +191,16 @@ model.finetune(dataset=dataset)
173191
```
174192
> See `examples/models/qwen3/qwen3_lora_finetune.py` for a runnable script.
175193
194+
8. __Qwen3-Omni dataset generation__ – Run the multimodal checkpoint locally (download from Hugging Face) to bootstrap instruction corpora without leaving your machine.
195+
```python
196+
from xturing.datasets import InstructionDataset
197+
from xturing.model_apis.qwen import Qwen3OmniTextGenerationAPI
198+
199+
# Download `Qwen/Qwen2.5-Omni` (or another HF variant) ahead of time
200+
engine = Qwen3OmniTextGenerationAPI(model_name_or_path="Qwen/Qwen2.5-Omni")
201+
dataset = InstructionDataset.generate_dataset("./tasks.jsonl", engine=engine)
202+
```
203+
176204
An exploration of the [Llama LoRA INT4 working example](examples/features/int4_finetuning/LLaMA_lora_int4.ipynb) is recommended for an understanding of its application.
177205

178206
For an extended insight, consider examining the [GenericModel working example](examples/features/generic/generic_model.py) available in the repository.
@@ -182,9 +210,17 @@ For an extended insight, consider examining the [GenericModel working example](e
182210
## CLI playground
183211
<img src=".github/cli-playground.gif" width="80%" style="margin: 0 1%;"/>
184212

213+
The `xturing` CLI provides interactive tools for working with fine-tuned models:
214+
185215
```bash
186-
$ xturing chat -m "<path-to-model-folder>"
216+
# Chat with a fine-tuned model
217+
xturing chat -m "<path-to-model-folder>"
218+
219+
# Launch the UI playground (alternative to programmatic Playground)
220+
xturing ui
187221

222+
# Get help and see all available commands
223+
xturing --help
188224
```
189225

190226
## UI playground
@@ -250,13 +286,27 @@ Contribute to this by submitting your performance results on other GPUs by creat
250286

251287
## 📎 Fine‑tuned model checkpoints
252288
We have already fine-tuned some models that you can use as your base or start playing with.
253-
Here is how you would load them:
254289

290+
### Loading Models
291+
292+
**Load from xTuring hub:**
255293
```python
256294
from xturing.models import BaseModel
257295
model = BaseModel.load("x/distilgpt2_lora_finetuned_alpaca")
258296
```
259297

298+
**Load from local directory:**
299+
```python
300+
model = BaseModel.load("/path/to/saved/model")
301+
```
302+
303+
**Create a new model for fine-tuning:**
304+
```python
305+
model = BaseModel.create("llama_lora")
306+
```
307+
308+
### Available Pre-trained Models
309+
260310
| model | dataset | Path |
261311
|---------------------|--------|---------------|
262312
| DistilGPT-2 LoRA | alpaca | `x/distilgpt2_lora_finetuned_alpaca` |
@@ -281,6 +331,7 @@ Below is a list of all the supported models via `BaseModel` class of `xTuring` a
281331
|LLaMA2 | llama2|
282332
|MiniMaxM2 | minimax_m2|
283333
|OPT-1.3B | opt|
334+
|Qwen3-0.6B | qwen3_0_6b|
284335

285336
The above are the base variants. Use these templates for `LoRA`, `INT8`, and `INT8 + LoRA` versions:
286337

@@ -314,17 +365,101 @@ Replace `<model_path>` with a local directory or a Hugging Face model like `face
314365

315366
<br>
316367

368+
## 🧪 Running Tests
369+
370+
The project uses pytest for testing. Test files are located in the `tests/` directory.
371+
372+
Run all tests:
373+
```bash
374+
pytest
375+
```
376+
377+
Run a specific test file:
378+
```bash
379+
pytest tests/xturing/models/test_qwen_model.py
380+
```
381+
382+
Skip slow tests:
383+
```bash
384+
pytest -m "not slow"
385+
```
386+
387+
Skip GPU tests (for CPU-only environments):
388+
```bash
389+
pytest -m "not gpu"
390+
```
391+
392+
Test markers used in this project:
393+
- `@pytest.mark.slow` - Tests that take significant time to run
394+
- `@pytest.mark.gpu` - Tests requiring GPU hardware
395+
396+
<br>
397+
317398
## 🤝 Help and Support
318399
If you have any questions, you can create an issue on this repository.
319400

320401
You can also join our [Discord server](https://discord.gg/TgHXuSJEk6) and start a discussion in the `#xturing` channel.
321402

322403
<br>
323404

405+
## 🏗️ Project Structure
406+
407+
Understanding the codebase organization:
408+
409+
```
410+
src/xturing/
411+
├── models/ # Model classes and registry (BaseModel, LLaMA, GPT-2, etc.)
412+
├── engines/ # Low-level model loading, tokenization, and operations
413+
├── datasets/ # Dataset loaders (InstructionDataset, TextDataset)
414+
├── trainers/ # Training loops (LightningTrainer with DeepSpeed support)
415+
├── preprocessors/ # Data preprocessing and tokenization
416+
├── config/ # YAML configurations for finetuning and generation
417+
├── cli/ # CLI commands (chat, ui, api)
418+
├── ui/ # Gradio UI playground
419+
├── self_instruct/ # Dataset generation utilities
420+
└── utils/ # Shared utilities
421+
422+
tests/xturing/ # Test suite mirroring src structure
423+
examples/ # Example scripts organized by model and feature
424+
```
425+
426+
**Key architectural patterns:**
427+
- **Registry Pattern**: Models and engines use a registry-based factory pattern via `BaseModel.create()` and `BaseEngine.create()`
428+
- **Model Variants**: Each model family has multiple variants following the naming template `<base>_[lora]_[int8|kbit]`
429+
- Example: `llama`, `llama_lora`, `llama_int8`, `llama_lora_int8`
430+
- **Configuration**: Training and generation parameters are defined in YAML files per model in `src/xturing/config/`
431+
- **Engines**: Handle the low-level operations (loading weights, tokenization, DeepSpeed integration)
432+
- **Models**: Provide high-level API (`finetune()`, `generate()`, `evaluate()`, `save()`, `load()`)
433+
434+
<br>
435+
324436
## 📝 License
325437
This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details.
326438

327439
<br>
328440

329441
## 🌎 Contributing
330442
As an open source project in a rapidly evolving field, we welcome contributions of all kinds, including new features and better documentation. Please read our [contributing guide](CONTRIBUTING.md) to learn how you can get involved.
443+
444+
### Quick Contribution Guidelines
445+
446+
**Important:** All pull requests should target the `dev` branch, not `main`.
447+
448+
The project uses pre-commit hooks to enforce code quality:
449+
- **black** - Code formatting
450+
- **isort** - Import sorting (black profile)
451+
- **autoflake** - Remove unused imports
452+
- **absolufy-imports** - Convert relative to absolute imports
453+
- **gitlint** - Commit message linting
454+
455+
You can manually format code:
456+
```bash
457+
black src/ tests/
458+
isort src/ tests/
459+
```
460+
461+
Pre-commit hooks will automatically run these checks when you commit. Make sure to install them:
462+
```bash
463+
pre-commit install
464+
pre-commit install --hook-type commit-msg
465+
```

docs/docs/advanced/generate.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,16 @@ engine = Davinci("your-api-key")
4141
engine = ClaudeSonnet("your-api-key")
4242
```
4343

44+
</TabItem>
45+
<TabItem value="qwen" label="Qwen3-Omni (local)">
46+
47+
Download the desired checkpoint from [Hugging Face](https://huggingface.co/Qwen/Qwen2.5-Omni) (or point to a local directory) and load it directly.
48+
49+
```python
50+
from xturing.model_apis.qwen import Qwen3OmniTextGenerationAPI
51+
engine = Qwen3OmniTextGenerationAPI(model_name_or_path="Qwen/Qwen2.5-Omni")
52+
```
53+
4454
</TabItem>
4555
</Tabs>
4656

src/xturing/model_apis/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from xturing.model_apis.openai import ChatGPT as OpenAIChatGPT
66
from xturing.model_apis.openai import Davinci as OpenAIDavinci
77
from xturing.model_apis.openai import OpenAITextGenerationAPI
8+
from xturing.model_apis.qwen import Qwen3OmniTextGenerationAPI
89

910
BaseApi.add_to_registry(OpenAITextGenerationAPI.config_name, OpenAITextGenerationAPI)
1011
BaseApi.add_to_registry(CohereTextGenerationAPI.config_name, CohereTextGenerationAPI)
@@ -13,3 +14,6 @@
1314
BaseApi.add_to_registry(OpenAIChatGPT.config_name, OpenAIChatGPT)
1415
BaseApi.add_to_registry(CohereMedium.config_name, CohereMedium)
1516
BaseApi.add_to_registry(ClaudeSonnet.config_name, ClaudeSonnet)
17+
BaseApi.add_to_registry(
18+
Qwen3OmniTextGenerationAPI.config_name, Qwen3OmniTextGenerationAPI
19+
)

src/xturing/model_apis/qwen.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
from datetime import datetime
2+
from typing import Dict, List, Optional, Sequence
3+
4+
import torch
5+
from transformers import AutoModelForCausalLM, AutoTokenizer
6+
7+
from xturing.model_apis.base import TextGenerationAPI
8+
9+
10+
class Qwen3OmniTextGenerationAPI(TextGenerationAPI):
11+
"""Text generation API wrapper for running Qwen3-Omni locally via Hugging Face."""
12+
13+
config_name = "qwen3_omni"
14+
15+
def __init__(
16+
self,
17+
model_name_or_path: str = "Qwen/Qwen2.5-Omni",
18+
device: Optional[str] = None,
19+
tokenizer_kwargs: Optional[Dict] = None,
20+
model_kwargs: Optional[Dict] = None,
21+
default_generate_kwargs: Optional[Dict] = None,
22+
):
23+
super().__init__(
24+
engine=model_name_or_path,
25+
api_key=None,
26+
request_batch_size=1,
27+
)
28+
tokenizer_kwargs = tokenizer_kwargs or {}
29+
model_kwargs = model_kwargs or {}
30+
self.default_generate_kwargs = default_generate_kwargs or {}
31+
32+
self.tokenizer = AutoTokenizer.from_pretrained(
33+
model_name_or_path, trust_remote_code=True, **tokenizer_kwargs
34+
)
35+
self.model = AutoModelForCausalLM.from_pretrained(
36+
model_name_or_path, trust_remote_code=True, **model_kwargs
37+
)
38+
39+
if device is None:
40+
device = "cuda" if torch.cuda.is_available() else "cpu"
41+
self.device = torch.device(device)
42+
self.model.to(self.device)
43+
if self.tokenizer.pad_token is None:
44+
self.tokenizer.pad_token = self.tokenizer.eos_token
45+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
46+
47+
def _trim_stop_sequences(self, text: str, stop_sequences: Optional[Sequence[str]]) -> str:
48+
if not stop_sequences:
49+
return text
50+
cut_index = len(text)
51+
for stop in stop_sequences:
52+
if not stop:
53+
continue
54+
idx = text.find(stop)
55+
if idx != -1 and idx < cut_index:
56+
cut_index = idx
57+
return text[:cut_index].rstrip()
58+
59+
def _generate_single(
60+
self,
61+
prompt: str,
62+
max_tokens: int,
63+
temperature: float,
64+
top_p: Optional[float],
65+
stop_sequences: Optional[Sequence[str]],
66+
n: int,
67+
generation_overrides: Dict,
68+
) -> List[Dict[str, str]]:
69+
inputs = self.tokenizer(
70+
prompt,
71+
return_tensors="pt",
72+
)
73+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
74+
do_sample = temperature is not None and temperature > 0
75+
generate_kwargs = {
76+
"max_new_tokens": max_tokens,
77+
"do_sample": do_sample,
78+
"num_return_sequences": n,
79+
"eos_token_id": self.tokenizer.eos_token_id,
80+
"pad_token_id": self.tokenizer.pad_token_id,
81+
}
82+
if temperature is not None:
83+
generate_kwargs["temperature"] = temperature
84+
if top_p is not None:
85+
generate_kwargs["top_p"] = top_p
86+
generate_kwargs.update(self.default_generate_kwargs)
87+
generate_kwargs.update(generation_overrides)
88+
outputs = self.model.generate(**inputs, **generate_kwargs)
89+
if n == 1:
90+
outputs = outputs.unsqueeze(0) if outputs.dim() == 1 else outputs
91+
generated_sequences: List[Dict[str, str]] = []
92+
prompt_length = inputs["input_ids"].shape[-1]
93+
for sequence in outputs:
94+
completion_tokens = sequence[prompt_length:]
95+
text = self.tokenizer.decode(
96+
completion_tokens,
97+
skip_special_tokens=True,
98+
).strip()
99+
text = self._trim_stop_sequences(text, stop_sequences)
100+
generated_sequences.append(
101+
{
102+
"text": text,
103+
"finish_reason": "stop",
104+
}
105+
)
106+
return generated_sequences
107+
108+
def generate_text(
109+
self,
110+
prompts,
111+
max_tokens,
112+
temperature,
113+
top_p=None,
114+
frequency_penalty=None,
115+
presence_penalty=None,
116+
stop_sequences=None,
117+
logprobs=None,
118+
n=1,
119+
best_of=1,
120+
retries=0,
121+
**generation_overrides,
122+
):
123+
if not isinstance(prompts, list):
124+
prompts = [prompts]
125+
126+
results = []
127+
for prompt in prompts:
128+
choices = self._generate_single(
129+
prompt=prompt,
130+
max_tokens=max_tokens,
131+
temperature=temperature,
132+
top_p=top_p,
133+
stop_sequences=stop_sequences,
134+
n=n,
135+
generation_overrides=generation_overrides,
136+
)
137+
data = {
138+
"prompt": prompt,
139+
"response": {"choices": choices},
140+
"created_at": str(datetime.now()),
141+
}
142+
results.append(data)
143+
144+
return results

0 commit comments

Comments
 (0)