Skip to content

Commit 5b7e2e2

Browse files
committed
fix stream initialization
1 parent 1883268 commit 5b7e2e2

File tree

3 files changed

+35
-4
lines changed

3 files changed

+35
-4
lines changed

KunQuant/Driver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def push_source(is_simple=False):
211211
is_single_source = split_source == 0
212212
# the set of names of custom cross sectional functions
213213
generated_cross_sectional_func = set()
214+
stream_state_buffer_init = []
214215
for func in impl:
215216
if split_source > 0 and cur_count > split_source:
216217
push_source()
@@ -232,7 +233,6 @@ def push_source(is_simple=False):
232233
def query_temp_buf_id(tempname: str, window: int) -> int:
233234
input_windows[tempname] = window
234235
return insert_name_str(tempname, "TEMP").idx
235-
stream_state_buffer_init = []
236236
src, decl = codegen_cpp(module_name, func, input_name_to_idx, ins, outs, options, stream_mode, query_temp_buf_id, input_windows, stream_state_buffer_init, generated_cross_sectional_func, dtype, blocking_len, not allow_unaligned, is_single_source)
237237
impl_src.append(src)
238238
decl_src.append(decl)

cpp/Kun/Runtime.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,9 @@ template <typename T>
327327
char *StreamBuffer<T>::make(size_t stock_count, size_t window_size,
328328
size_t simd_len) {
329329
auto ret = kunAlignedAlloc(
330-
sizeof(T) * simd_len, StreamBuffer::getBufferSize(stock_count, window_size, simd_len));
330+
KUN_MALLOC_ALIGNMENT,
331+
roundUp(StreamBuffer::getBufferSize(stock_count, window_size, simd_len),
332+
KUN_MALLOC_ALIGNMENT));
331333
auto buf = (StreamBuffer *)ret;
332334
auto data = buf->getBuffer();
333335
auto rounded_stock_count = roundUp(stock_count, simd_len);
@@ -487,8 +489,10 @@ bool StreamContext::serializeStates(OutputStreamBase* stream) {
487489
StateBuffer *StateBuffer::make(size_t num_objs, size_t elem_size,
488490
CtorFn_t ctor_fn, DtorFn_t dtor_fn, SerializeFn_t serialize_fn,
489491
DeserializeFn_t deserialize_fn) {
490-
auto ret = kunAlignedAlloc(KUN_MALLOC_ALIGNMENT,
491-
sizeof(StateBuffer) + num_objs * elem_size);
492+
auto ret =
493+
kunAlignedAlloc(KUN_MALLOC_ALIGNMENT,
494+
roundUp(sizeof(StateBuffer) + num_objs * elem_size,
495+
KUN_MALLOC_ALIGNMENT));
492496
auto buf = (StateBuffer *)ret;
493497
buf->num_objs = num_objs;
494498
buf->elem_size = elem_size;

tests/test_runtime.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,31 @@ def test_stream_lifetime_gh_issue_41():
596596

597597
####################################
598598

599+
def repro_crash_gh_issue_71():
600+
print("Building factors...")
601+
builder = Builder()
602+
with builder:
603+
inp1 = Input("close")
604+
# Generate MANY factors to stress memory/heap
605+
for i in range(20):
606+
Output(WindowedQuantile(inp1, 10, 0.5 + i * 0.01), f"qtl_{i}")
607+
for i in range(20):
608+
Output(WindowedLinearRegressionSlope(inp1, 10 + i), f"beta_{i}")
609+
f = Function(builder.ops)
610+
return "test_repro_crash_gh_issue_71", f, KunCompilerConfig(partition_factor=3, output_layout="STREAM", options={"opt_reduce": False, "fast_log": True})
611+
612+
def test_repro_crash_gh_issue_71(lib):
613+
num_symbols = 24
614+
executor = kr.createSingleThreadExecutor()
615+
modu = lib.getModule("test_repro_crash_gh_issue_71")
616+
stream = kr.StreamContext(executor, modu, num_symbols)
617+
data = np.random.rand(num_symbols).astype("float32")
618+
h_close = stream.queryBufferHandle("close")
619+
for i in range(20):
620+
stream.pushData(h_close, data)
621+
stream.run()
622+
623+
####################################
599624

600625
def create_stream_double():
601626
builder = Builder()
@@ -680,6 +705,7 @@ def rolling_max_dd(x, window_size, min_periods=1):
680705
check_covar(),
681706
check_quantile(),
682707
check_large_rank(),
708+
repro_crash_gh_issue_71(),
683709
]
684710
lib = cfake.compileit(funclist, "test", cfake.CppCompilerConfig(machine=get_compiler_flags()))
685711

@@ -706,4 +732,5 @@ def rolling_max_dd(x, window_size, min_periods=1):
706732
test_covar(lib)
707733
test_quantile(lib)
708734
test_large_rank(lib)
735+
test_repro_crash_gh_issue_71(lib)
709736
print("done")

0 commit comments

Comments
 (0)