Skip to content

Commit 9f8e318

Browse files
committed
adding seeded dropout
1 parent d43b94b commit 9f8e318

File tree

2 files changed

+89
-0
lines changed

2 files changed

+89
-0
lines changed

makani/models/common/context.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 20245 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import torch
17+
from torch import nn
18+
from typing import Optional
19+
from contextlib import contextmanager
20+
21+
@contextmanager
22+
def rng_context(cpu_rng: torch.Generator, device_rng: Optional[torch.Generator] = None):
23+
"""
24+
Context manager for temporarily setting CPU and device RNG states.
25+
26+
This context manager allows you to temporarily set specific RNG states
27+
for reproducibility, then automatically restore the original global states.
28+
29+
Parameters
30+
----------
31+
cpu_rng_state : torch.Tensor
32+
CPU RNG state to set temporarily
33+
device_rng_state : torch.Tensor, optional
34+
Device (CUDA) RNG state to set temporarily. Uses current device.
35+
36+
Examples
37+
--------
38+
>>> # Save current states
39+
>>> cpu_state = torch.get_rng_state()
40+
>>> device_state = torch.cuda.get_rng_state()
41+
>>>
42+
>>> # Later, temporarily use those states
43+
>>> with rng_context(cpu_state, device_state):
44+
>>> # Code here uses the provided RNG states
45+
>>> x = torch.randn(10)
46+
>>> # Original RNG states are restored here
47+
"""
48+
49+
# Backup and set CPU RNG state
50+
cpu_backup = torch.get_rng_state()
51+
torch.set_rng_state(cpu_rng.get_state())
52+
53+
# Backup and set device RNG state if provided
54+
device_backup = None
55+
if device_rng is not None and torch.cuda.is_available():
56+
device_backup = torch.cuda.get_rng_state()
57+
torch.cuda.set_rng_state(device_rng.get_state())
58+
try:
59+
yield
60+
61+
finally:
62+
# Restore states
63+
cpu_rng.set_state(torch.get_rng_state())
64+
torch.set_rng_state(cpu_backup)
65+
if device_backup is not None:
66+
device_rng.set_state(torch.cuda.get_rng_state())
67+
torch.cuda.set_rng_state(device_backup)

makani/models/common/layers.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
2020
import math
2121

22+
from makani.models.common.context import rng_context
23+
2224

2325
@torch.compile
2426
def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
@@ -49,6 +51,26 @@ def forward(self, x):
4951
return drop_path(x, self.drop_prob, self.training)
5052

5153

54+
class SeededDropout(nn.Module):
55+
def __init__(self, drop_prob=0.0, seed=333):
56+
super(SeededDropout, self).__init__()
57+
self.drop_prob = drop_prob
58+
self.seed = seed
59+
self.drop = nn.Dropout(p=self.drop_prob)
60+
61+
# set RNG states
62+
self.rng_cpu = torch.Generator(device=torch.device("cpu"))
63+
self.rng_cpu.manual_seed(seed)
64+
if torch.cuda.is_available():
65+
self.rng_gpu = torch.Generator(device=torch.cuda.current_device())
66+
self.rng_gpu.manual_seed(seed)
67+
68+
def forward(self, x):
69+
with rng_context(self.rng_cpu, self.rng_gpu):
70+
xdrop = self.drop(x)
71+
return xdrop
72+
73+
5274
class LayerScale(nn.Module):
5375
def __init__(self, num_chans=3, init_value=0.1):
5476
super().__init__()

0 commit comments

Comments
 (0)