1010#include " vulkan_resources.h"
1111
1212using namespace Halide ::Runtime::Internal::Vulkan;
13+ using Halide::Runtime::Internal::MemoryRegionKind;
14+
15+ // --------------------------------------------------------------------------
16+
17+ namespace Halide {
18+ namespace Runtime {
19+ namespace Internal {
20+ namespace Vulkan {
21+
22+ ALWAYS_INLINE const MemoryRegion *vk_region_root (const MemoryRegion *region) {
23+ const MemoryRegion *current = region;
24+ while (current != nullptr && current->kind == MemoryRegionKind::CropAlias) {
25+ current = current->owner != nullptr ? current->owner : reinterpret_cast <MemoryRegion *>(current->handle );
26+ }
27+ return current;
28+ }
29+
30+ ALWAYS_INLINE MemoryRegion *vk_region_root (MemoryRegion *region) {
31+ return const_cast <MemoryRegion *>(vk_region_root (static_cast <const MemoryRegion *>(region)));
32+ }
33+
34+ ALWAYS_INLINE uint64_t vk_external_buffer_offset_bytes (void *user_context, const MemoryRegion *region) {
35+ halide_debug_assert (user_context, region != nullptr );
36+ const MemoryRegion *root = vk_region_root (region);
37+ return (root != nullptr && root->kind == MemoryRegionKind::ExternalWrapped) ? root->allocation .offset : 0 ;
38+ }
39+
40+ ALWAYS_INLINE uint64_t vk_total_buffer_offset_bytes (void *user_context, const MemoryRegion *region, halide_type_t type) {
41+ return vk_external_buffer_offset_bytes (user_context, region) + (region->indexing .offset * type.bytes ());
42+ }
43+
44+ ALWAYS_INLINE void vk_destroy_wrapped_region (MemoryRegion *region) {
45+ if (region == nullptr ) {
46+ return ;
47+ }
48+
49+ MemoryRegion *root = vk_region_root (region);
50+ if (root == nullptr ) {
51+ return ;
52+ }
53+
54+ if (region != root && region->kind == MemoryRegionKind::CropAlias) {
55+ free (region);
56+ return ;
57+ }
58+
59+ if (root->kind != MemoryRegionKind::ExternalWrapped) {
60+ return ;
61+ }
62+
63+ free (root->handle );
64+ free (root);
65+ }
66+
67+ ALWAYS_INLINE MemoryRegion *vk_create_wrapped_buffer_region (void *user_context,
68+ halide_buffer_t *buf,
69+ VkBuffer vk_buffer,
70+ uint64_t offset) {
71+ VkBuffer *native_handle = reinterpret_cast <VkBuffer *>(malloc (sizeof (VkBuffer)));
72+ if (native_handle == nullptr ) {
73+ error (user_context) << " Vulkan: Failed to allocate wrapped buffer handle metadata.\n " ;
74+ return nullptr ;
75+ }
76+
77+ MemoryRegion *region = reinterpret_cast <MemoryRegion *>(malloc (sizeof (MemoryRegion)));
78+ if (region == nullptr ) {
79+ free (native_handle);
80+ error (user_context) << " Vulkan: Failed to allocate wrapped buffer region metadata.\n " ;
81+ return nullptr ;
82+ }
83+
84+ *native_handle = vk_buffer;
85+ memset (region, 0 , sizeof (MemoryRegion));
86+ region->handle = native_handle;
87+ region->allocation .offset = offset;
88+ region->allocation .size = buf->size_in_bytes ();
89+ region->is_owner = true ;
90+ region->kind = MemoryRegionKind::ExternalWrapped;
91+ region->owner = nullptr ;
92+ return region;
93+ }
94+
95+ } // namespace Vulkan
96+ } // namespace Internal
97+ } // namespace Runtime
98+ } // namespace Halide
1399
14100// --------------------------------------------------------------------------
15101
@@ -110,12 +196,20 @@ WEAK int halide_vulkan_device_free(void *user_context, halide_buffer_t *halide_b
110196
111197 // get the allocated region for the device
112198 MemoryRegion *device_region = reinterpret_cast <MemoryRegion *>(halide_buffer->device );
199+ #ifdef DEBUG_RUNTIME
200+ const uint64_t device_region_size = device_region->allocation .size ;
201+ #endif
113202 MemoryRegion *memory_region = ctx.allocator ->owner_of (user_context, device_region);
114203 if (ctx.allocator && memory_region && memory_region->handle ) {
115- if (halide_can_reuse_device_allocations (user_context)) {
116- ctx.allocator ->release (user_context, memory_region);
204+ if (memory_region->kind == MemoryRegionKind::ExternalWrapped) {
205+ debug (user_context) << " Vulkan: Releasing wrapped external buffer metadata only.\n " ;
206+ vk_destroy_wrapped_region (device_region);
117207 } else {
118- ctx.allocator ->reclaim (user_context, memory_region);
208+ if (halide_can_reuse_device_allocations (user_context)) {
209+ ctx.allocator ->release (user_context, memory_region);
210+ } else {
211+ ctx.allocator ->reclaim (user_context, memory_region);
212+ }
119213 }
120214 }
121215 halide_buffer->device = 0 ;
@@ -126,7 +220,7 @@ WEAK int halide_vulkan_device_free(void *user_context, halide_buffer_t *halide_b
126220 debug (user_context) << " Vulkan: Released memory for device region ("
127221 << " user_context: " << user_context << " , "
128222 << " buffer: " << halide_buffer << " , "
129- << " size_in_bytes: " << ( uint64_t )device_region-> allocation . size << " )\n " ;
223+ << " size_in_bytes: " << device_region_size << " )\n " ;
130224
131225 uint64_t t_after = halide_current_time_ns (user_context);
132226 debug (user_context) << " Time: " << (t_after - t_before) / 1.0e6 << " ms\n " ;
@@ -272,15 +366,24 @@ WEAK int halide_vulkan_device_malloc(void *user_context, halide_buffer_t *buf) {
272366 size_t size = buf->size_in_bytes ();
273367 if (buf->device ) {
274368 MemoryRegion *device_region = (MemoryRegion *)(buf->device );
275- if (device_region->allocation .size >= size) {
369+ MemoryRegion *memory_region = ctx.allocator ->owner_of (user_context, device_region);
370+ if (memory_region != nullptr && memory_region->allocation .size >= size) {
276371 debug (user_context) << " Vulkan: Requested allocation for existing device memory ... using existing buffer!\n " ;
277372 return halide_error_code_success;
278373 } else {
374+ if (memory_region == nullptr ) {
375+ error (user_context) << " Vulkan: Failed to retrieve memory region for existing device buffer!\n " ;
376+ return halide_error_code_internal_error;
377+ }
378+ if (memory_region->kind == MemoryRegionKind::ExternalWrapped) {
379+ error (user_context) << " Vulkan: Wrapped external buffer is too small for requested allocation!\n " ;
380+ return halide_error_code_device_malloc_failed;
381+ }
279382 debug (user_context) << " Vulkan: Requested allocation of different size ... reallocating buffer!\n " ;
280383 if (halide_can_reuse_device_allocations (user_context)) {
281- ctx.allocator ->release (user_context, device_region );
384+ ctx.allocator ->release (user_context, memory_region );
282385 } else {
283- ctx.allocator ->reclaim (user_context, device_region );
386+ ctx.allocator ->reclaim (user_context, memory_region );
284387 }
285388 buf->device = 0 ;
286389 }
@@ -487,7 +590,7 @@ WEAK int halide_vulkan_copy_to_device(void *user_context, halide_buffer_t *halid
487590 bool to_host = false ;
488591
489592 uint64_t src_offset = copy_helper.src_begin ;
490- uint64_t dst_offset = copy_helper.dst_begin + ( device_region-> indexing . offset * halide_buffer->type . bytes () );
593+ uint64_t dst_offset = copy_helper.dst_begin + vk_total_buffer_offset_bytes (user_context, device_region, halide_buffer->type );
491594
492595 copy_helper.src = (uint64_t )(staging_buffer);
493596 copy_helper.dst = (uint64_t )(device_buffer);
@@ -656,7 +759,7 @@ WEAK int halide_vulkan_copy_to_host(void *user_context, halide_buffer_t *halide_
656759 bool from_host = false ;
657760 bool to_host = true ;
658761 uint64_t copy_dst = copy_helper.dst ;
659- uint64_t src_offset = copy_helper.src_begin + ( device_region-> indexing . offset * halide_buffer->type . bytes () );
762+ uint64_t src_offset = copy_helper.src_begin + vk_total_buffer_offset_bytes (user_context, device_region, halide_buffer->type );
660763 uint64_t dst_offset = copy_helper.dst_begin ;
661764
662765 copy_helper.src = (uint64_t )(device_buffer);
@@ -937,8 +1040,8 @@ WEAK int halide_vulkan_buffer_copy(void *user_context, struct halide_buffer_t *s
9371040
9381041 // define the src and dst config
9391042 uint64_t copy_dst = copy_helper.dst ;
940- uint64_t src_offset = copy_helper.src_begin + ( src_buffer_region-> indexing . offset * src->type . bytes () );
941- uint64_t dst_offset = copy_helper.dst_begin + ( dst_buffer_region-> indexing . offset * dst->type . bytes () );
1043+ uint64_t src_offset = copy_helper.src_begin + vk_total_buffer_offset_bytes (user_context, src_buffer_region, src->type );
1044+ uint64_t dst_offset = copy_helper.dst_begin + vk_total_buffer_offset_bytes (user_context, dst_buffer_region, dst->type );
9421045
9431046 copy_helper.src = (uint64_t )(src_device_buffer);
9441047 copy_helper.dst = (uint64_t )(dst_device_buffer);
@@ -1345,18 +1448,37 @@ WEAK int halide_vulkan_device_and_host_free(void *user_context, struct halide_bu
13451448 return halide_default_device_and_host_free (user_context, buf, &vulkan_device_interface);
13461449}
13471450
1348- WEAK int halide_vulkan_wrap_vk_buffer (void *user_context, struct halide_buffer_t *buf, uint64_t vk_buffer) {
1451+ WEAK int halide_vulkan_wrap_vk_buffer_with_offset (void *user_context,
1452+ struct halide_buffer_t *buf,
1453+ uint64_t vk_buffer,
1454+ uint64_t offset) {
13491455 halide_debug_assert (user_context, buf->device == 0 );
13501456 if (buf->device != 0 ) {
13511457 error (user_context) << " Vulkan: Unable to wrap buffer ... invalid device pointer!\n " ;
13521458 return halide_error_code_device_wrap_native_failed;
13531459 }
1354- buf->device = vk_buffer;
1460+ if (vk_buffer == 0 ) {
1461+ error (user_context) << " Vulkan: Unable to wrap buffer ... invalid VkBuffer handle!\n " ;
1462+ return halide_error_code_device_wrap_native_failed;
1463+ }
1464+
1465+ MemoryRegion *region = vk_create_wrapped_buffer_region (user_context, buf,
1466+ reinterpret_cast <VkBuffer>(vk_buffer),
1467+ offset);
1468+ if (region == nullptr ) {
1469+ return halide_error_code_out_of_memory;
1470+ }
1471+
1472+ buf->device = reinterpret_cast <uint64_t >(region);
13551473 buf->device_interface = &vulkan_device_interface;
13561474 buf->device_interface ->impl ->use_module ();
13571475 return halide_error_code_success;
13581476}
13591477
1478+ WEAK int halide_vulkan_wrap_vk_buffer (void *user_context, struct halide_buffer_t *buf, uint64_t vk_buffer) {
1479+ return halide_vulkan_wrap_vk_buffer_with_offset (user_context, buf, vk_buffer, 0 );
1480+ }
1481+
13601482WEAK int halide_vulkan_detach_vk_buffer (void *user_context, halide_buffer_t *buf) {
13611483 if (buf->device == 0 ) {
13621484 return halide_error_code_success;
@@ -1365,6 +1487,13 @@ WEAK int halide_vulkan_detach_vk_buffer(void *user_context, halide_buffer_t *buf
13651487 error (user_context) << " Vulkan: Unable to detach buffer ... invalid device interface!\n " ;
13661488 return halide_error_code_incompatible_device_interface;
13671489 }
1490+ MemoryRegion *device_region = reinterpret_cast <MemoryRegion *>(buf->device );
1491+ MemoryRegion *root_region = vk_region_root (device_region);
1492+ if (root_region == nullptr || root_region->kind != MemoryRegionKind::ExternalWrapped) {
1493+ error (user_context) << " Vulkan: Unable to detach buffer ... buffer is not externally wrapped!\n " ;
1494+ return halide_error_code_device_detach_native_failed;
1495+ }
1496+ vk_destroy_wrapped_region (device_region);
13681497 buf->device = 0 ;
13691498 buf->device_interface ->impl ->release_module ();
13701499 buf->device_interface = nullptr ;
@@ -1376,7 +1505,21 @@ WEAK uintptr_t halide_vulkan_get_vk_buffer(void *user_context, halide_buffer_t *
13761505 return 0 ;
13771506 }
13781507 halide_debug_assert (user_context, buf->device_interface == &vulkan_device_interface);
1379- return (uintptr_t )buf->device ;
1508+ MemoryRegion *device_region = reinterpret_cast <MemoryRegion *>(buf->device );
1509+ MemoryRegion *root_region = vk_region_root (device_region);
1510+ if (root_region == nullptr || root_region->handle == nullptr ) {
1511+ return 0 ;
1512+ }
1513+ return (uintptr_t )(*reinterpret_cast <VkBuffer *>(root_region->handle ));
1514+ }
1515+
1516+ WEAK uint64_t halide_vulkan_get_vk_crop_offset (void *user_context, halide_buffer_t *buf) {
1517+ if (buf->device == 0 ) {
1518+ return 0 ;
1519+ }
1520+ halide_debug_assert (user_context, buf->device_interface == &vulkan_device_interface);
1521+ MemoryRegion *device_region = reinterpret_cast <MemoryRegion *>(buf->device );
1522+ return vk_total_buffer_offset_bytes (user_context, device_region, buf->type );
13801523}
13811524
13821525WEAK const struct halide_device_interface_t *halide_vulkan_device_interface () {
0 commit comments