File tree 1 file changed +4
-2
lines changed
1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -171,11 +171,13 @@ class GPUMatrixDynamic : public GPUMatrixBase {
171
171
void resize (uint32_t rows, uint32_t cols) {
172
172
if (m_arena_allocation) {
173
173
cudaStream_t stream = m_arena_allocation->stream ();
174
- m_arena_allocation.reset ();
174
+ m_arena_allocation.reset (); // reset is called explicitly to ensure memory is freed before being allocated
175
175
m_arena_allocation = std::make_shared<GPUMemoryArena::Allocation>(allocate_workspace (stream, rows * cols * sizeof (T)));
176
+ m_data = (T*)m_arena_allocation->data ();
176
177
} else if (m_malloc_allocation) {
177
- m_malloc_allocation.reset ();
178
+ m_malloc_allocation.reset (); // reset is called explicitly to ensure memory is freed before being allocated
178
179
m_malloc_allocation = std::make_shared<GPUMemory<uint8_t >>(rows * cols * sizeof (T));
180
+ m_data = (T*)m_malloc_allocation->data ();
179
181
} else {
180
182
throw std::runtime_error{" GPUMatrix::resize is not permitted when the underlying memory is not owned. Use GPUMatrix::set instead." };
181
183
}
You can’t perform that action at this time.
0 commit comments