Skip to content

Commit 607c505

Browse files
authored
Merge pull request #114 from edgeintelligence/brushup
Brushup
2 parents 3dfa2b4 + 0396da7 commit 607c505

File tree

20 files changed

+337
-52
lines changed

20 files changed

+337
-52
lines changed

distributed/README.ja.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# kakiage 分散機械学習サーバ
2+
3+
サーバ側は Python ライブラリとして実装されている。
4+
5+
# セットアップ
6+
7+
Python 3.8+
8+
9+
```
10+
pip install -r requirements.txt
11+
python setup.py develop
12+
```
13+
14+
サンプル動作方法: `samples/*/README.md`参照
15+
16+
# 配布用ビルド
17+
18+
```
19+
python setup.py bdist_wheel
20+
```
21+
22+
`dist/kakiage-<version>-py3-none-any.whl` が生成される。利用者は、`pip install /path/to/kakiage-<version>-py3-none-any.whl`を実行することで必須依存パッケージ(numpy 等)とともに kakiage をインストールすることが可能。

distributed/README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
# kakiage 分散機械学習サーバ
1+
# kakiage distributed training server
22

3-
サーバ側は Python ライブラリとして実装されている。
3+
Server-side code is implemented as Python library.
44

5-
# セットアップ
5+
# Setup
66

77
Python 3.8+
88

@@ -11,12 +11,12 @@ pip install -r requirements.txt
1111
python setup.py develop
1212
```
1313

14-
サンプル動作方法: `samples/*/README.md`参照
14+
How to run sample: see `samples/*/README.md`
1515

16-
# 配布用ビルド
16+
# Build for distribution
1717

1818
```
1919
python setup.py bdist_wheel
2020
```
2121

22-
`dist/kakiage-<version>-py3-none-any.whl` が生成される。利用者は、`pip install /path/to/kakiage-<version>-py3-none-any.whl`を実行することで必須依存パッケージ(numpy 等)とともに kakiage をインストールすることが可能。
22+
`dist/kakiage-<version>-py3-none-any.whl` will be generated. The user runs `pip install /path/to/kakiage-<version>-py3-none-any.whl` to install kakiage along with required dependencies (numpy, etc.).
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Sample: MNIST data parallel training
2+
3+
MNIST画像データセットを分類するMLPを、データ並列方式により分散学習
4+
5+
CIFAR10, CIFAR100データセットも使用可能。
6+
7+
# ビルド
8+
9+
```
10+
npm install
11+
npm run build
12+
```
13+
14+
# 学習実行
15+
16+
環境変数で設定を行う。
17+
18+
- MODEL: mlp, conv, resnet18のいずれか。モデルの種類を指定する。
19+
- N_CLIENTS: 分散計算に参加するクライアント数。1以上の整数を指定する。指定しない場合は1が指定されたとみなす。
20+
- EPOCH: 学習エポック数。デフォルトは2。
21+
- BATCH_SIZE: バッチサイズ。全クライアントの合計。デフォルトは32。
22+
23+
実行はuvicorn経由で行う。コマンド例(Mac/Linuxの場合):
24+
25+
```
26+
MODEL=conv N_CLIENTS=2 npm run train
27+
```
28+
29+
Windowsの場合はsetコマンドを使用して以下のようになる:
30+
31+
```
32+
set MODEL=conv
33+
set N_CLIENTS=2
34+
npm run train
35+
```
36+
37+
ブラウザで[http://localhost:8081/](http://localhost:8081/)を開く。`N_CLIENTS`を設定した場合、並列で計算するため、`N_CLIENTS`個のブラウザウィンドウで開く必要がある。注意:1つのウィンドウ上で複数のタブを開いた場合、表示されていないタブの計算速度が低下する。
38+
39+
学習したモデルはONNXフォーマットで出力される。WebDNN、ONNX Runtime Web等により、推論に利用することができる。
Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,39 @@
1-
# サンプル: MNIST data parallel training
1+
# Sample: MNIST data parallel training
22

3-
MNIST画像データセットを分類するMLPを、データ並列で学習する
3+
Training MLPs to classify MNIST image datasets in data-parallel distributed training
44

5-
# ビルド
5+
CIFAR10, CIFAR100 dataset can be also used.
6+
7+
# Build
68

79
```
810
npm install
911
npm run build
1012
```
1113

12-
# 学習実行
14+
# Run training
15+
16+
Settings are made via environment variables.
17+
18+
- MODEL: one of mlp, conv, resnet18. Specify model type.
19+
- N_CLIENTS: The number of clients participating in the distribution calculation, an integer greater than or equal to 1. If not specified, 1 is assumed to be specified.
20+
- EPOCH: Number of learning epochs. Default is 2.
21+
- BATCH_SIZE: Batch size. Total for all clients. Default is 32.
1322

23+
Execution is via uvicorn. Command sample (for Mac/Linux):
24+
25+
```
26+
MODEL=conv N_CLIENTS=2 npm run train
1427
```
28+
29+
On Windows, use the set command:
30+
31+
```
32+
set MODEL=conv
33+
set N_CLIENTS=2
1534
npm run train
1635
```
1736

18-
ブラウザで[http://localhost:8081/](http://localhost:8081/)を開く。3並列で計算するため、3つのタブで開く必要がある。
37+
Open [http://localhost:8081/](http://localhost:8081/) with web browser. If you set `N_CLIENTS`, to run `N_CLIENTS` distributed clients, it must be opened in `N_CLIENTS` browser windows. Note: If three tabs are opened on one window, the computation speed of the tabs not displayed will be reduced.
38+
39+
The learned models are output in ONNX format and can be used for inference with WebDNN, ONNX Runtime Web, etc.

distributed/sample/mnist_data_parallel/main.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,10 @@
2929
from kakiage.tensor_serializer import serialize_tensors_to_bytes, deserialize_tensor_from_bytes
3030
from sample_net import make_net, get_io_shape, get_dataset_loader
3131

32-
# スクリプトの配布
32+
# setup server to distribute javascript and communicate
3333
kakiage_server = setup_server()
3434
app = kakiage_server.app
3535

36-
# PyTorchを用いた初期モデルの作成、学習したモデルのサーバサイドでの評価
37-
38-
3936
def test(model, loader):
4037
model.eval()
4138
loss_sum = 0.0
@@ -56,6 +53,7 @@ def test(model, loader):
5653
def snake2camel(name):
5754
"""
5855
running_mean -> runningMean
56+
PyTorch uses snake_case, kakiage uses camelCase
5957
"""
6058
upper = False
6159
cs = []
@@ -99,6 +97,7 @@ async def main():
9997
client_ids = []
10098
print(f"Waiting {n_client_wait} clients to connect")
10199

100+
# Gets server event
102101
async def get_event():
103102
while True:
104103
event = await kakiage_server.event_queue.get()
@@ -137,11 +136,13 @@ async def get_event():
137136
chunk_size = math.ceil(batch_size / n_clients)
138137
chunk_sizes = []
139138
grad_item_ids = []
139+
# split batch into len(client_ids) chunks
140140
for c, client_id in enumerate(client_ids):
141141
image_chunk = image[c*chunk_size:(c+1)*chunk_size]
142142
label_chunk = label[c*chunk_size:(c+1)*chunk_size]
143143
chunk_sizes.append(len(image_chunk))
144144
dataset_item_id = uuid4().hex
145+
# set blob (binary data) in server so that client can download by spceifying id
145146
kakiage_server.blobs[dataset_item_id] = serialize_tensors_to_bytes(
146147
{
147148
"image": image_chunk.detach().numpy().astype(np.float32),
@@ -150,6 +151,7 @@ async def get_event():
150151
)
151152
item_ids_to_delete.append(dataset_item_id)
152153
grad_item_id = uuid4().hex
154+
# send client to calculate gradient given the weight and batch
153155
await kakiage_server.send_message(client_id, {
154156
"model": model_name,
155157
"inputShape": list(input_shape),
@@ -161,6 +163,9 @@ async def get_event():
161163
grad_item_ids.append(grad_item_id)
162164
item_ids_to_delete.append(grad_item_id)
163165
complete_count = 0
166+
# Wait for all clients to complete
167+
# No support for disconnection and dynamic addition of clients (this implementation waits disconnected client forever)
168+
# To support, handle event such as KakiageServerWSConnectEvent
164169
while True:
165170
event = await get_event()
166171
if isinstance(event, KakiageServerWSReceiveEvent):
@@ -184,6 +189,7 @@ async def get_event():
184189
for k, v in weights.items():
185190
grad = grad_arrays[k]
186191
if is_trainable_key(k):
192+
# update weight using SGD (no momentum)
187193
v -= lr * grad
188194
else:
189195
# not trainable = BN stats = average latest value
Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,33 @@
11
<!DOCTYPE html>
22
<html lang="en">
3-
<head>
4-
<meta charset="UTF-8">
5-
<meta http-equiv="X-UA-Compatible" content="IE=edge">
6-
<meta name="viewport" content="width=device-width, initial-scale=1.0">
3+
<head>
4+
<meta charset="UTF-8" />
5+
<meta http-equiv="X-UA-Compatible" content="IE=edge" />
6+
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
77
<title>Kakiage Distributed MNIST Training Sample</title>
88
<script src="static/index.js"></script>
9-
</head>
10-
<body>
11-
<p id="state"></p>
12-
<p id="messages"></p>
13-
</body>
14-
</html>
9+
<link href="static/index.css" rel="stylesheet" />
10+
</head>
11+
<body>
12+
<h1>Kakiage Distributed Training</h1>
13+
<main>
14+
<p id="state"></p>
15+
<table>
16+
<tbody>
17+
<tr>
18+
<td>Processed batches</td>
19+
<td id="table-batches"></td>
20+
</tr>
21+
<tr>
22+
<td>Last loss</td>
23+
<td id="table-loss"></td>
24+
</tr>
25+
<tr>
26+
<td>Batch size</td>
27+
<td id="table-batchsize"></td>
28+
</tr>
29+
</tbody>
30+
</table>
31+
</main>
32+
</body>
33+
</html>
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
body {
2+
margin: 0;
3+
}
4+
5+
h1 {
6+
background: linear-gradient(180deg, orange, white);
7+
margin: 0;
8+
padding: 1em;
9+
}
10+
11+
main {
12+
padding: 1em;
13+
}

distributed/sample/mnist_data_parallel/src/index.ts

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -222,14 +222,16 @@ function makeModel(
222222
}
223223
}
224224

225-
function writeLog(message: string) {
226-
document.getElementById('messages')!.innerText += message + '\n';
227-
}
228-
229225
const writeState = throttle((message: string) => {
230226
document.getElementById('state')!.innerText = message;
231227
}, 1000);
232228

229+
const writeBatchInfo = throttle((processedBatches: number, lastLoss: number, batchSize: number) => {
230+
document.getElementById('table-batches')!.innerText = processedBatches.toString();
231+
document.getElementById('table-loss')!.innerText = lastLoss.toString();
232+
document.getElementById('table-batchsize')!.innerText = batchSize.toString();
233+
}, 1000);
234+
233235
async function sendBlob(itemId: string, data: Uint8Array): Promise<void> {
234236
const blob = new Blob([data]);
235237
const f = await fetch(`/kakiage/blob/${itemId}`, {
@@ -285,23 +287,21 @@ async function compute(msg: { weight: string; dataset: string; grad: string }) {
285287
}
286288
await sendBlob(msg.grad, new TensorSerializer().serialize(grads));
287289
totalBatches += 1;
288-
writeState(
289-
`total batch: ${totalBatches}, last loss: ${lossValue}, last batch size: ${y.data.shape[0]}`
290-
);
290+
writeBatchInfo(totalBatches, lossValue, y.data.shape[0]);
291291
}
292292

293293
async function run() {
294-
writeState('Connecting');
294+
writeState('Connecting to distributed training server...');
295295
ws = new WebSocket(
296296
(window.location.protocol === 'https:' ? 'wss://' : 'ws://') +
297297
window.location.host +
298298
'/kakiage/ws'
299299
);
300300
ws.onopen = () => {
301-
writeState('Connected to WS server');
301+
writeState('Connected to server');
302302
};
303303
ws.onclose = () => {
304-
writeState('Disconnected from WS server');
304+
writeState('Disconnected from server');
305305
};
306306
ws.onmessage = async (ev) => {
307307
const msg = JSON.parse(ev.data);
@@ -320,7 +320,6 @@ async function run() {
320320
window.addEventListener('load', async () => {
321321
backend = (new URLSearchParams(window.location.search).get('backend') ||
322322
'webgl') as K.Backend;
323-
writeLog(`backend: ${backend}`);
324323
if (backend === 'webgl') {
325324
await K.tensor.initializeNNWebGLContext();
326325
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# サンプル: テンソルの定数倍
2+
3+
テンソルを定数倍にして返す、シンプルなサンプル
4+
5+
# ビルド
6+
7+
```
8+
npm install
9+
npm run build
10+
```
11+
12+
# 学習実行
13+
14+
```
15+
npm run train
16+
```
17+
18+
ブラウザで[http://localhost:8081/](http://localhost:8081/)を開く。
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
1-
# サンプル: テンソルの定数倍
1+
# Sample: Constant times the tensor
22

3-
テンソルを定数倍にして返す、シンプルなサンプル
3+
Simple sample that returns a tensor times a constant
44

5-
# ビルド
5+
# Build
66

77
```
88
npm install
99
npm run build
1010
```
1111

12-
# 学習実行
12+
# Run
1313

1414
```
1515
npm run train
1616
```
1717

18-
ブラウザで[http://localhost:8081/](http://localhost:8081/)を開く。
18+
Open [http://localhost:8081/](http://localhost:8081/) with web browser.

0 commit comments

Comments
 (0)