Skip to content

Batch Dimension Inconsistency Between Local and Remote Execution #581

@davidbau

Description

@davidbau

Summary

When using nnsight to access intermediate tensors (e.g., layer outputs, logits), the tensor shapes differ between local and remote execution:

  • Local: Tensors include batch dimension [batch, seq_len, hidden_dim]
  • Remote: Tensors have batch dimension squeezed [seq_len, hidden_dim]

This inconsistency requires users to write defensive code that checks tensor dimensions, which is error-prone and unexpected.

Minimal Reproducible Example

import os
from nnsight import LanguageModel, CONFIG

# Configure NDIF (replace with your key)
CONFIG.set_default_api_key(os.environ.get('NDIF_API'))

# Use a model available on NDIF
model = LanguageModel('meta-llama/Llama-3.1-8B', device_map='auto')

token_ids = model.tokenizer.encode('Hello world')
print(f"Input: {len(token_ids)} tokens")

# --- LOCAL EXECUTION ---
with model.trace(token_ids, remote=False):
    hidden_local = model.model.layers[0].output[0]
    logits_local = model.lm_head(model.model.norm(hidden_local))
    shapes_local = {
        'hidden': hidden_local.shape,
        'logits': logits_local.shape
    }.save()

print(f"\nLOCAL execution:")
print(f"  hidden shape: {shapes_local['hidden']}")
print(f"  logits shape: {shapes_local['logits']}")

# --- REMOTE EXECUTION ---
with model.trace(token_ids, remote=True):
    hidden_remote = model.model.layers[0].output[0]
    logits_remote = model.lm_head(model.model.norm(hidden_remote))
    shapes_remote = {
        'hidden': hidden_remote.shape,
        'logits': logits_remote.shape
    }.save()

print(f"\nREMOTE execution:")
print(f"  hidden shape: {shapes_remote['hidden']}")
print(f"  logits shape: {shapes_remote['logits']}")

Expected Output (consistent behavior)

Input: 2 tokens

LOCAL execution:
  hidden shape: torch.Size([1, 2, 4096])
  logits shape: torch.Size([1, 2, 128256])

REMOTE execution:
  hidden shape: torch.Size([1, 2, 4096])
  logits shape: torch.Size([1, 2, 128256])

Actual Output (inconsistent behavior)

Input: 2 tokens

LOCAL execution:
  hidden shape: torch.Size([1, 2, 4096])
  logits shape: torch.Size([1, 2, 128256])

REMOTE execution:
  hidden shape: torch.Size([2, 4096])
  logits shape: torch.Size([2, 128256])

Impact

This inconsistency forces library authors to write defensive code like:

# Workaround for batch dimension inconsistency
if logits.dim() == 3:
    logits = logits.squeeze(0)  # Remove batch dim if present

This is:

  1. Error-prone: Easy to forget, causes cryptic IndexError when code works locally but fails remotely
  2. Unexpected: Users expect the same tensor shapes regardless of execution backend
  3. Undocumented: Not mentioned in nnsight documentation

Suggested Fix

Either:

  1. Remote should preserve batch dimension (preferred - matches PyTorch conventions)
  2. Local should squeeze batch=1 (would be a breaking change)
  3. Document this explicitly if intentional (but consistency would be better)

Environment

  • nnsight version: 0.5.13
  • Python: 3.12
  • Model tested: meta-llama/Llama-3.1-8B on NDIF

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions