Skip to content

Commit 5cfe57f

Browse files
vchuravyclaude
andcommitted
add extension for AMDGPU
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent c093788 commit 5cfe57f

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Trixi"
22
uuid = "a7f1ee26-1774-49b1-8366-f1abc58fbfcb"
3-
version = "0.15.7-DEV"
43
authors = ["Michael Schlottke-Lakemper <michael.schlottke-lakemper@uni-a.de>", "Gregor Gassner <ggassner@uni-koeln.de>", "Hendrik Ranocha <mail@ranocha.de>", "Andrew R. Winters <andrew.ross.winters@liu.se>", "Jesse Chan <jesse.chan@rice.edu>", "Andrés Rueda-Ramírez <am.rueda@upm.es>"]
4+
version = "0.15.7-DEV"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
@@ -54,6 +54,7 @@ TrixiBase = "9a0f1c46-06d5-4909-a5a3-ce25d3fa3284"
5454
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
5555

5656
[weakdeps]
57+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
5758
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
5859
Convex = "f65535da-76fb-5f13-bab9-19810c17039a"
5960
ECOS = "e2685f51-7e38-5353-a97d-a921fd2c8199"
@@ -62,6 +63,7 @@ NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
6263
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
6364

6465
[extensions]
66+
TrixiAMDGPUExt = "AMDGPU"
6567
TrixiCUDAExt = "CUDA"
6668
TrixiConvexECOSExt = ["Convex", "ECOS"]
6769
TrixiMakieExt = "Makie"
@@ -70,6 +72,7 @@ TrixiSparseConnectivityTracerExt = "SparseConnectivityTracer"
7072

7173
[compat]
7274
Accessors = "0.1.36"
75+
AMDGPU = "2.2.1"
7376
Adapt = "4.1"
7477
CUDA = "5.8.2"
7578
CodeTracking = "1.0.5, 2, 3"

ext/TrixiAMDGPUExt.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Package extension for adding AMDGPU-based features to Trixi.jl
2+
module TrixiAMDGPUExt
3+
4+
import AMDGPU: ROCArray, ROCDeviceArray
5+
import AMDGPU.Runtime: Adaptor
6+
import Trixi
7+
8+
function Trixi.storage_type(::Type{<:ROCArray})
9+
return ROCArray
10+
end
11+
12+
function Trixi.unsafe_wrap_or_alloc(::Adaptor, vec, size)
13+
return Trixi.unsafe_wrap_or_alloc(ROCDeviceArray, vec, size)
14+
end
15+
16+
function Trixi.unsafe_wrap_or_alloc(::Type{<:ROCDeviceArray}, vec::ROCDeviceArray, size)
17+
return reshape(vec, size)
18+
end
19+
20+
end

0 commit comments

Comments
 (0)