Skip to content

Commit b3d7b79

Browse files
authored
Merge pull request #244 from TyndaleLym/fix/hstu-correctness-and-meta-alignment
Align HSTU with Meta reference; fix training/eval correctness
2 parents f9ddc09 + be56cd8 commit b3d7b79

14 files changed

Lines changed: 1791 additions & 1369 deletions

File tree

docs/en/api/api.md

Lines changed: 136 additions & 156 deletions
Large diffs are not rendered by default.

docs/en/models/generative.md

Lines changed: 90 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -11,58 +11,57 @@ Generative recommendation models are an emerging approach that leverages generat
1111

1212
### Description
1313

14-
HSTU (Hierarchical Sequence Transformer Unit) is a hierarchical sequence transformation unit designed for large-scale sequence recommendation, capable of supporting trillion-parameter recommendation systems.
14+
HSTU (Hierarchical Sequential Transduction Units) is an autoregressive sequence recommender for next-item prediction. In Torch-RecHub, `HSTUModel` consumes padded item-token sequences plus optional per-position time-difference features and returns logits over the item vocabulary at every sequence position.
1515

1616
### Core Principles
1717

18-
- **Hierarchical Structure**: Uses hierarchical design to decompose long sequences into multiple sub-sequences, improving model parallelism and scalability
19-
- **Transformer Architecture**: Based on Transformer architecture, capable of capturing long-range dependencies
20-
- **Large-scale Pretraining**: Supports large-scale pretraining, learning universal representations from massive data
21-
- **Efficient Inference**: Optimized inference process, supporting real-time recommendations
18+
- **Eq. 2 UVQK projection**: applies one `SiLU` to the joint `UVQK` projection before splitting, so `U`, `V`, `Q`, and `K` all pass through the same non-linearity.
19+
- **Eq. 3 attention bias**: adds per-head bucketed relative position/time bias `rab^{p,t}` to attention scores before `silu(scores) / max_seq_len`.
20+
- **Eq. 4 gated output**: projects `LayerNorm(A V) * U` through one output linear layer, without concat-u/x bypasses or a separate FFN.
21+
- **External residuals**: each layer is wrapped as `x = x + HSTULayer(x)` in `HSTUBlock`.
22+
- **Generative training**: predicts the next token in the sequence and masks PAD token `0` in the loss.
2223

2324
### Usage
2425

2526
```python
26-
from torch_rechub.models.generative import HSTUModel
27-
from torch_rechub.basic.features import SparseFeature, SequenceFeature
28-
29-
# Define features
30-
user_features = [
31-
SparseFeature(name="user_id", vocab_size=10000, embed_dim=32),
32-
SequenceFeature(name="user_history", vocab_size=100000, embed_dim=32, pooling="mean")
33-
]
27+
import torch
3428

35-
item_features = [
36-
SparseFeature(name="item_id", vocab_size=100000, embed_dim=32),
37-
SparseFeature(name="category", vocab_size=1000, embed_dim=16)
38-
]
29+
from torch_rechub.models.generative import HSTUModel
3930

40-
# Create model
4131
model = HSTUModel(
42-
user_features=user_features,
43-
item_features=item_features,
44-
transformer_params={
45-
"num_layers": 2,
46-
"num_heads": 4,
47-
"hidden_size": 128,
48-
"intermediate_size": 256,
49-
"dropout": 0.2
50-
},
51-
hierarchical_params={
52-
"level1_window_size": 10,
53-
"level2_window_size": 5
54-
}
32+
vocab_size=100000,
33+
d_model=128,
34+
n_heads=4,
35+
n_layers=2,
36+
dqk=32,
37+
dv=32,
38+
max_seq_len=200,
39+
num_time_buckets=128,
5540
)
41+
42+
seq_tokens = torch.randint(1, 100000, (32, 200))
43+
time_diffs = torch.zeros_like(seq_tokens) # seconds from query time
44+
logits = model(seq_tokens, time_diffs)
45+
print(logits.shape) # torch.Size([32, 200, 100000])
5646
```
5747

5848
### Parameters
5949

6050
| Parameter | Type | Description | Default |
6151
| --- | --- | --- | --- |
62-
| user_features | list | User feature list | None |
63-
| item_features | list | Item feature list | None |
64-
| transformer_params | dict | Transformer parameters | None |
65-
| hierarchical_params | dict | Hierarchical structure parameters | None |
52+
| vocab_size | int | Item vocabulary size, with PAD reserved as token `0` | required |
53+
| d_model | int | Hidden dimension | 512 |
54+
| n_heads | int | Number of attention heads | 8 |
55+
| n_layers | int | Number of stacked HSTU layers | 4 |
56+
| dqk | int | Query/key dimension per head | 64 |
57+
| dv | int | Value/U dimension per head | 64 |
58+
| max_seq_len | int | Maximum supported sequence length | 256 |
59+
| dropout | float | Dropout rate | 0.1 |
60+
| use_time_embedding | bool | Add input-side time-bucket embedding; `time_diffs` is still used by `rab^{p,t}` | True |
61+
| num_time_buckets | int | Number of time buckets for embeddings and attention bias | 128 |
62+
| time_bucket_fn | {"sqrt", "log"} | Time bucketization function | "sqrt" |
63+
| time_bucket_divisor | float | Divisor applied after bucketization | 1.0 |
64+
| tie_embeddings | bool | Tie output projection to token embedding weights | True |
6665

6766
### Use Cases
6867

@@ -156,76 +155,70 @@ model = HLLMModel(
156155
## 5. Complete Training Example
157156

158157
```python
158+
import pickle
159+
import torch
160+
159161
from torch_rechub.models.generative import HSTUModel
160-
from torch_rechub.trainers import GenRecTrainer
161-
from torch_rechub.utils.data import DataGenerator
162-
from torch_rechub.basic.features import SparseFeature, SequenceFeature
163-
164-
# 1. Define features
165-
user_features = [
166-
SparseFeature(name="user_id", vocab_size=10000, embed_dim=32),
167-
SequenceFeature(name="user_history", vocab_size=100000, embed_dim=32, pooling="mean")
168-
]
169-
170-
item_features = [
171-
SparseFeature(name="item_id", vocab_size=100000, embed_dim=32),
172-
SparseFeature(name="category", vocab_size=1000, embed_dim=16)
173-
]
174-
175-
# 2. Prepare data
176-
# Assume x and y are preprocessed feature and label data
177-
x = {
178-
"user_id": user_id_data,
179-
"user_history": user_history_data,
180-
"item_id": item_id_data,
181-
"category": category_data
182-
}
183-
y = label_data # click/no-click labels
184-
185-
# 3. Create data generator
186-
dg = DataGenerator(x, y)
187-
train_dl, val_dl, test_dl = dg.generate_dataloader(split_ratio=[0.7, 0.1], batch_size=256)
188-
189-
# 4. Create model
162+
from torch_rechub.trainers import SeqTrainer
163+
from torch_rechub.utils.data import SequenceDataGenerator
164+
165+
with open("examples/generative/data/ml-1m/processed/train_data.pkl", "rb") as f:
166+
train_data = pickle.load(f)
167+
with open("examples/generative/data/ml-1m/processed/val_data.pkl", "rb") as f:
168+
val_data = pickle.load(f)
169+
with open("examples/generative/data/ml-1m/processed/test_data.pkl", "rb") as f:
170+
test_data = pickle.load(f)
171+
with open("examples/generative/data/ml-1m/processed/vocab.pkl", "rb") as f:
172+
vocab = pickle.load(f)
173+
174+
train_gen = SequenceDataGenerator(
175+
train_data["seq_tokens"],
176+
train_data["seq_positions"],
177+
train_data["targets"],
178+
train_data["seq_time_diffs"],
179+
)
180+
val_gen = SequenceDataGenerator(
181+
val_data["seq_tokens"],
182+
val_data["seq_positions"],
183+
val_data["targets"],
184+
val_data["seq_time_diffs"],
185+
)
186+
test_gen = SequenceDataGenerator(
187+
test_data["seq_tokens"],
188+
test_data["seq_positions"],
189+
test_data["targets"],
190+
test_data["seq_time_diffs"],
191+
)
192+
193+
train_dl = train_gen.generate_dataloader(batch_size=512, num_workers=0)[0]
194+
val_dl = val_gen.generate_dataloader(batch_size=512, num_workers=0)[0]
195+
test_dl = test_gen.generate_dataloader(batch_size=512, num_workers=0)[0]
196+
197+
vocab_size = len(vocab["item_to_idx"]) if "item_to_idx" in vocab else len(vocab)
190198
model = HSTUModel(
191-
user_features=user_features,
192-
item_features=item_features,
193-
transformer_params={
194-
"num_layers": 2,
195-
"num_heads": 4,
196-
"hidden_size": 128,
197-
"intermediate_size": 256,
198-
"dropout": 0.2
199-
},
200-
hierarchical_params={
201-
"level1_window_size": 10,
202-
"level2_window_size": 5
203-
}
199+
vocab_size=vocab_size,
200+
d_model=128,
201+
n_heads=4,
202+
n_layers=2,
203+
dqk=32,
204+
dv=32,
205+
max_seq_len=200,
206+
dropout=0.1,
204207
)
205208

206-
# 5. Create trainer
207-
trainer = GenRecTrainer(
208-
model=model,
209+
trainer = SeqTrainer(
210+
model,
211+
optimizer_fn=torch.optim.Adam,
209212
optimizer_params={"lr": 0.001, "weight_decay": 0.0001},
210-
n_epoch=50,
213+
n_epoch=10,
211214
earlystop_patience=10,
212-
device="cuda:0",
213-
model_path="saved/hstu"
215+
device="cuda" if torch.cuda.is_available() else "cpu",
216+
model_path="saved/hstu",
214217
)
215218

216-
# 6. Train model
217219
trainer.fit(train_dl, val_dl)
218-
219-
# 7. Evaluate model
220-
auc = trainer.evaluate(trainer.model, test_dl)
221-
print(f"Test AUC: {auc}")
222-
223-
# 8. Export model
224-
trainer.export_onnx("hstu.onnx")
225-
226-
# 9. Model prediction
227-
preds = trainer.predict(trainer.model, test_dl)
228-
print(f"Predictions shape: {preds.shape}")
220+
test_loss, top1_acc = trainer.evaluate(test_dl)
221+
print(f"test_loss={test_loss:.4f}, top1_acc={top1_acc:.4f}")
229222
```
230223

231224
## 6. FAQ
@@ -307,4 +300,4 @@ A: Try the following approaches:
307300
- Develop more efficient model training and inference methods
308301
- Achieve distributed and scalable generative recommendation systems
309302

310-
Generative recommendation is an important development direction for recommendation systems, capable of providing richer, more natural, and more personalized recommendation experiences. Torch-RecHub provides various advanced generative recommendation models for developers to choose based on business requirements. With the continuous development of large language models and generative AI technologies, generative recommendation will be applied in more scenarios, providing users with better recommendation experiences.
303+
Generative recommendation is an important development direction for recommendation systems, capable of providing richer, more natural, and more personalized recommendation experiences. Torch-RecHub provides various advanced generative recommendation models for developers to choose based on business requirements. With the continuous development of large language models and generative AI technologies, generative recommendation will be applied in more scenarios, providing users with better recommendation experiences.

0 commit comments

Comments
 (0)