Skip to content

predict() is very slow latency wise. #1384

@LukePogaPersonal

Description

@LukePogaPersonal

What happened + What you expected to happen

when i run inference, ie. predict(), it takes at minimum 100ms on gpu or cpu, on every server or workstation.

when i run the same model,eg. NHITs in an ONNX equivalent, it runs 1ms on cpu.

I can not find any way to reduce latency on predict() to get it below 90ms which is staggeringly slow, and unuseable in real world. it would be great to have a way not to export the model to ONNX before I can use it. any thoughts?

Versions / Dependencies

latest, gtx 5090 etc.

Reproduction script

import time
import os
import numpy as np
import pandas as pd
import torch

def generate_synthetic_df(total_points: int, freq: str = 'h') -> pd.DataFrame:
np.random.seed(42)
torch.manual_seed(42)

dates = pd.date_range(start='2025-01-01', periods=total_points, freq=freq)
# simple trend + seasonality + noise
t = np.arange(total_points)
seasonal = 2.0 * np.sin(2 * np.pi * t / 24) + 0.5 * np.sin(2 * np.pi * t / (24 * 7))
trend = 0.001 * t
noise = np.random.normal(scale=0.5, size=total_points)
y = trend + seasonal + noise

df = pd.DataFrame({
    'unique_id': 'series_1',
    'ds': dates,
    'y': y.astype(float),
})
return df

def gpu_is_usable() -> bool:
if not torch.cuda.is_available():
return False
try:
# Check if the installed wheel includes our device's SM in arch list
get_arch_list = getattr(torch.cuda, 'get_arch_list', None)
if callable(get_arch_list):
archs = set(get_arch_list())
major, minor = torch.cuda.get_device_capability(0)
sm = f"sm
{major}{minor}"
if sm not in archs:
return False
# Try a tiny CUDA op to catch runtime failures
x = torch.zeros(1, device='cuda')
torch.cuda.synchronize()
del x
return True
except Exception:
return False

def main():
# Defer heavy imports to runtime to keep import-time light for tests
from neuralforecast import NeuralForecast
from neuralforecast.models import NHITS

# Speed knobs for CI/tests
try:
    torch.set_num_threads(int(os.getenv('TORCH_NUM_THREADS', '1')))
except Exception:
    pass

lookback = 720
horizon = 1200
quantiles = [0.01, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99]
max_steps = int(os.getenv('LATENCY_MAX_STEPS', '10'))

total_points = lookback + horizon
df = generate_synthetic_df(total_points=total_points, freq='h')

# Train on lookback + horizon to ensure at least one training window exists
train_df = df.iloc[: lookback + horizon].copy()

# Configure NHITS with robust quantile loss handling across versions
loss = 'MAE'
quantiles_arg = None
try:
    from neuralforecast.losses.pytorch import MQLoss  # type: ignore
    loss = MQLoss(quantiles=quantiles)
except Exception:
    # Fallback for versions expecting quantiles arg on the model
    quantiles_arg = quantiles

# Decide accelerator and precision
want_gpu = os.getenv('USE_GPU', '1').lower() not in {'0', 'false', 'no'}
can_use_gpu = _gpu_is_usable() if want_gpu else False
if can_use_gpu:
    try:
        torch.set_float32_matmul_precision(os.getenv('TORCH_F32_MATMUL', 'high'))
    except Exception:
        pass

trainer_kwargs = {
    'accelerator': 'gpu' if can_use_gpu else 'cpu',
    'devices': 1,
    'enable_progress_bar': False,
    'logger': False,
}
if can_use_gpu:
    # Mixed precision can speed up inference on modern GPUs
    precision = os.getenv('LATENCY_PRECISION', '16-mixed')
    trainer_kwargs['precision'] = precision

model_kwargs = dict(
    h=horizon,
    input_size=lookback,
    loss=loss,
    max_steps=max_steps,
    scaler_type='identity',
)
if quantiles_arg is not None:
    model_kwargs['quantiles'] = quantiles_arg

# Unpack trainer kwargs at top level so PL Trainer receives them
model = NHITS(**model_kwargs, **trainer_kwargs)

# Optional torch.compile on the model to reduce Python overhead/JIT optimize
if os.getenv('USE_TORCH_COMPILE', '0').lower() in {'1', 'true', 'yes'}:
    try:
        model = torch.compile(model, mode=os.getenv('TORCH_COMPILE_MODE', 'max-autotune'))
    except Exception as e:
        print(f"torch.compile skipped: {e}")

nf = NeuralForecast(models=[model], freq='h')

# Fit quickly
nf.fit(df=train_df)

# Warm-up call to exclude any first-call overheads (JIT, caching, etc.)
try:
    with torch.inference_mode():
        nf.predict()
except Exception:
    # Some versions require passing full df; pass last input window
    with torch.inference_mode():
        nf.predict(df=train_df)

# Time a single predict() call with CUDA synchronization if GPU
if can_use_gpu:
    torch.cuda.synchronize()
start = time.perf_counter()
with torch.inference_mode():
    preds = nf.predict()
if can_use_gpu:
    torch.cuda.synchronize()
elapsed_s = time.perf_counter() - start

# Output simple stats and latency
num_rows = len(preds)
num_cols = len(preds.columns)
device_str = 'cuda' if can_use_gpu else 'cpu'
print(f"device: {device_str}")
print(f"predict() latency: {elapsed_s * 1000:.2f} ms")
print(f"predictions shape: rows={num_rows}, cols={num_cols}")
print(preds.head().to_string(index=False))

if name == 'main':
main()

Issue Severity

None

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions