Skip to content

Commit 257757c

Browse files
authored
Add demo prediction script for 1D CNN classifier
This script demonstrates the prediction process for a 1D CNN classifier using dummy tokenized sequences. It includes model initialization, forward pass, and output of probabilities.
1 parent 9d0ae90 commit 257757c

1 file changed

Lines changed: 41 additions & 0 deletions

File tree

demo_predict.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""
2+
Demo prediction script for the 1D CNN complaint classifier.
3+
4+
This uses a dummy tokenized sequence to show the forward pass.
5+
Replace the dummy_ids with real token IDs from your tokenizer if desired.
6+
"""
7+
8+
import torch
9+
from train_1d_cnn import ComplaintCNN # ensure this matches your model class name
10+
11+
12+
def main():
13+
# These hyperparameters must match how the model is defined in train_1d_cnn.py
14+
vocab_size = 5000
15+
embed_dim = 64
16+
num_classes = 3
17+
max_len = 50
18+
19+
model = ComplaintCNN(
20+
vocab_size=vocab_size,
21+
embed_dim=embed_dim,
22+
num_classes=num_classes,
23+
)
24+
25+
# Dummy batch of token ids: [batch_size, seq_len]
26+
batch_size = 2
27+
dummy_ids = torch.randint(0, vocab_size, (batch_size, max_len))
28+
29+
# Forward pass
30+
logits = model(dummy_ids)
31+
probs = torch.softmax(logits, dim=-1)
32+
33+
print(f"Input shape: {dummy_ids.shape}")
34+
print(f"Logits shape: {logits.shape}")
35+
print("Probabilities:")
36+
print(probs)
37+
38+
39+
if __name__ == "__main__":
40+
main()
41+

0 commit comments

Comments
 (0)