-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
bugSomething isn't workingSomething isn't workingver: 2.5.xworking as intendedWorking as intendedWorking as intended
Description
Bug description
I want to have dropout active in evaluation mode to generate random outputs. For this, I set the model to train mode before performing the prediction. However, using lightning's Trainer.predict resets the model to evaluation mode, essentially disabling dropout, which leads to deterministic outputs. Running the prediction on the raw model works as expected.
Note: Contrary to the version dropdown selection, I am running version 2.6.0.
What version are you seeing the problem on?
master
Reproduced in studio
No response
How to reproduce the bug
# %%
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader, TensorDataset
# %%
class SimpleModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(10, 1)
self.dropout = torch.nn.Dropout(0.5)
def forward(self, x):
return self.dropout(self.layer(x))
def test_step(self, batch, batch_idx):
x, y = batch
out = self(x)
loss = torch.nn.functional.mse_loss(out, y)
self.log("test_loss", loss)
return loss
def predict_step(self, batch, batch_idx):
x, y = batch
return self(x)
# Data
X = torch.randn(100, 10)
y = torch.randn(100, 1)
ds = TensorDataset(X, y)
dl = DataLoader(ds, batch_size=10)
# Model
pl.seed_everything(42)
model = SimpleModel()
trainer = pl.Trainer(accelerator="cpu", devices=1)
# %%
print("--- Evaluate model using Trainer ---")
print("--- Run 1: train(False) ---")
model.train(False)
predictions = trainer.predict(model, dataloaders=dl)
y_pred_deterministic = torch.cat(predictions)
print("--- Run 2: train(True) ---")
model.train(True)
predictions = trainer.predict(model, dataloaders=dl)
y_pred_stochastic = torch.cat(predictions)
are_same = torch.allclose(y_pred_deterministic, y_pred_stochastic)
print(f"Predictions are the same: {are_same}")
print("------------------------------------")
# %%
print("--- Evaluate model using Raw Model Loop ---")
print("--- Run 3: raw train(True) ---")
model.train(False)
with torch.no_grad():
raw_predictions = [model(x) for x, _ in dl]
y_pred_raw_deterministic = torch.cat(raw_predictions)
print("--- Run 4: raw train(True) ---")
model.train(True)
with torch.no_grad():
raw_predictions = [model(x) for x, _ in dl]
y_pred_raw_stochastic = torch.cat(raw_predictions)
are_same = torch.allclose(y_pred_raw_deterministic, y_pred_raw_stochastic)
print(f"Raw predictions are the same: {are_same}")
# %%Error messages and logs
# Error messages and logs here please
Environment
Current environment
- CUDA:
- GPU: None
- available: False
- version: 12.8
- Lightning:
- lightning-utilities: 0.15.2
- pytorch-lightning: 2.6.0
- torch: 2.9.1
- torchinfo: 1.8.0
- torchmetrics: 1.8.2
- torchview: 0.2.7
- Packages:
- absl-py: 2.3.1
- aiohappyeyeballs: 2.6.1
- aiohttp: 3.13.2
- aiosignal: 1.4.0
- alembic: 1.17.2
- alibi-detect: 0.13.0
- altair: 6.0.0
- annotated-doc: 0.0.4
- annotated-types: 0.7.0
- anyio: 4.12.0
- appdirs: 1.4.4
- argon2-cffi: 25.1.0
- argon2-cffi-bindings: 25.1.0
- arrow: 1.4.0
- astor: 0.8.1
- asttokens: 3.0.1
- astunparse: 1.6.3
- async-lru: 2.0.5
- attrs: 25.4.0
- autocommand: 2.2.2
- babel: 2.17.0
- backports-zstd: 1.2.0
- backports.tarfile: 1.2.0
- beautifulsoup4: 4.14.3
- bleach: 6.3.0
- blinker: 1.9.0
- bokeh: 3.8.1
- brotli: 1.2.0
- cachetools: 6.2.2
- cartes: 0.8.5
- cartopy: 0.25.0
- catalogue: 2.0.10
- certifi: 2025.11.12
- cffi: 2.0.0
- charset-normalizer: 3.4.4
- cheroot: 11.1.2
- click: 8.2.1
- cloudpickle: 3.1.2
- cmdstanpy: 1.3.0
- coloredlogs: 15.0.1
- comm: 0.2.3
- contourpy: 1.3.3
- cramjam: 2.11.0
- cryptography: 46.0.3
- cycler: 0.12.1
- databricks-sdk: 0.73.0
- debugpy: 1.8.17
- decorator: 5.2.1
- defusedxml: 0.7.1
- dill: 0.3.9
- dm-tree: 0.1.9
- docker: 7.1.0
- etils: 1.13.0
- executing: 2.2.1
- fastapi: 0.124.0
- fastjsonschema: 2.21.2
- fastparquet: 2024.11.0
- filelock: 3.20.0
- flask: 3.1.2
- flask-cors: 6.0.1
- flatbuffers: 25.9.23
- flexcache: 0.3
- flexparser: 0.4
- fonttools: 4.61.0
- fqdn: 1.5.1
- frozenlist: 1.8.0
- fsspec: 2025.12.0
- gast: 0.7.0
- gcsfs: 2025.12.0
- geopandas: 1.1.1
- gitdb: 4.0.12
- gitpython: 3.1.45
- google-api-core: 2.28.1
- google-auth: 2.43.0
- google-auth-oauthlib: 1.2.2
- google-cloud-core: 2.5.0
- google-cloud-storage: 3.6.0
- google-cloud-storage-control: 1.8.0
- google-crc32c: 1.7.1
- google-pasta: 0.2.0
- google-resumable-media: 2.8.0
- googleapis-common-protos: 1.72.0
- graphene: 3.4.3
- graphql-core: 3.2.7
- graphql-relay: 3.2.0
- graphviz: 0.21
- greenlet: 3.3.0
- grpc-google-iam-v1: 0.14.3
- grpcio: 1.76.0
- grpcio-status: 1.76.0
- gunicorn: 23.0.0
- gviz-api: 1.10.0
- h11: 0.16.0
- h2: 4.3.0
- h5py: 3.15.1
- hf-xet: 1.2.0
- holidays: 0.86
- hpack: 4.1.0
- httpcore: 1.0.9
- httpx: 0.28.1
- huggingface-hub: 0.36.0
- humanfriendly: 10.0
- hyperframe: 6.1.0
- idna: 3.11
- imageio: 2.37.2
- importlib-metadata: 8.7.0
- importlib-resources: 6.5.2
- impunity: 1.0.5
- inflate64: 1.0.4
- inflect: 7.3.1
- iniconfig: 2.3.0
- ipykernel: 7.1.0
- ipython: 9.8.0
- ipython-pygments-lexers: 1.1.1
- ipywidgets: 8.1.8
- isoduration: 20.11.0
- itsdangerous: 2.2.0
- jaraco-functools: 4.3.0
- jaraco.classes: 3.4.0
- jaraco.collections: 5.1.0
- jaraco.context: 6.0.1
- jaraco.functools: 4.0.1
- jaraco.text: 3.12.1
- jedi: 0.19.2
- jeepney: 0.9.0
- jinja2: 3.1.6
- joblib: 1.5.2
- json5: 0.12.1
- jsonpointer: 3.0.0
- jsonschema: 4.25.1
- jsonschema-specifications: 2025.9.1
- jupyter: 1.1.1
- jupyter-client: 8.6.3
- jupyter-console: 6.6.3
- jupyter-core: 5.9.1
- jupyter-events: 0.12.0
- jupyter-lsp: 2.3.0
- jupyter-server: 2.17.0
- jupyter-server-terminals: 0.5.3
- jupyterlab: 4.5.0
- jupyterlab-pygments: 0.3.0
- jupyterlab-server: 2.28.0
- jupyterlab-widgets: 3.0.16
- keopscore: 2.2.3
- keras: 3.12.0
- keras-tuner: 1.4.8
- keyring: 25.7.0
- kiwisolver: 1.4.9
- kt-legacy: 1.0.5
- lark: 1.3.1
- lazy-loader: 0.4
- libclang: 18.1.1
- librt: 0.7.3
- lightning-utilities: 0.15.2
- llvmlite: 0.46.0
- lxml: 6.0.2
- lz4: 4.4.5
- mako: 1.3.10
- markdown: 3.10
- markdown-it-py: 4.0.0
- markupsafe: 3.0.3
- matplotlib: 3.10.7
- matplotlib-inline: 0.2.1
- mdurl: 0.1.2
- metar: 1.11.0
- minio: 7.2.20
- mistune: 3.1.4
- ml-dtypes: 0.5.4
- mlflow: 3.5.1
- mlflow-skinny: 3.5.1
- mlflow-tracing: 3.5.1
- more-itertools: 10.8.0
- mpmath: 1.3.0
- msgpack: 1.1.2
- multidict: 6.7.0
- multivolumefile: 0.2.3
- mypy: 1.19.0
- mypy-extensions: 1.1.0
- namex: 0.1.0
- narwhals: 2.13.0
- nbclient: 0.10.2
- nbconvert: 7.16.6
- nbformat: 5.10.4
- nest-asyncio: 1.6.0
- networkx: 3.6.1
- notebook: 7.5.0
- notebook-shim: 0.2.4
- numba: 0.63.0
- numpy: 2.3.5
- nvidia-cublas-cu12: 12.8.4.1
- nvidia-cuda-cupti-cu12: 12.8.90
- nvidia-cuda-nvcc-cu12: 12.9.86
- nvidia-cuda-nvrtc-cu12: 12.8.93
- nvidia-cuda-runtime-cu12: 12.8.90
- nvidia-cudnn-cu12: 9.10.2.21
- nvidia-cufft-cu12: 11.3.3.83
- nvidia-cufile-cu12: 1.13.1.3
- nvidia-curand-cu12: 10.3.9.90
- nvidia-cusolver-cu12: 11.7.3.90
- nvidia-cusparse-cu12: 12.5.8.93
- nvidia-cusparselt-cu12: 0.7.1
- nvidia-nccl-cu12: 2.27.5
- nvidia-nvjitlink-cu12: 12.8.93
- nvidia-nvshmem-cu12: 3.3.20
- nvidia-nvtx-cu12: 12.8.90
- oauthlib: 3.3.1
- onnxruntime: 1.23.2
- openap: 2.4
- opencv-python: 4.11.0.86
- opentelemetry-api: 1.39.0
- opentelemetry-proto: 1.39.0
- opentelemetry-sdk: 1.39.0
- opentelemetry-semantic-conventions: 0.60b0
- opt-einsum: 3.4.0
- optree: 0.18.0
- orjson: 3.11.5
- overrides: 7.7.0
- packaging: 25.0
- pandas: 2.3.3
- pandocfilters: 1.5.1
- parso: 0.8.5
- pathspec: 0.12.1
- patsy: 1.0.2
- pexpect: 4.9.0
- pillow: 10.4.0
- pint: 0.25.2
- pitot: 0.3.2
- platformdirs: 4.5.1
- plotly: 6.5.0
- pluggy: 1.6.0
- prometheus-client: 0.23.1
- prompt-toolkit: 3.0.52
- propcache: 0.4.1
- properscoring: 0.1
- prophet: 1.2.1
- proto-plus: 1.26.1
- protobuf: 6.33.2
- psutil: 7.1.3
- ptyprocess: 0.7.0
- pure-eval: 0.2.3
- py7zr: 1.0.0
- pyarrow: 21.0.0
- pyasn1: 0.6.1
- pyasn1-modules: 0.4.2
- pybcj: 1.0.7
- pybind11: 3.0.1
- pycparser: 2.23
- pycryptodome: 3.23.0
- pycryptodomex: 3.23.0
- pydantic: 2.12.5
- pydantic-core: 2.41.5
- pygments: 2.19.2
- pyjwt: 2.10.1
- pykeops: 2.2.3
- pynverse: 0.1.4.6
- pyod: 2.0.6
- pyogrio: 0.12.1
- pyopensky: 2.15
- pyparsing: 3.2.5
- pyppmd: 1.2.0
- pyproj: 3.7.2
- pyshp: 3.0.3
- pytest: 9.0.2
- python-dateutil: 2.9.0.post0
- python-dotenv: 1.2.1
- python-json-logger: 4.0.0
- pytorch-lightning: 2.6.0
- pytz: 2025.2
- pyyaml: 6.0.3
- pyzmq: 27.1.0
- pyzstd: 0.19.0
- quantile-forest: 1.4.1
- ray: 2.52.1
- referencing: 0.37.0
- regex: 2025.11.3
- requests: 2.32.5
- requests-oauthlib: 2.0.0
- rfc3339-validator: 0.1.4
- rfc3986-validator: 0.1.1
- rfc3987-syntax: 1.1.0
- rich: 14.2.0
- rpds-py: 0.30.0
- rs1090: 0.4.14
- rsa: 4.9.1
- ruff: 0.14.8
- safetensors: 0.7.0
- scikit-image: 0.25.2
- scikit-learn: 1.7.2
- scipy: 1.16.3
- seaborn: 0.13.2
- secretstorage: 3.5.0
- send2trash: 1.8.3
- setuptools: 80.9.0
- sfoutils: 0.2.0
- shap: 0.50.0
- shapely: 2.1.2
- six: 1.17.0
- sklearn-quantile: 0.1.1
- slicer: 0.0.8
- smmap: 5.0.2
- soupsieve: 2.8
- sqlalchemy: 2.0.44
- sqlparse: 0.5.4
- stack-data: 0.6.3
- stanio: 0.5.1
- starlette: 0.50.0
- statsmodels: 0.14.6
- sympy: 1.14.0
- tensorboard: 2.20.0
- tensorboard-data-server: 0.7.2
- tensorboard-plugin-profile: 2.21.3
- tensorboardx: 2.6.4
- tensorflow: 2.20.0
- tensorflow-docs: 2025.12.2.70325
- tensorflow-probability: 0.25.0
- termcolor: 3.2.0
- terminado: 0.18.1
- texttable: 1.7.0
- tf-keras: 2.20.1
- threadpoolctl: 3.6.0
- tifffile: 2025.10.16
- tinycss2: 1.4.0
- tokenizers: 0.21.4
- toml: 0.10.2
- tomli: 2.0.1
- torch: 2.9.1
- torchinfo: 1.8.0
- torchmetrics: 1.8.2
- torchview: 0.2.7
- tornado: 6.5.2
- tqdm: 4.67.1
- traffic: 2.13.post16.dev0+9e00a83
- traitlets: 5.14.3
- transformers: 4.51.3
- trino: 0.336.0
- triton: 3.5.1
- tudcolors: 0.0.1
- typeguard: 4.3.0
- types-protobuf: 6.32.1.20251105
- types-requests: 2.32.4.20250913
- types-tensorflow: 2.18.0.20251008
- typing-extensions: 4.15.0
- typing-inspection: 0.4.2
- tzdata: 2025.2
- tzlocal: 5.3.1
- uri-template: 1.3.0
- urllib3: 2.6.1
- uvicorn: 0.38.0
- wcwidth: 0.2.14
- webcolors: 25.10.0
- webencodings: 0.5.1
- websocket-client: 1.9.0
- werkzeug: 3.1.4
- wheel: 0.45.1
- widgetsnbextension: 4.0.15
- wrapt: 2.0.1
- xprof: 2.21.3
- xyzservices: 2025.11.0
- yarl: 1.22.0
- zipp: 3.23.0
- zstandard: 0.25.0
- System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor:
- python: 3.13.9
- release: 6.17.9-200.fc42.x86_64
- version: Proposal for help #1 SMP PREEMPT_DYNAMIC Mon Nov 24 22:28:05 UTC 2025
More info
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingver: 2.5.xworking as intendedWorking as intendedWorking as intended