Skip to content

Commit 6ff539a

Browse files
authored
[Reland] Use torch.multiprocessing instead of multiprocessing module (#976)
Reland #970 due to ghstack issues.
1 parent 0898a2f commit 6ff539a

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

torchtitan/components/checkpoint.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,18 @@
66

77
import enum
88
import functools
9-
import multiprocessing as mp
109
import os
1110
import queue
1211
import re
1312
import shutil
1413
import threading
1514
import time
16-
from multiprocessing import get_context
1715
from typing import Any, Dict, List, Optional, Union
1816

1917
import torch
2018
import torch.distributed as dist
2119
import torch.distributed.checkpoint as dcp
20+
import torch.multiprocessing as mp
2221
import torch.nn as nn
2322
from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict
2423
from torch.distributed.checkpoint.state_dict import (
@@ -294,7 +293,7 @@ def load_state_dict(state_dict):
294293
self.async_future = None
295294
elif async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
296295
self.async_mode = AsyncMode.ASYNC_WITH_PINNED_MEM
297-
ctx = get_context("spawn")
296+
ctx = mp.get_context("spawn")
298297
self.mp_queue_send = ctx.Queue()
299298
self.mp_queue_recv = ctx.Queue()
300299
self.mp = ctx.Process(

0 commit comments

Comments
 (0)