Skip to content

Commit 258c534

Browse files
Merge pull request #310 from jonbinney/cuda
Empty CUDA's cache
2 parents 8de7737 + 12ed5da commit 258c534

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

deep_quoridor/src/agents/alphazero/self_play_manager.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Optional
88

99
import quoridor_env
10+
import torch
1011
import wandb
1112
from plugins.wandb_train import WandbParams
1213
from utils import set_deterministic
@@ -258,3 +259,7 @@ def run_self_play_games(
258259
list(alphazero_agent.replay_buffer),
259260
)
260261
)
262+
263+
del alphazero_agent
264+
if torch.cuda.is_available():
265+
torch.cuda.empty_cache()

deep_quoridor/src/metrics.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import torch
12
from agents import Agent
23
from agents.alphazero.alphazero import AlphaZeroAgent
34
from agents.core.agent import AgentRegistry
@@ -144,11 +145,17 @@ def compute(
144145
absolute_elo = elo_table[agent.name()]
145146

146147
dumb_score = self.dumb_score(agent)
148+
del agent
147149

148150
if isinstance(agent, AlphaZeroAgent):
149151
raw_play_encoded_name = override_subargs(play_encoded_name, {"mcts_n": 0})
150152
agent_raw = AgentRegistry.create_from_encoded_name(raw_play_encoded_name, arena.game)
151153
dumb_score_raw = self.dumb_score(agent_raw, verbose=True)
154+
del agent_raw
155+
156+
157+
if torch.cuda.is_available():
158+
torch.cuda.empty_cache()
152159

153160
return (
154161
VERSION,

0 commit comments

Comments
 (0)