Skip to content

Some metrics cannot handle GPU tensors #115

@jannik-el

Description

@jannik-el

Libraries Requiring CPU Tensors:

  1. torchsurv.metrics.ConcordanceIndex()
  • Cannot handle GPU tensors
  • Requires manual .cpu() transfer before computation
  1. torchsurv.metrics.Auc()
  • Cannot handle GPU tensors
  • Internal tensor operations create CPU tensors causing device mismatch
  • Requires manual .cpu() transfer before computation

Specific Error Patterns:

  • BatchNorm Error: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
  • Tensor Cat Error: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0!

Minimal Example for testing if metric will work with tensor on GPU

import torch
import numpy as np
from torchsurv.metrics.cindex import ConcordanceIndex
from torchsurv.metrics.auc import Auc

print("🔧 Checking device availability...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create sample survival data
n_samples = 1000
n_features = 10

print("\n📊 Creating sample survival data...")
# Features (risk scores from a model)
log_hazards = torch.randn(n_samples, 1)
# Events (1 = event occurred, 0 = censored)
events = torch.bernoulli(torch.full((n_samples,), 0.3))  # 30% event rate
# Survival times
times = torch.exponential(torch.ones(n_samples))

print(f"Sample data created: {n_samples} observations, {events.sum().item():.0f} events")

# Move tensors to GPU if available
print(f"\n🚀 Moving tensors to {device}...")
log_hazards_gpu = log_hazards.to(device)
events_gpu = events.bool().to(device)
times_gpu = times.to(device)

print(f"✅ Tensors on device: {log_hazards_gpu.device}")

# Attempt 1: Try TorchSurv metrics with GPU tensors (THIS WILL FAIL)
print("\n❌ ATTEMPT 1: Using TorchSurv metrics directly with GPU tensors")
print("=" * 60)

try:
    print("   🧪 Testing ConcordanceIndex with GPU tensors...")
    cindex_metric = ConcordanceIndex()
    cindex = cindex_metric(log_hazards_gpu, events_gpu, times_gpu)
    print(f"   ✅ C-index: {cindex:.4f}")
except Exception as e:
    print(f"   💥 ConcordanceIndex FAILED: {type(e).__name__}: {e}")

try:
    print("   🧪 Testing AUC with GPU tensors...")
    auc_metric = Auc()
    auc = auc_metric(log_hazards_gpu, events_gpu, times_gpu, new_time=torch.tensor(1.0).to(device))
    print(f"   ✅ AUC: {auc:.4f}")
except Exception as e:
    print(f"   💥 AUC FAILED: {type(e).__name__}: {e}")

# Attempt 2: Workaround - move tensors to CPU first (THIS WORKS)
print("\n✅ ATTEMPT 2: Workaround - moving tensors to CPU first")
print("=" * 60)

print("   ⚠️  Moving tensors from GPU to CPU for TorchSurv compatibility...")
log_hazards_cpu = log_hazards_gpu.cpu()
events_cpu = events_gpu.cpu()
times_cpu = times_gpu.cpu()

try:
    print("   🧪 Testing ConcordanceIndex with CPU tensors...")
    cindex_metric = ConcordanceIndex()
    cindex = cindex_metric(log_hazards_cpu, events_cpu, times_cpu)
    print(f"   ✅ C-index: {cindex:.4f}")
except Exception as e:
    print(f"   💥 ConcordanceIndex FAILED: {type(e).__name__}: {e}")

try:
    print("   🧪 Testing AUC with CPU tensors...")
    auc_metric = Auc()
    auc = auc_metric(log_hazards_cpu, events_cpu, times_cpu, new_time=torch.tensor(1.0))
    print(f"   ✅ AUC: {auc:.4f}")
except Exception as e:
    print(f"   💥 AUC FAILED: {type(e).__name__}: {e}")

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions