-
Notifications
You must be signed in to change notification settings - Fork 66
Open
Labels
Description
Summary
When using nnsight to access intermediate tensors (e.g., layer outputs, logits), the tensor shapes differ between local and remote execution:
- Local: Tensors include batch dimension
[batch, seq_len, hidden_dim] - Remote: Tensors have batch dimension squeezed
[seq_len, hidden_dim]
This inconsistency requires users to write defensive code that checks tensor dimensions, which is error-prone and unexpected.
Minimal Reproducible Example
import os
from nnsight import LanguageModel, CONFIG
# Configure NDIF (replace with your key)
CONFIG.set_default_api_key(os.environ.get('NDIF_API'))
# Use a model available on NDIF
model = LanguageModel('meta-llama/Llama-3.1-8B', device_map='auto')
token_ids = model.tokenizer.encode('Hello world')
print(f"Input: {len(token_ids)} tokens")
# --- LOCAL EXECUTION ---
with model.trace(token_ids, remote=False):
hidden_local = model.model.layers[0].output[0]
logits_local = model.lm_head(model.model.norm(hidden_local))
shapes_local = {
'hidden': hidden_local.shape,
'logits': logits_local.shape
}.save()
print(f"\nLOCAL execution:")
print(f" hidden shape: {shapes_local['hidden']}")
print(f" logits shape: {shapes_local['logits']}")
# --- REMOTE EXECUTION ---
with model.trace(token_ids, remote=True):
hidden_remote = model.model.layers[0].output[0]
logits_remote = model.lm_head(model.model.norm(hidden_remote))
shapes_remote = {
'hidden': hidden_remote.shape,
'logits': logits_remote.shape
}.save()
print(f"\nREMOTE execution:")
print(f" hidden shape: {shapes_remote['hidden']}")
print(f" logits shape: {shapes_remote['logits']}")Expected Output (consistent behavior)
Input: 2 tokens
LOCAL execution:
hidden shape: torch.Size([1, 2, 4096])
logits shape: torch.Size([1, 2, 128256])
REMOTE execution:
hidden shape: torch.Size([1, 2, 4096])
logits shape: torch.Size([1, 2, 128256])
Actual Output (inconsistent behavior)
Input: 2 tokens
LOCAL execution:
hidden shape: torch.Size([1, 2, 4096])
logits shape: torch.Size([1, 2, 128256])
REMOTE execution:
hidden shape: torch.Size([2, 4096])
logits shape: torch.Size([2, 128256])
Impact
This inconsistency forces library authors to write defensive code like:
# Workaround for batch dimension inconsistency
if logits.dim() == 3:
logits = logits.squeeze(0) # Remove batch dim if presentThis is:
- Error-prone: Easy to forget, causes cryptic IndexError when code works locally but fails remotely
- Unexpected: Users expect the same tensor shapes regardless of execution backend
- Undocumented: Not mentioned in nnsight documentation
Suggested Fix
Either:
- Remote should preserve batch dimension (preferred - matches PyTorch conventions)
- Local should squeeze batch=1 (would be a breaking change)
- Document this explicitly if intentional (but consistency would be better)
Environment
- nnsight version: 0.5.13
- Python: 3.12
- Model tested: meta-llama/Llama-3.1-8B on NDIF
Butanium