Skip to content

Commit 143bb01

Browse files
committed
optimize ring buffer and copy free array op
1 parent 1e21b66 commit 143bb01

22 files changed

Lines changed: 855 additions & 28 deletions

python/jittor/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@
77
# This file is subject to the terms and conditions defined in
88
# file 'LICENSE.txt', which is part of this source code package.
99
# ***************************************************************
10-
__version__ = '1.2.0.7'
10+
__version__ = '1.2.0.8'
1111
from . import lock
1212
with lock.lock_scope():
13+
ori_int = int
14+
ori_float = float
15+
ori_bool = bool
1316
from . import compiler
1417
from .compiler import LOG, has_cuda
1518
from .compiler import compile_custom_ops, compile_custom_op
@@ -874,16 +877,13 @@ def to_float(v):
874877
def to_bool(v):
875878
dtype = str(v.dtype)
876879
assert dtype.startswith("int") or dtype=="bool"
877-
return bool(v.item())
880+
return ori_bool(v.item())
878881

879882
Var.item = item
880883
Var.__int__ = to_int
881884
Var.__float__ = to_float
882885
Var.__bool__ = to_bool
883886

884-
ori_int = int
885-
ori_float = float
886-
887887
int = int32
888888
Var.int = Var.int32
889889
float = float32

python/jittor/dataset/dataset.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
from collections.abc import Sequence, Mapping
1616
import pathlib
1717
from PIL import Image
18-
from jittor_utils import ring_buffer
19-
from jittor_utils.ring_buffer import RingBuffer
2018
import multiprocessing as mp
2119
import signal
2220
from jittor_utils import LOG
@@ -30,8 +28,8 @@
3028

3129
class Worker:
3230
def __init__(self, target, args, buffer_size):
33-
buffer = mp.Array('c', buffer_size, lock=False)
34-
self.buffer = RingBuffer(buffer)
31+
self.buffer = jt.RingBuffer(buffer_size)
32+
3533
self.status = mp.Array('f', 5, lock=False)
3634
self.p = mp.Process(target=target, args=args+(self.buffer,self.status))
3735
self.p.daemon = True
@@ -253,13 +251,12 @@ class YourDataset(Dataset):
253251
msg.append(f"progress:{self.last_id}/{self.batch_len}")
254252
msg.append(f"batch(s): {self.batch_time:.3f}\twait(s):{self.wait_time:.3f}")
255253
msg.append(f"recv(s): {self.recv_time:.3f}\tto_jittor(s):{self.to_jittor_time:.3f}")
256-
msg.append(f"recv_raw_call: {ring_buffer.recv_raw_call}")
257254
msg.append(f"last 10 workers: {self.idmap[max(0, self.last_id-9):self.last_id+1]}")
258255
msg.append(f"ID\twait(s)\topen(s)\tload(s)\tsend(s)\ttotal(s)")
259256
for i in range(self.num_workers):
260257
w = self.workers[i]
261258
s = w.status
262-
msg.append(f"#{i}\t{s[0]:.3f}\t{s[4]:.3f}\t{s[1]:.3f}\t{s[2]:.3f}\t{s[3]:.3f}\t{w.buffer.allocator}")
259+
msg.append(f"#{i}\t{s[0]:.3f}\t{s[4]:.3f}\t{s[1]:.3f}\t{s[2]:.3f}\t{s[3]:.3f}\t{w.buffer}")
263260
LOG.i('\n'.join(msg))
264261

265262
def _stop_all_workers(self):

python/jittor/pyjt_compiler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ def parse_attrs(s):
3131
"uint": ["PyLong_AsUnsignedLong", "PyLong_FromUnsignedLong", "PyLong_CheckExact"],
3232
"uint64": ["PyLong_AsUnsignedLongLong", "PyLong_FromUnsignedLongLong", "PyLong_CheckExact"],
3333
"void": ["...", "GET_PY_NONE", "..."],
34+
"PyObject*": ["","",""],
3435
}
3536
def get_pytype_map(T, i):
37+
assert T != ""
3638
if T in pytype_map:
3739
return pytype_map[T][i]
3840
return ["from_py_object", "to_py_object", "is_type"][i]+"<"+T+">"
@@ -204,7 +206,7 @@ def get_def_code(df, scope_name, pyname, self_as_arg0=False):
204206
func_call = f"(GET_RAW_PTR({scope_name},self))->" + func_call
205207
if pyname == "__init__":
206208
# XXX->xxx(...) ---> new XXX xxx(...)
207-
assert "->" in func_call
209+
assert "->" in func_call, func_call
208210
func_call = "new " + func_call.replace("->", " ")
209211
if no_need_convert:
210212
func_quick_check_runable = ""

python/jittor/test/test_reduce_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def check(a, op, dims):
3939
idims = [(), (0,), (1,), (2,), (3,), (0, 2), (1,3), (1,2,3), 2, 3]
4040

4141
iop = [ op[7:] for op in dir(jt) if op.startswith("reduce_")]
42-
assert len(iop) >= 10
42+
assert len(iop) >= 10, iop
4343
for a in ia:
4444
check(a, iop[0], idims[0])
4545
for op in iop:

python/jittor/test/test_relu.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,9 @@ def test_relu(self):
6565
# ***************************************************************
6666
# Test GELU Layer
6767
# ***************************************************************
68-
arr = np.random.randn(16,10,224,224)
69-
check_equal(arr, jnn.GELU(), tnn.GELU())
68+
if hasattr(tnn, "GELU"):
69+
arr = np.random.randn(16,10,224,224)
70+
check_equal(arr, jnn.GELU(), tnn.GELU())
7071

7172
# ***************************************************************
7273
# Test Softplus Layer
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# ***************************************************************
2+
# Copyright (c) 2020 Jittor. Authors:
3+
# Dun Liang <randonlang@gmail.com>.
4+
# All Rights Reserved.
5+
# This file is subject to the terms and conditions defined in
6+
# file 'LICENSE.txt', which is part of this source code package.
7+
# ***************************************************************
8+
import jittor as jt
9+
import unittest
10+
import numpy as np
11+
import random
12+
from .test_core import expect_error
13+
from jittor.dataset.mnist import MNIST
14+
import jittor.transform as trans
15+
from tqdm import tqdm
16+
17+
def test_ring_buffer():
18+
buffer = jt.RingBuffer(1000)
19+
def test_send_recv(data):
20+
print("test send recv", type(data))
21+
buffer.push(data)
22+
recv = buffer.pop()
23+
if isinstance(data, np.ndarray):
24+
assert (recv == data).all()
25+
else:
26+
assert data == recv
27+
28+
n_byte = 0
29+
test_send_recv(1)
30+
n_byte += 1 + 8
31+
assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()
32+
test_send_recv(100000000000)
33+
n_byte += 1 + 8
34+
assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()
35+
36+
test_send_recv(1e-5)
37+
n_byte += 1 + 8
38+
assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()
39+
test_send_recv(100000000000.0)
40+
n_byte += 1 + 8
41+
assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()
42+
43+
test_send_recv("float32")
44+
n_byte += 1 + 8 + 7
45+
assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()
46+
test_send_recv("")
47+
n_byte += 1 + 8 + 0
48+
assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()
49+
test_send_recv("xxxxxxxxxx")
50+
n_byte += 1 + 8 + 10
51+
assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()
52+
53+
test_send_recv([1,0.2])
54+
n_byte += 1 + 8 + 1 + 8 + 1 + 8
55+
assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()
56+
test_send_recv({'asd':1})
57+
n_byte += 1 + 8 + 1 + 8 + 3 + 1 + 8
58+
assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()
59+
60+
test_send_recv(np.random.rand(10,10))
61+
n_byte += 1 + 16 + 2 + 10*10*8
62+
assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()
63+
test_send_recv(test_ring_buffer)
64+
65+
expect_error(lambda: test_send_recv(np.random.rand(10,1000)))
66+
67+
68+
class TestRingBuffer(unittest.TestCase):
69+
70+
def test_ring_buffer(self):
71+
test_ring_buffer()
72+
73+
def test_dataset(self):
74+
return
75+
self.train_loader = MNIST(train=True, transform=trans.Resize(224)) \
76+
.set_attrs(batch_size=300, shuffle=True)
77+
self.train_loader.num_workers = 1
78+
for batch_idx, (data, target) in tqdm(enumerate(self.train_loader)):
79+
# self.train_loader.display_worker_status()
80+
if batch_idx > 30:
81+
break
82+
pass
83+
84+
85+
if __name__ == "__main__":
86+
unittest.main()

python/jittor/transform/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def to_tensor(img):
213213
img_ = transform.to_tensor(img)
214214
"""
215215
if isinstance(img, Image.Image):
216-
return np.array(img).transpose((2,0,1)) / np.float32(255)
216+
return np.array(img).transpose((2,0,1)) * np.float32(1.0/255.0)
217217
return img
218218

219219

@@ -323,7 +323,7 @@ def __call__(self, img):
323323
if isinstance(img, Image.Image):
324324
img = (np.array(img).transpose((2,0,1)) \
325325
- self.mean*np.float32(255.)) \
326-
/ (self.std*np.float32(255.))
326+
* (np.float32(1./255.)/self.std)
327327
else:
328328
img = (img - self.mean) / self.std
329329
return img

python/jittor_utils/ring_buffer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def str_to_char_array(s, array_len):
127127
return a
128128

129129
def char_array_to_str(a):
130-
return str(a.tostring(), 'ascii').strip()
130+
return str(a.tobytes(), 'ascii').strip()
131131

132132
class RingBuffer:
133133
def __init__(self, buffer):

src/misc/nano_vector.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ struct NanoVector {
6464
// @pyjt(__init__)
6565
inline NanoVector(const NanoVector& nv) : data(nv.data), offset(nv.offset) {}
6666

67-
void clear() { data = offset = 0; }
67+
inline void clear() { data = offset = 0; }
6868

6969
// @pyjt(__len__, __map_len__)
7070
inline int size() const {
@@ -158,10 +158,22 @@ struct NanoVector {
158158
for (auto a : v) push_back_check_overflow(a);
159159
}
160160

161+
inline static NanoVector make(const int64* v, int n) {
162+
NanoVector nv;
163+
for (int i=0; i<n; i++) nv.push_back_check_overflow(v[i]);
164+
return nv;
165+
}
166+
167+
inline static NanoVector make(const int32* v, int n) {
168+
NanoVector nv;
169+
for (int i=0; i<n; i++) nv.push_back_check_overflow(v[i]);
170+
return nv;
171+
}
172+
161173
inline NanoVector(int64 x) { push_back(x); }
162174

163175
// @pyjt(__repr__)
164-
string to_string() {
176+
inline string to_string() {
165177
string s="[";
166178
for (int i=0; i<size(); i++) {
167179
s += S(at(i));

src/misc/ring_buffer.cc

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
// ***************************************************************
2+
// Copyright (c) 2020 Jittor. All Rights Reserved.
3+
// Authors: Dun Liang <randonlang@gmail.com>.
4+
// This file is subject to the terms and conditions defined in
5+
// file 'LICENSE.txt', which is part of this source code package.
6+
// ***************************************************************
7+
#include <chrono>
8+
#include <thread>
9+
#include <sys/mman.h>
10+
#include "common.h"
11+
#include "misc/ring_buffer.h"
12+
13+
namespace jittor {
14+
15+
RingBuffer::RingBuffer(uint64 size, bool multiprocess) : m(multiprocess), cv(multiprocess) {
16+
int i=0;
17+
for (;(1ll<<i)<size;i++);
18+
size_mask = (1ll<<i)-1;
19+
this->size = size_mask+1;
20+
size_bit = i;
21+
l = r = is_wait = 0;
22+
is_multiprocess = multiprocess;
23+
}
24+
25+
RingBuffer::~RingBuffer() {
26+
}
27+
28+
29+
RingBuffer* RingBuffer::make_ring_buffer(uint64 size, bool multiprocess) {
30+
int i=0;
31+
for (;(1ll<<i)<size;i++);
32+
uint64 size_mask = (1ll<<i)-1;
33+
size = size_mask+1;
34+
uint64 total_size = sizeof(RingBuffer) + size;
35+
void* ptr = multiprocess ?
36+
// mmap(NULL, total_size, PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS | MAP_HUGETLB, -1, 0) :
37+
mmap(NULL, total_size, PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0) :
38+
// mmap(NULL, total_size, PROT_READ | PROT_WRITE, MAP_SHARED, -1, 0) :
39+
(void*)malloc(total_size);
40+
std::memset(ptr, 0, total_size);
41+
auto rb = (RingBuffer*)ptr;
42+
new (rb) RingBuffer(size, multiprocess);
43+
return rb;
44+
}
45+
46+
void RingBuffer::free_ring_buffer(RingBuffer* rb) {
47+
uint64 total_size = sizeof(RingBuffer) + rb->size;
48+
auto is_multiprocess = rb->is_multiprocess;
49+
rb->~RingBuffer();
50+
if (is_multiprocess) {
51+
munmap(rb, total_size);
52+
} else {
53+
rb->~RingBuffer();
54+
free((void*)rb);
55+
}
56+
}
57+
58+
// test
59+
60+
JIT_TEST(ring_buffer_benchmark) {
61+
size_t n = 1ll << 20;
62+
size_t size = 1<<15;
63+
// size_t n = 1ll << 30;
64+
// size_t size = 1<<20;
65+
// size_t n = 1ll << 10;
66+
// size_t size = 1<<5;
67+
RingBuffer* rb = RingBuffer::make_ring_buffer(size, 0);
68+
std::thread p([&]() {
69+
for (size_t i=0; i<n; i++) {
70+
rb->push_t<int>(i);
71+
}
72+
});
73+
auto start = std::chrono::high_resolution_clock::now();
74+
size_t s = 0;
75+
for (size_t i=0; i<n; i++) {
76+
auto x = rb->pop_t<int>();
77+
s += x;
78+
}
79+
auto finish = std::chrono::high_resolution_clock::now();
80+
auto tt = std::chrono::duration_cast<std::chrono::nanoseconds>(finish-start).count();
81+
p.join();
82+
expect_error([&]() { rb->push(size+1); });
83+
RingBuffer::free_ring_buffer(rb);
84+
85+
LOGi << tt << tt*1.0/n;
86+
LOGi << s << (n*(n-1)/2);
87+
ASSERTop(s,==,(n*(n-1)/2));
88+
ASSERTop(tt*1.0/n,<=,50);
89+
}
90+
91+
}

0 commit comments

Comments
 (0)