Skip to content

Trainer.predict and Trainer.test reset model state to evaluation #21410

@sfo

Description

@sfo

Bug description

I want to have dropout active in evaluation mode to generate random outputs. For this, I set the model to train mode before performing the prediction. However, using lightning's Trainer.predict resets the model to evaluation mode, essentially disabling dropout, which leads to deterministic outputs. Running the prediction on the raw model works as expected.

Note: Contrary to the version dropdown selection, I am running version 2.6.0.

What version are you seeing the problem on?

master

Reproduced in studio

No response

How to reproduce the bug

# %%
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader, TensorDataset


# %%
class SimpleModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(10, 1)
        self.dropout = torch.nn.Dropout(0.5)

    def forward(self, x):
        return self.dropout(self.layer(x))

    def test_step(self, batch, batch_idx):
        x, y = batch
        out = self(x)
        loss = torch.nn.functional.mse_loss(out, y)
        self.log("test_loss", loss)
        return loss

    def predict_step(self, batch, batch_idx):
        x, y = batch
        return self(x)


# Data
X = torch.randn(100, 10)
y = torch.randn(100, 1)
ds = TensorDataset(X, y)
dl = DataLoader(ds, batch_size=10)

# Model
pl.seed_everything(42)
model = SimpleModel()
trainer = pl.Trainer(accelerator="cpu", devices=1)

# %%
print("--- Evaluate model using Trainer ---")
print("--- Run 1: train(False) ---")
model.train(False)
predictions = trainer.predict(model, dataloaders=dl)
y_pred_deterministic = torch.cat(predictions)

print("--- Run 2: train(True) ---")
model.train(True)
predictions = trainer.predict(model, dataloaders=dl)
y_pred_stochastic = torch.cat(predictions)

are_same = torch.allclose(y_pred_deterministic, y_pred_stochastic)
print(f"Predictions are the same: {are_same}")

print("------------------------------------")
# %%
print("--- Evaluate model using Raw Model Loop ---")
print("--- Run 3: raw train(True) ---")
model.train(False)
with torch.no_grad():
    raw_predictions = [model(x) for x, _ in dl]
y_pred_raw_deterministic = torch.cat(raw_predictions)

print("--- Run 4: raw train(True) ---")
model.train(True)
with torch.no_grad():
    raw_predictions = [model(x) for x, _ in dl]
y_pred_raw_stochastic = torch.cat(raw_predictions)

are_same = torch.allclose(y_pred_raw_deterministic, y_pred_raw_stochastic)
print(f"Raw predictions are the same: {are_same}")

# %%

Error messages and logs

# Error messages and logs here please

Environment

Current environment
  • CUDA:
    • GPU: None
    • available: False
    • version: 12.8
  • Lightning:
    • lightning-utilities: 0.15.2
    • pytorch-lightning: 2.6.0
    • torch: 2.9.1
    • torchinfo: 1.8.0
    • torchmetrics: 1.8.2
    • torchview: 0.2.7
  • Packages:
    • absl-py: 2.3.1
    • aiohappyeyeballs: 2.6.1
    • aiohttp: 3.13.2
    • aiosignal: 1.4.0
    • alembic: 1.17.2
    • alibi-detect: 0.13.0
    • altair: 6.0.0
    • annotated-doc: 0.0.4
    • annotated-types: 0.7.0
    • anyio: 4.12.0
    • appdirs: 1.4.4
    • argon2-cffi: 25.1.0
    • argon2-cffi-bindings: 25.1.0
    • arrow: 1.4.0
    • astor: 0.8.1
    • asttokens: 3.0.1
    • astunparse: 1.6.3
    • async-lru: 2.0.5
    • attrs: 25.4.0
    • autocommand: 2.2.2
    • babel: 2.17.0
    • backports-zstd: 1.2.0
    • backports.tarfile: 1.2.0
    • beautifulsoup4: 4.14.3
    • bleach: 6.3.0
    • blinker: 1.9.0
    • bokeh: 3.8.1
    • brotli: 1.2.0
    • cachetools: 6.2.2
    • cartes: 0.8.5
    • cartopy: 0.25.0
    • catalogue: 2.0.10
    • certifi: 2025.11.12
    • cffi: 2.0.0
    • charset-normalizer: 3.4.4
    • cheroot: 11.1.2
    • click: 8.2.1
    • cloudpickle: 3.1.2
    • cmdstanpy: 1.3.0
    • coloredlogs: 15.0.1
    • comm: 0.2.3
    • contourpy: 1.3.3
    • cramjam: 2.11.0
    • cryptography: 46.0.3
    • cycler: 0.12.1
    • databricks-sdk: 0.73.0
    • debugpy: 1.8.17
    • decorator: 5.2.1
    • defusedxml: 0.7.1
    • dill: 0.3.9
    • dm-tree: 0.1.9
    • docker: 7.1.0
    • etils: 1.13.0
    • executing: 2.2.1
    • fastapi: 0.124.0
    • fastjsonschema: 2.21.2
    • fastparquet: 2024.11.0
    • filelock: 3.20.0
    • flask: 3.1.2
    • flask-cors: 6.0.1
    • flatbuffers: 25.9.23
    • flexcache: 0.3
    • flexparser: 0.4
    • fonttools: 4.61.0
    • fqdn: 1.5.1
    • frozenlist: 1.8.0
    • fsspec: 2025.12.0
    • gast: 0.7.0
    • gcsfs: 2025.12.0
    • geopandas: 1.1.1
    • gitdb: 4.0.12
    • gitpython: 3.1.45
    • google-api-core: 2.28.1
    • google-auth: 2.43.0
    • google-auth-oauthlib: 1.2.2
    • google-cloud-core: 2.5.0
    • google-cloud-storage: 3.6.0
    • google-cloud-storage-control: 1.8.0
    • google-crc32c: 1.7.1
    • google-pasta: 0.2.0
    • google-resumable-media: 2.8.0
    • googleapis-common-protos: 1.72.0
    • graphene: 3.4.3
    • graphql-core: 3.2.7
    • graphql-relay: 3.2.0
    • graphviz: 0.21
    • greenlet: 3.3.0
    • grpc-google-iam-v1: 0.14.3
    • grpcio: 1.76.0
    • grpcio-status: 1.76.0
    • gunicorn: 23.0.0
    • gviz-api: 1.10.0
    • h11: 0.16.0
    • h2: 4.3.0
    • h5py: 3.15.1
    • hf-xet: 1.2.0
    • holidays: 0.86
    • hpack: 4.1.0
    • httpcore: 1.0.9
    • httpx: 0.28.1
    • huggingface-hub: 0.36.0
    • humanfriendly: 10.0
    • hyperframe: 6.1.0
    • idna: 3.11
    • imageio: 2.37.2
    • importlib-metadata: 8.7.0
    • importlib-resources: 6.5.2
    • impunity: 1.0.5
    • inflate64: 1.0.4
    • inflect: 7.3.1
    • iniconfig: 2.3.0
    • ipykernel: 7.1.0
    • ipython: 9.8.0
    • ipython-pygments-lexers: 1.1.1
    • ipywidgets: 8.1.8
    • isoduration: 20.11.0
    • itsdangerous: 2.2.0
    • jaraco-functools: 4.3.0
    • jaraco.classes: 3.4.0
    • jaraco.collections: 5.1.0
    • jaraco.context: 6.0.1
    • jaraco.functools: 4.0.1
    • jaraco.text: 3.12.1
    • jedi: 0.19.2
    • jeepney: 0.9.0
    • jinja2: 3.1.6
    • joblib: 1.5.2
    • json5: 0.12.1
    • jsonpointer: 3.0.0
    • jsonschema: 4.25.1
    • jsonschema-specifications: 2025.9.1
    • jupyter: 1.1.1
    • jupyter-client: 8.6.3
    • jupyter-console: 6.6.3
    • jupyter-core: 5.9.1
    • jupyter-events: 0.12.0
    • jupyter-lsp: 2.3.0
    • jupyter-server: 2.17.0
    • jupyter-server-terminals: 0.5.3
    • jupyterlab: 4.5.0
    • jupyterlab-pygments: 0.3.0
    • jupyterlab-server: 2.28.0
    • jupyterlab-widgets: 3.0.16
    • keopscore: 2.2.3
    • keras: 3.12.0
    • keras-tuner: 1.4.8
    • keyring: 25.7.0
    • kiwisolver: 1.4.9
    • kt-legacy: 1.0.5
    • lark: 1.3.1
    • lazy-loader: 0.4
    • libclang: 18.1.1
    • librt: 0.7.3
    • lightning-utilities: 0.15.2
    • llvmlite: 0.46.0
    • lxml: 6.0.2
    • lz4: 4.4.5
    • mako: 1.3.10
    • markdown: 3.10
    • markdown-it-py: 4.0.0
    • markupsafe: 3.0.3
    • matplotlib: 3.10.7
    • matplotlib-inline: 0.2.1
    • mdurl: 0.1.2
    • metar: 1.11.0
    • minio: 7.2.20
    • mistune: 3.1.4
    • ml-dtypes: 0.5.4
    • mlflow: 3.5.1
    • mlflow-skinny: 3.5.1
    • mlflow-tracing: 3.5.1
    • more-itertools: 10.8.0
    • mpmath: 1.3.0
    • msgpack: 1.1.2
    • multidict: 6.7.0
    • multivolumefile: 0.2.3
    • mypy: 1.19.0
    • mypy-extensions: 1.1.0
    • namex: 0.1.0
    • narwhals: 2.13.0
    • nbclient: 0.10.2
    • nbconvert: 7.16.6
    • nbformat: 5.10.4
    • nest-asyncio: 1.6.0
    • networkx: 3.6.1
    • notebook: 7.5.0
    • notebook-shim: 0.2.4
    • numba: 0.63.0
    • numpy: 2.3.5
    • nvidia-cublas-cu12: 12.8.4.1
    • nvidia-cuda-cupti-cu12: 12.8.90
    • nvidia-cuda-nvcc-cu12: 12.9.86
    • nvidia-cuda-nvrtc-cu12: 12.8.93
    • nvidia-cuda-runtime-cu12: 12.8.90
    • nvidia-cudnn-cu12: 9.10.2.21
    • nvidia-cufft-cu12: 11.3.3.83
    • nvidia-cufile-cu12: 1.13.1.3
    • nvidia-curand-cu12: 10.3.9.90
    • nvidia-cusolver-cu12: 11.7.3.90
    • nvidia-cusparse-cu12: 12.5.8.93
    • nvidia-cusparselt-cu12: 0.7.1
    • nvidia-nccl-cu12: 2.27.5
    • nvidia-nvjitlink-cu12: 12.8.93
    • nvidia-nvshmem-cu12: 3.3.20
    • nvidia-nvtx-cu12: 12.8.90
    • oauthlib: 3.3.1
    • onnxruntime: 1.23.2
    • openap: 2.4
    • opencv-python: 4.11.0.86
    • opentelemetry-api: 1.39.0
    • opentelemetry-proto: 1.39.0
    • opentelemetry-sdk: 1.39.0
    • opentelemetry-semantic-conventions: 0.60b0
    • opt-einsum: 3.4.0
    • optree: 0.18.0
    • orjson: 3.11.5
    • overrides: 7.7.0
    • packaging: 25.0
    • pandas: 2.3.3
    • pandocfilters: 1.5.1
    • parso: 0.8.5
    • pathspec: 0.12.1
    • patsy: 1.0.2
    • pexpect: 4.9.0
    • pillow: 10.4.0
    • pint: 0.25.2
    • pitot: 0.3.2
    • platformdirs: 4.5.1
    • plotly: 6.5.0
    • pluggy: 1.6.0
    • prometheus-client: 0.23.1
    • prompt-toolkit: 3.0.52
    • propcache: 0.4.1
    • properscoring: 0.1
    • prophet: 1.2.1
    • proto-plus: 1.26.1
    • protobuf: 6.33.2
    • psutil: 7.1.3
    • ptyprocess: 0.7.0
    • pure-eval: 0.2.3
    • py7zr: 1.0.0
    • pyarrow: 21.0.0
    • pyasn1: 0.6.1
    • pyasn1-modules: 0.4.2
    • pybcj: 1.0.7
    • pybind11: 3.0.1
    • pycparser: 2.23
    • pycryptodome: 3.23.0
    • pycryptodomex: 3.23.0
    • pydantic: 2.12.5
    • pydantic-core: 2.41.5
    • pygments: 2.19.2
    • pyjwt: 2.10.1
    • pykeops: 2.2.3
    • pynverse: 0.1.4.6
    • pyod: 2.0.6
    • pyogrio: 0.12.1
    • pyopensky: 2.15
    • pyparsing: 3.2.5
    • pyppmd: 1.2.0
    • pyproj: 3.7.2
    • pyshp: 3.0.3
    • pytest: 9.0.2
    • python-dateutil: 2.9.0.post0
    • python-dotenv: 1.2.1
    • python-json-logger: 4.0.0
    • pytorch-lightning: 2.6.0
    • pytz: 2025.2
    • pyyaml: 6.0.3
    • pyzmq: 27.1.0
    • pyzstd: 0.19.0
    • quantile-forest: 1.4.1
    • ray: 2.52.1
    • referencing: 0.37.0
    • regex: 2025.11.3
    • requests: 2.32.5
    • requests-oauthlib: 2.0.0
    • rfc3339-validator: 0.1.4
    • rfc3986-validator: 0.1.1
    • rfc3987-syntax: 1.1.0
    • rich: 14.2.0
    • rpds-py: 0.30.0
    • rs1090: 0.4.14
    • rsa: 4.9.1
    • ruff: 0.14.8
    • safetensors: 0.7.0
    • scikit-image: 0.25.2
    • scikit-learn: 1.7.2
    • scipy: 1.16.3
    • seaborn: 0.13.2
    • secretstorage: 3.5.0
    • send2trash: 1.8.3
    • setuptools: 80.9.0
    • sfoutils: 0.2.0
    • shap: 0.50.0
    • shapely: 2.1.2
    • six: 1.17.0
    • sklearn-quantile: 0.1.1
    • slicer: 0.0.8
    • smmap: 5.0.2
    • soupsieve: 2.8
    • sqlalchemy: 2.0.44
    • sqlparse: 0.5.4
    • stack-data: 0.6.3
    • stanio: 0.5.1
    • starlette: 0.50.0
    • statsmodels: 0.14.6
    • sympy: 1.14.0
    • tensorboard: 2.20.0
    • tensorboard-data-server: 0.7.2
    • tensorboard-plugin-profile: 2.21.3
    • tensorboardx: 2.6.4
    • tensorflow: 2.20.0
    • tensorflow-docs: 2025.12.2.70325
    • tensorflow-probability: 0.25.0
    • termcolor: 3.2.0
    • terminado: 0.18.1
    • texttable: 1.7.0
    • tf-keras: 2.20.1
    • threadpoolctl: 3.6.0
    • tifffile: 2025.10.16
    • tinycss2: 1.4.0
    • tokenizers: 0.21.4
    • toml: 0.10.2
    • tomli: 2.0.1
    • torch: 2.9.1
    • torchinfo: 1.8.0
    • torchmetrics: 1.8.2
    • torchview: 0.2.7
    • tornado: 6.5.2
    • tqdm: 4.67.1
    • traffic: 2.13.post16.dev0+9e00a83
    • traitlets: 5.14.3
    • transformers: 4.51.3
    • trino: 0.336.0
    • triton: 3.5.1
    • tudcolors: 0.0.1
    • typeguard: 4.3.0
    • types-protobuf: 6.32.1.20251105
    • types-requests: 2.32.4.20250913
    • types-tensorflow: 2.18.0.20251008
    • typing-extensions: 4.15.0
    • typing-inspection: 0.4.2
    • tzdata: 2025.2
    • tzlocal: 5.3.1
    • uri-template: 1.3.0
    • urllib3: 2.6.1
    • uvicorn: 0.38.0
    • wcwidth: 0.2.14
    • webcolors: 25.10.0
    • webencodings: 0.5.1
    • websocket-client: 1.9.0
    • werkzeug: 3.1.4
    • wheel: 0.45.1
    • widgetsnbextension: 4.0.15
    • wrapt: 2.0.1
    • xprof: 2.21.3
    • xyzservices: 2025.11.0
    • yarl: 1.22.0
    • zipp: 3.23.0
    • zstandard: 0.25.0
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor:
    • python: 3.13.9
    • release: 6.17.9-200.fc42.x86_64
    • version: Proposal for help #1 SMP PREEMPT_DYNAMIC Mon Nov 24 22:28:05 UTC 2025

More info

No response

cc @ethanwharris

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions