Skip to content

Commit be56cd8

Browse files
committed
style: apply project code formatter
1 parent 78c1b09 commit be56cd8

6 files changed

Lines changed: 35 additions & 12 deletions

File tree

examples/generative/data/ml-1m/preprocess_ml_hstu.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -280,9 +280,24 @@ def generate_leave_last_out_samples(self, user_sequences, user_timestamps, vocab
280280
item_to_idx = vocab['item_to_idx'] if 'item_to_idx' in vocab else vocab
281281

282282
splits = {
283-
'train': {'seq_tokens': [], 'seq_positions': [], 'seq_time_diffs': [], 'targets': []},
284-
'val': {'seq_tokens': [], 'seq_positions': [], 'seq_time_diffs': [], 'targets': []},
285-
'test': {'seq_tokens': [], 'seq_positions': [], 'seq_time_diffs': [], 'targets': []},
283+
'train': {
284+
'seq_tokens': [],
285+
'seq_positions': [],
286+
'seq_time_diffs': [],
287+
'targets': []
288+
},
289+
'val': {
290+
'seq_tokens': [],
291+
'seq_positions': [],
292+
'seq_time_diffs': [],
293+
'targets': []
294+
},
295+
'test': {
296+
'seq_tokens': [],
297+
'seq_positions': [],
298+
'seq_time_diffs': [],
299+
'targets': []
300+
},
286301
}
287302

288303
positions_template = list(range(self.max_seq_len))

examples/generative/run_hstu_movielens.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,10 +235,12 @@ def main(
235235
optimizer_params={
236236
"lr": learning_rate,
237237
"weight_decay": weight_decay,
238-
"betas": (0.9, 0.98),
238+
"betas": (0.9,
239+
0.98),
239240
},
240241
n_epoch=epoch,
241-
earlystop_patience=max(epoch + 1, 10),
242+
earlystop_patience=max(epoch + 1,
243+
10),
242244
device=device,
243245
model_path=save_dir,
244246
)

tests/test_hstu_model.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,13 @@ def test_vocab_mask_drops_per_user_history():
174174

175175
def test_seqtrainer_next_token_loss_masks_left_pad_positions():
176176
trainer = SeqTrainer(
177-
nn.Linear(1, 1),
177+
nn.Linear(1,
178+
1),
178179
device='cpu',
179-
loss_params={"ignore_index": 0, "reduction": "sum"},
180+
loss_params={
181+
"ignore_index": 0,
182+
"reduction": "sum"
183+
},
180184
)
181185
logits = torch.zeros(1, 4, 10)
182186
seq_tokens = torch.tensor([[0, 0, 5, 6]])

torch_rechub/basic/layers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -948,8 +948,7 @@ def __init__(self, d_model=512, n_heads=8, n_layers=4, dqk=64, dv=64, dropout=0.
948948
time_bucket_fn=time_bucket_fn,
949949
time_bucket_divisor=time_bucket_divisor,
950950
time_bucket_unit=time_bucket_unit,
951-
)
952-
for _ in range(n_layers)
951+
) for _ in range(n_layers)
953952
]
954953
)
955954

torch_rechub/models/generative/hstu.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,8 @@ def forward(self, x, time_diffs=None):
227227
"""
228228
batch_size, seq_len = x.shape
229229
if seq_len > self.max_seq_len:
230-
raise ValueError(f"Input seq_len ({seq_len}) exceeds max_seq_len ({self.max_seq_len}). " f"Either truncate the input or rebuild the model with a larger max_seq_len.")
230+
raise ValueError(f"Input seq_len ({seq_len}) exceeds max_seq_len ({self.max_seq_len}). "
231+
f"Either truncate the input or rebuild the model with a larger max_seq_len.")
231232

232233
padding_mask = x.ne(0) # (B, L) — True for valid tokens
233234

torch_rechub/trainers/seq_trainer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,8 @@ def export_onnx(self, output_path, batch_size=2, seq_length=50, vocab_size=None,
348348
overlap = set(export_kwargs.keys()) & set(onnx_export_kwargs.keys())
349349
overlap.discard("dynamo")
350350
if overlap:
351-
raise ValueError("onnx_export_kwargs contains keys that overlap with explicit args: " f"{sorted(overlap)}. Please set them via export_onnx() parameters instead.")
351+
raise ValueError("onnx_export_kwargs contains keys that overlap with explicit args: "
352+
f"{sorted(overlap)}. Please set them via export_onnx() parameters instead.")
352353
export_kwargs.update(onnx_export_kwargs)
353354

354355
# Auto-pick exporter:
@@ -473,7 +474,8 @@ def visualization(self, seq_length=50, vocab_size=None, batch_size=2, depth=3, s
473474
elif hasattr(model, 'item_num'):
474475
vocab_size = model.item_num
475476
else:
476-
raise ValueError("vocab_size must be provided or model must have " "'vocab_size' or 'item_num' attribute")
477+
raise ValueError("vocab_size must be provided or model must have "
478+
"'vocab_size' or 'item_num' attribute")
477479

478480
# Generate dummy inputs for sequence model
479481
dummy_seq_tokens = torch.randint(0, vocab_size, (batch_size, seq_length), device=viz_device)

0 commit comments

Comments
 (0)