Skip to content

Commit 40012de

Browse files
hekaishengXuye (Chris) Qin
authored and
Xuye (Chris) Qin
committed
Implement MarsDataset to integrate with PyTorch (#937)
1 parent 4a602b5 commit 40012de

19 files changed

+527
-64
lines changed

mars/api.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,11 @@ def has_session(self, session_id):
9595
return self.session_manager.has_session(session_id)
9696

9797
def submit_graph(self, session_id, serialized_graph, graph_key, target,
98-
compose=True, wait=True):
98+
names=None, compose=True, wait=True):
9999
session_uid = SessionActor.gen_uid(session_id)
100100
session_ref = self.get_actor_ref(session_uid)
101101
session_ref.submit_tileable_graph(
102-
serialized_graph, graph_key, target, compose=compose, _tell=not wait)
102+
serialized_graph, graph_key, target, names=names, compose=compose, _tell=not wait)
103103

104104
def create_mutable_tensor(self, session_id, name, shape, dtype, *args, **kwargs):
105105
session_uid = SessionActor.gen_uid(session_id)
@@ -176,11 +176,11 @@ def wait_graph_finish(self, session_id, graph_key, timeout=None):
176176
def fetch_data(self, session_id, graph_key, tileable_key, index_obj=None, compressions=None):
177177
graph_uid = GraphActor.gen_uid(session_id, graph_key)
178178
graph_ref = self.get_actor_ref(graph_uid)
179-
nsplits, chunk_indexes = graph_ref.get_tileable_meta(tileable_key)
180-
179+
nsplits, chunk_keys, chunk_indexes = graph_ref.get_tileable_metas([tileable_key])[0]
180+
chunk_index_to_key = dict((index, key) for index, key in zip(chunk_indexes, chunk_keys))
181181
if not index_obj:
182182
chunk_results = dict((idx, self.fetch_chunk_data(session_id, k)) for
183-
idx, k in chunk_indexes.items())
183+
idx, k in zip(chunk_indexes, chunk_keys))
184184
else:
185185
chunk_results = dict()
186186
indexes = dict()
@@ -194,7 +194,7 @@ def fetch_data(self, session_id, graph_key, tileable_key, index_obj=None, compre
194194
# `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array
195195
# index, `arr[np.array(seq)]`, which will result either in an error or a different result.
196196
slice_obj = tuple(indexes[axis][chunk_idx] for axis, chunk_idx in enumerate(chunk_index))
197-
chunk_key = chunk_indexes[chunk_index]
197+
chunk_key = chunk_index_to_key[chunk_index]
198198
chunk_results[chunk_index] = self.fetch_chunk_data(session_id, chunk_key, slice_obj)
199199

200200
chunk_results = [(idx, dataserializer.loads(f.result())) for
@@ -210,8 +210,7 @@ def fetch_chunk_data(self, session_id, chunk_key, index_obj=None):
210210
endpoints = self.chunk_meta_client.get_workers(session_id, chunk_key)
211211
sender_ref = self.actor_client.actor_ref(ResultSenderActor.default_uid(),
212212
address=random.choice(endpoints))
213-
future = sender_ref.fetch_data(session_id, chunk_key, index_obj, _wait=False)
214-
return future
213+
return sender_ref.fetch_data(session_id, chunk_key, index_obj, _wait=False)
215214

216215
def delete_data(self, session_id, graph_key, tileable_key, wait=False):
217216
graph_uid = GraphActor.gen_uid(session_id, graph_key)
@@ -223,4 +222,4 @@ def get_tileable_nsplits(self, session_id, graph_key, tileable_key):
223222
graph_uid = GraphActor.gen_uid(session_id, graph_key)
224223
graph_ref = self.get_actor_ref(graph_uid)
225224

226-
return graph_ref.get_tileable_meta(tileable_key)[0]
225+
return graph_ref.get_tileable_metas([tileable_key], filter_fields=['nsplits'])[0][0]

mars/context.py

+115-16
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
import sys
1616
import threading
17-
from collections import namedtuple
17+
import random
18+
from collections import namedtuple, defaultdict
1819
from enum import Enum
1920
from typing import List
2021

@@ -70,11 +71,12 @@ def h(*args, **kwargs):
7071
# Meta relative
7172
# ---------------
7273

73-
def get_chunk_metas(self, chunk_keys):
74+
def get_chunk_metas(self, chunk_keys, filter_fields=None):
7475
"""
7576
Get chunk metas according to the given chunk keys.
7677
7778
:param chunk_keys: chunk keys
79+
:param filter_fields: filter the fields in meta
7880
:return: List of chunk metas
7981
"""
8082
raise NotImplementedError
@@ -188,7 +190,9 @@ def get_local_address(self):
188190
def get_ncores(self):
189191
return self._ncores
190192

191-
def get_chunk_metas(self, chunk_keys):
193+
def get_chunk_metas(self, chunk_keys, filter_fields=None):
194+
if filter_fields is not None: # pragma: no cover
195+
raise NotImplementedError("Local context doesn't support filter fields now")
192196
metas = []
193197
for chunk_key in chunk_keys:
194198
chunk_data = self.get(chunk_key)
@@ -219,17 +223,29 @@ def get_chunk_results(self, chunk_keys: List[str]) -> List:
219223

220224

221225
class DistributedContext(ContextBase):
222-
def __init__(self, cluster_info, session_id, addr, chunk_meta_client,
223-
resource_actor_ref, actor_ctx, **kw):
224-
self._cluster_info = cluster_info
225-
is_distributed = cluster_info.is_distributed()
226+
def __init__(self, scheduler_address, session_id, actor_ctx=None, **kw):
227+
from .worker.api import WorkerAPI
228+
from .scheduler.api import MetaAPI
229+
from .scheduler.resource import ResourceActor
230+
from .scheduler.utils import SchedulerClusterInfoActor
231+
from .actors import new_client
232+
233+
self._session_id = session_id
234+
self._scheduler_address = scheduler_address
235+
self._worker_api = WorkerAPI()
236+
self._meta_api = MetaAPI(actor_ctx=actor_ctx, scheduler_endpoint=scheduler_address)
237+
238+
self._running_mode = None
239+
self._actor_ctx = actor_ctx or new_client()
240+
self._cluster_info = self._actor_ctx.actor_ref(
241+
SchedulerClusterInfoActor.default_uid(), address=scheduler_address)
242+
is_distributed = self._cluster_info.is_distributed()
226243
self._running_mode = RunningMode.local_cluster \
227244
if not is_distributed else RunningMode.distributed
228-
self._session_id = session_id
229-
self._address = addr
230-
self._chunk_meta_client = chunk_meta_client
231-
self._resource_actor_ref = resource_actor_ref
232-
self._actor_ctx = actor_ctx
245+
self._address = kw.pop('address', None)
246+
self._resource_actor_ref = self._actor_ctx.actor_ref(
247+
ResourceActor.default_uid(), address=scheduler_address)
248+
233249
self._extra_info = kw
234250

235251
@property
@@ -252,10 +268,6 @@ def get_local_address(self):
252268
def get_ncores(self):
253269
return self._extra_info.get('n_cpu')
254270

255-
def get_chunk_metas(self, chunk_keys):
256-
return self._chunk_meta_client.batch_get_chunk_meta(
257-
self._session_id, chunk_keys)
258-
259271
def get_chunk_results(self, chunk_keys: List[str]) -> List:
260272
from .serialize import dataserializer
261273
from .worker.transfer import ResultSenderActor
@@ -269,6 +281,93 @@ def get_chunk_results(self, chunk_keys: List[str]) -> List:
269281
dataserializer.loads(sender_ref.fetch_data(self._session_id, chunk_key)))
270282
return results
271283

284+
# Meta API
285+
def get_tileable_metas(self, tileable_keys, filter_fields: List[str]=None) -> List:
286+
return self._meta_api.get_tileable_metas(self._session_id, tileable_keys, filter_fields)
287+
288+
def get_chunk_metas(self, chunk_keys, filter_fields: List[str] = None) -> List:
289+
return self._meta_api.get_chunk_metas(self._session_id, chunk_keys, filter_fields)
290+
291+
def get_tileable_key_by_name(self, name: str):
292+
return self._meta_api.get_tileable_key_by_name(self._session_id, name)
293+
294+
# Worker API
295+
def get_chunks_data(self, worker: str, chunk_keys: List[str], indexes: List=None,
296+
compression_types: List[str]=None):
297+
return self._worker_api.get_chunks_data(self._session_id, worker, chunk_keys, indexes=indexes,
298+
compression_types=compression_types)
299+
300+
# Fetch tileable data by tileable keys and indexes.
301+
def get_tileable_data(self, tileable_key: str, indexes: List=None,
302+
compression_types: List[str]=None):
303+
from .serialize import dataserializer
304+
from .utils import merge_chunks
305+
from .tensor.core import TENSOR_TYPE
306+
from .tensor.datasource import empty
307+
from .tensor.indexing.getitem import TensorIndexTilesHandler
308+
309+
nsplits, chunk_keys, chunk_indexes = self.get_tileable_metas([tileable_key])[0]
310+
chunk_idx_to_keys = dict(zip(chunk_indexes, chunk_keys))
311+
chunk_keys_to_idx = dict(zip(chunk_keys, chunk_indexes))
312+
endpoints = self.get_chunk_metas(chunk_keys, filter_fields=['workers'])
313+
chunk_keys_to_worker = dict((chunk_key, random.choice(es[0])) for es, chunk_key in zip(endpoints, chunk_keys))
314+
315+
chunk_workers = defaultdict(list)
316+
[chunk_workers[e].append(chunk_key) for chunk_key, e in chunk_keys_to_worker.items()]
317+
318+
chunk_results = dict()
319+
if not indexes:
320+
datas = []
321+
for endpoint, chunks in chunk_workers.items():
322+
datas.append(self.get_chunks_data(endpoint, chunks, compression_types=compression_types))
323+
datas = [d.result() for d in datas]
324+
for (endpoint, chunks), d in zip(chunk_workers.items(), datas):
325+
d = [dataserializer.loads(db) for db in d]
326+
chunk_results.update(dict(zip([chunk_keys_to_idx[k] for k in chunks], d)))
327+
else:
328+
# TODO: make a common util to handle indexes
329+
if any(isinstance(ind, TENSOR_TYPE) for ind in indexes):
330+
raise TypeError("Doesn't support indexing by tensors")
331+
# Reuse the getitem logic to get each chunk's indexes
332+
tileable_shape = tuple(sum(s) for s in nsplits)
333+
empty_tileable = empty(tileable_shape, chunk_size=nsplits)._inplace_tile()
334+
indexed = empty_tileable[tuple(indexes)]
335+
index_handler = TensorIndexTilesHandler(indexed.op)
336+
index_handler._extract_indexes_info()
337+
index_handler._preprocess_fancy_indexes()
338+
index_handler._process_fancy_indexes()
339+
index_handler._process_in_tensor()
340+
341+
result_chunks = dict()
342+
for c in index_handler._out_chunks:
343+
result_chunks[chunk_idx_to_keys[c.inputs[0].index]] = [c.index, c.op.indexes]
344+
345+
chunk_datas = dict()
346+
for endpoint, chunks in chunk_workers.items():
347+
to_fetch_keys = []
348+
to_fetch_indexes = []
349+
to_fetch_idx = []
350+
for r_chunk, (chunk_index, index_obj) in result_chunks.items():
351+
if r_chunk in chunks:
352+
to_fetch_keys.append(r_chunk)
353+
to_fetch_indexes.append(index_obj)
354+
to_fetch_idx.append(chunk_index)
355+
if to_fetch_keys:
356+
datas = self.get_chunks_data(endpoint, to_fetch_keys, indexes=to_fetch_indexes,
357+
compression_types=compression_types)
358+
chunk_datas[tuple(to_fetch_idx)] = datas
359+
chunk_datas = dict((k, v.result()) for k, v in chunk_datas.items())
360+
for idx, d in chunk_datas.items():
361+
d = [dataserializer.loads(db) for db in d]
362+
chunk_results.update(dict(zip(idx, d)))
363+
364+
chunk_results = [(k, v) for k, v in chunk_results.items()]
365+
if len(chunk_results) == 1:
366+
ret = chunk_results[0][1]
367+
else:
368+
ret = merge_chunks(chunk_results)
369+
return ret
370+
272371

273372
class DistributedDictContext(DistributedContext, dict):
274373
def __init__(self, *args, **kwargs):

mars/learn/contrib/pytorch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from .run_script import run_pytorch_script
16+
from .dataset import MarsDataset
1617

1718

1819
def register_op():

mars/learn/contrib/pytorch/dataset.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 1999-2020 Alibaba Group Holding Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
try:
18+
from torch.utils.data import Dataset
19+
except ImportError: # pragma: no cover
20+
Dataset = object
21+
22+
from ....context import get_context, DistributedContext
23+
from ....tensor.indexing.core import process_index
24+
from ....tensor.fetch import TensorFetch
25+
from ....utils import require_not_none
26+
27+
28+
@require_not_none(Dataset)
29+
class MarsDataset(Dataset):
30+
def __init__(self, *names):
31+
self._context = get_context()
32+
33+
tensors = []
34+
for name in names:
35+
tileable_key = self._context.get_tileable_key_by_name(name)
36+
nsplits = self._context.get_tileable_metas([tileable_key], filter_fields=['nsplits'])[0][0]
37+
shape = tuple(sum(s) for s in nsplits)
38+
tensors.append(TensorFetch().new_tensor([], shape=shape, _key=tileable_key))
39+
self.tensors = tensors
40+
41+
def __len__(self):
42+
return self.tensors[0].shape[0]
43+
44+
def __getitem__(self, item):
45+
indexes = process_index(self.tensors[0].ndim, item)
46+
return tuple(self._context.get_tileable_data(t.key, indexes) for t in self.tensors)
47+
48+
49+
def enter_mars_context():
50+
scheduler = os.environ['MARS_SCHEDULER']
51+
session_id = os.environ['MARS_SESSION']
52+
return DistributedContext(scheduler_address=scheduler, session_id=session_id)
53+
54+

mars/learn/contrib/pytorch/run_script.py

+6
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,12 @@ def execute(cls, ctx, op):
125125
env['MASTER_ADDR'] = str(op.master_addr)
126126
env['RANK'] = str(op.rank)
127127
env['WORLD_SIZE'] = str(op.world_size)
128+
129+
# set mars envs
130+
if ctx.running_mode != RunningMode.local:
131+
env['MARS_SCHEDULER'] = str(ctx._scheduler_address)
132+
env['MARS_SESSION'] = str(ctx._session_id)
133+
128134
# exec pytorch code in a new process
129135
process = subprocess.Popen(
130136
[sys.executable, filename] + op.command_args, env=env)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright 1999-2020 Alibaba Group Holding Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import sys
17+
18+
19+
def get_model():
20+
import torch.nn as nn
21+
return nn.Sequential(
22+
nn.Linear(32, 64),
23+
nn.ReLU(),
24+
nn.Linear(64, 64),
25+
nn.ReLU(),
26+
nn.Linear(64, 10),
27+
nn.Softmax(),
28+
)
29+
30+
31+
def main():
32+
import torch.nn as nn
33+
import torch.distributed as dist
34+
import torch.optim as optim
35+
import torch.utils.data
36+
from mars.learn.contrib.pytorch.dataset import MarsDataset, enter_mars_context
37+
38+
dist.init_process_group(backend='gloo')
39+
torch.manual_seed(42)
40+
41+
with enter_mars_context():
42+
train_dataset = MarsDataset('data', 'labels')
43+
44+
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
45+
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
46+
batch_size=32,
47+
shuffle=False,
48+
sampler=train_sampler)
49+
50+
model = nn.parallel.DistributedDataParallel(get_model())
51+
optimizer = optim.SGD(model.parameters(),
52+
lr=0.01, momentum=0.5)
53+
criterion = nn.BCELoss()
54+
55+
for _ in range(2):
56+
# 2 epochs
57+
for _, (batch_data, batch_labels) in enumerate(train_loader):
58+
outputs = model(batch_data)
59+
loss = criterion(outputs.squeeze(), batch_labels)
60+
optimizer.zero_grad()
61+
loss.backward()
62+
optimizer.step()
63+
64+
65+
if __name__ == "__main__":
66+
assert len(sys.argv) == 2
67+
assert sys.argv[1] == 'multiple'
68+
main()

0 commit comments

Comments
 (0)