Skip to content

Commit 5713130

Browse files
committed
signal handler for ctrl-c
1 parent cd57319 commit 5713130

9 files changed

Lines changed: 43 additions & 10 deletions

File tree

python/jittor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# This file is subject to the terms and conditions defined in
88
# file 'LICENSE.txt', which is part of this source code package.
99
# ***************************************************************
10-
__version__ = '1.2.0.9'
10+
__version__ = '1.2.1.0'
1111
from . import lock
1212
with lock.lock_scope():
1313
ori_int = int

python/jittor/dataset/dataset.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,11 @@ def terminate(self):
149149
'''
150150
if hasattr(self, "workers"):
151151
for w in self.workers:
152-
w.buffer.stop()
153-
w.p.join()
154-
w.p.close()
152+
w.p.terminate()
155153

156154
def _worker_main(self, worker_id, buffer, status):
155+
import jittor_utils
156+
jittor_utils.cc.init_subprocess()
157157
import time
158158
try:
159159
gid_obj = self.gid.get_obj()
@@ -162,7 +162,7 @@ def _worker_main(self, worker_id, buffer, status):
162162
while True:
163163
# get id
164164
with gid_lock:
165-
while gid_obj.value >= self.batch_len:
165+
while gid_obj.value >= self.batch_len or buffer.is_stop():
166166
self.num_idle.value += 1
167167
self.num_idle_c.notify()
168168
self.gidc.wait()
@@ -189,7 +189,12 @@ def _worker_main(self, worker_id, buffer, status):
189189
# send data to main process
190190
if mp_log_v:
191191
print(f"#{worker_id} {os.getpid()} send", type(batch).__name__, [ type(b).__name__ for b in batch ], buffer)
192-
buffer.send(batch)
192+
try:
193+
buffer.send(batch)
194+
except:
195+
if buffer.is_stop():
196+
continue
197+
raise
193198
now = time.time()
194199
send_time = now - start
195200
start = now

python/jittor/test/test_ring_buffer2.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,14 @@ def test_dataset(self):
8484
if batch_idx > 30:
8585
break
8686
pass
87+
for batch_idx, (data, target) in tqdm(enumerate(self.train_loader)):
88+
# time.sleep(5)
89+
# print("break")
90+
# break
91+
# self.train_loader.display_worker_status()
92+
if batch_idx > 300:
93+
break
94+
pass
8795

8896

8997
if __name__ == "__main__":

python/jittor/test/test_where_op.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def test_doc(self):
5757
assert "Where Operator" in jt.where.__doc__
5858

5959

60+
@unittest.skipIf(not jt.has_cuda, "No Torch found")
6061
class TestWhereOpCuda(TestWhereOp):
6162
def setUp(self):
6263
self.where = jt.where

python/jittor_utils/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@ def pool_cleanup():
154154
p.__exit__(None, None, None)
155155
del p
156156

157+
def pool_initializer():
158+
cc.init_subprocess()
159+
157160
def run_cmds(cmds, cache_path, jittor_path, msg="run_cmds"):
158161
global pool_size, p
159162
bk = mp.current_process()._config.get('daemon')
@@ -163,7 +166,7 @@ def run_cmds(cmds, cache_path, jittor_path, msg="run_cmds"):
163166
mem_gib = mem_bytes/(1024.**3)
164167
pool_size = min(16,max(int(mem_gib // 3), 1))
165168
LOG.i(f"Total mem: {mem_gib:.2f}GB, using {pool_size} procs for compiling.")
166-
p = Pool(pool_size)
169+
p = Pool(pool_size, initializer=pool_initializer)
167170
p.__enter__()
168171
import atexit
169172
atexit.register(pool_cleanup)

src/misc/ring_buffer.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ struct RingBuffer {
4848
}
4949

5050
inline ~Cond() {
51+
// a dirty hack
52+
// ref: https://stackoverflow.com/questions/20439404/pthread-conditions-and-process-termination
53+
cv.__data.__wrefs = 0;
5154
pthread_cond_destroy(&cv);
5255
}
5356

@@ -86,7 +89,7 @@ struct RingBuffer {
8689

8790
inline void wait() {
8891
if (is_stop) {
89-
abort();
92+
throw std::runtime_error("stop");
9093
}
9194
{
9295
MutexScope _(m);

src/pyjt/py_ring_buffer.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@ struct PyMultiprocessRingBuffer {
2222
// @pyjt(pop,recv)
2323
PyObject* pop();
2424
// @pyjt(clear)
25-
inline void clear() { rb->l = rb->r = 0; }
25+
inline void clear() { rb->l = rb->r = rb->is_stop = 0; }
2626
// @pyjt(stop)
2727
inline void stop() { rb->stop(); }
28+
// @pyjt(is_stop)
29+
inline bool is_stop() { return rb->is_stop; }
2830

2931
// @pyjt(total_pop)
3032
inline uint64 total_pop() { return rb->l; }

src/utils/jit_utils.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,16 @@
1313
#ifdef __GNUC__
1414
#endif
1515
#include <pybind11/iostream.h>
16+
#include <sys/prctl.h>
17+
#include <signal.h>
18+
19+
namespace jittor {
20+
21+
void init_subprocess() {
22+
prctl(PR_SET_PDEATHSIG, SIGKILL);
23+
}
24+
25+
}
1626

1727
PYBIND11_MODULE(jit_utils_core, m) {
1828
pybind11::add_ostream_redirect(m, "ostream_redirect");
@@ -39,4 +49,5 @@ PYBIND11_MODULE(jit_utils_core, m) {
3949
m.def("log_capture_start", &jittor::log_capture_start);
4050
m.def("log_capture_stop", &jittor::log_capture_stop);
4151
m.def("log_capture_read", &jittor::log_capture_read);
52+
m.def("init_subprocess", &jittor::init_subprocess);
4253
}

src/utils/log.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ int register_sigaction() {
219219
sigaction(SIGKILL, &sa, NULL);
220220
sigaction(SIGSTOP, &sa, NULL);
221221
sigaction(SIGFPE, &sa, NULL);
222-
// sigaction(SIGINT, &sa, NULL);
222+
sigaction(SIGINT, &sa, NULL);
223223
sigaction(SIGILL, &sa, NULL);
224224
sigaction(SIGBUS, &sa, NULL);
225225
sigaction(SIGQUIT, &sa, NULL);

0 commit comments

Comments
 (0)