Skip to content

Commit 4ee4476

Browse files
committed
Helper: to_torch
Add helper methods to generate PyTorch tensors.
1 parent 56d5c98 commit 4ee4476

File tree

5 files changed

+39
-0
lines changed

5 files changed

+39
-0
lines changed

src/amrex/Array4.py

+3
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ def array4_to_cupy(self, copy=False, order="F"):
8282
raise ValueError("The order argument must be F or C.")
8383

8484

85+
# torch
86+
87+
8588
def register_Array4_extension(amr):
8689
"""Array4 helper methods"""
8790
import inspect

src/amrex/ArrayOfStructs.py

+3
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ def aos_to_cupy(self, copy=False):
7575
return cp.array(self, copy=copy)
7676

7777

78+
# torch
79+
80+
7881
def register_AoS_extension(amr):
7982
"""ArrayOfStructs helper methods"""
8083
import inspect

src/amrex/MultiFab.py

+5
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ def mf_to_cupy(self, copy=False, order="F"):
8989
return views
9090

9191

92+
# torch
93+
94+
9295
def register_MultiFab_extension(amr):
9396
"""MultiFab helper methods"""
9497

@@ -99,3 +102,5 @@ def register_MultiFab_extension(amr):
99102
amr.MultiFab.to_numpy.__doc__ = mf_to_numpy.__doc__
100103

101104
amr.MultiFab.to_cupy = mf_to_cupy
105+
106+
# torch

src/amrex/PODVector.py

+14
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,19 @@ def podvector_to_cupy(self, copy=False):
6868
raise ValueError("Vector is empty.")
6969

7070

71+
def podvector_to_torch(self, copy=False):
72+
"""
73+
Provide PyTorch tensor views into a PODVector (e.g., RealVector, IntVector).
74+
75+
...
76+
"""
77+
import torch
78+
79+
# if CUDA else ...
80+
# pick right device (context? device number?)
81+
return torch.as_tensor(self.to_cupy(copy), device="cuda")
82+
83+
7184
def register_PODVector_extension(amr):
7285
"""PODVector helper methods"""
7386
import inspect
@@ -82,3 +95,4 @@ def register_PODVector_extension(amr):
8295
):
8396
POD_type.to_numpy = podvector_to_numpy
8497
POD_type.to_cupy = podvector_to_cupy
98+
POD_type.to_torch = podvector_to_torch

src/amrex/StructOfArrays.py

+14
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,19 @@ def soa_to_cupy(self, copy=False):
8383
return soa_view
8484

8585

86+
def soa_to_torch(self, copy=False):
87+
"""
88+
Provide PyTorch tensor views into a StructOfArrays.
89+
90+
...
91+
"""
92+
import torch
93+
94+
# if CUDA else ...
95+
# pick right device (context? device number?)
96+
return torch.as_tensor(self.to_cupy(copy), device="cuda")
97+
98+
8699
def register_SoA_extension(amr):
87100
"""StructOfArrays helper methods"""
88101
import inspect
@@ -97,3 +110,4 @@ def register_SoA_extension(amr):
97110
):
98111
SoA_type.to_numpy = soa_to_numpy
99112
SoA_type.to_cupy = soa_to_cupy
113+
SoA_type.to_torch = soa_to_torch

0 commit comments

Comments
 (0)