-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbasic_run.py
More file actions
40 lines (31 loc) · 1.89 KB
/
basic_run.py
File metadata and controls
40 lines (31 loc) · 1.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""
examples/basic_run.py — Minimal Inference Example
===================================================
Loads (or instantiates) a CSNS model and runs a single forward pass.
Demonstrates how to inspect the latent representation and stability score.
"""
import torch
import sys
import os
# Allow running from examples/ directly
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from csns import CSNS
# ── Instantiate model ─────────────────────────────────────────────────────────
model = CSNS(dim=32)
model.eval()
# ── Single forward pass ───────────────────────────────────────────────────────
x = torch.randn(8, 32) # batch of 8
with torch.no_grad():
output, latent = model(x)
# ── Diagnostics ───────────────────────────────────────────────────────────────
stability = model.evaluator.score(latent).item()
resistance = model.resistance_loss(latent).item()
latent_norm = torch.norm(latent, dim=-1).mean().item()
print("── CSNS Inference ──────────────────────────────")
print(f" Input shape : {list(x.shape)}")
print(f" Output shape : {list(output.shape)}")
print(f" Latent shape : {list(latent.shape)}")
print(f" Mean latent ‖·‖: {latent_norm:.4f} (ceiling={model.constraint.max_norm})")
print(f" Resistance : {resistance:.4f}")
print(f" Stability score: {stability:.4f} (higher = more stable)")
print("────────────────────────────────────────────────")