17
17
namespace Halide {
18
18
namespace Internal {
19
19
20
- using std::endl;
21
- using std::map;
22
20
using std::ostream;
23
21
using std::ostringstream;
24
22
using std::string;
@@ -47,7 +45,7 @@ CodeGen_PyTorch::CodeGen_PyTorch(ostream &s, Target t, std::string cpp_header) :
47
45
" UserContext feature to properly manage the GPU memory. "
48
46
" Please add \" -user_context\" to the generator's target options.\n " ;
49
47
}
50
- stream << " #include \" ATen/cuda/CUDAContext.h\"\n " ;
48
+ stream << " #include \" ATen/cuda/CUDAContext.h\"\n " ;
51
49
stream << " #include \" HalidePyTorchCudaHelpers.h\"\n " ;
52
50
}
53
51
@@ -92,7 +90,7 @@ void CodeGen_PyTorch::compile(const LoweredFunc &f, bool is_cuda) {
92
90
continue ;
93
91
} else if (args[i].is_buffer ()) {
94
92
buffer_args.push_back (args[i]);
95
- stream
93
+ stream
96
94
<< type_to_pytorch_tensor (args[i].type , is_cuda)
97
95
<< " &"
98
96
<< c_print_name (args[i].name );
@@ -134,14 +132,14 @@ void CodeGen_PyTorch::compile(const LoweredFunc &f, bool is_cuda) {
134
132
do_indent ();
135
133
stream
136
134
<< " HLPT_CHECK_CONTIGUOUS("
137
- << c_print_name (buffer_args[i].name )
135
+ << c_print_name (buffer_args[i].name )
138
136
<< " );\n " ;
139
137
140
138
if (is_cuda) {
141
139
do_indent ();
142
140
stream
143
141
<< " HLPT_CHECK_DEVICE("
144
- << c_print_name (buffer_args[i].name )
142
+ << c_print_name (buffer_args[i].name )
145
143
<< " , device_id);\n " ;
146
144
}
147
145
}
@@ -157,9 +155,9 @@ void CodeGen_PyTorch::compile(const LoweredFunc &f, bool is_cuda) {
157
155
string tp = type_to_c_type (buffer_args[i].type , false );
158
156
stream
159
157
<< " Buffer<" << tp << " > "
160
- << c_print_name (buffer_args[i].name )
158
+ << c_print_name (buffer_args[i].name )
161
159
<< " _buffer = Halide::PyTorch::wrap<" << tp << " >("
162
- << c_print_name (buffer_args[i].name )
160
+ << c_print_name (buffer_args[i].name )
163
161
<< " );\n "
164
162
;
165
163
}
@@ -172,7 +170,7 @@ void CodeGen_PyTorch::compile(const LoweredFunc &f, bool is_cuda) {
172
170
stream << " int err = " << simple_name << " (" ;
173
171
for (size_t i = 0 ; i < args.size (); i++) {
174
172
if (args[i].is_buffer ()) {
175
- stream
173
+ stream
176
174
<< c_print_name (args[i].name )
177
175
<< " _buffer" ;
178
176
} else {
@@ -194,15 +192,15 @@ void CodeGen_PyTorch::compile(const LoweredFunc &f, bool is_cuda) {
194
192
for (size_t i = 0 ; i < buffer_args.size (); i++) {
195
193
if (buffer_args[i].is_buffer ()) {
196
194
do_indent ();
197
- stream
195
+ stream
198
196
<< " AT_ASSERTM(!"
199
197
<< c_print_name (buffer_args[i].name ) << " _buffer.host_dirty(),"
200
198
<< " \" device not synchronized for buffer "
201
199
<< c_print_name (buffer_args[i].name )
202
200
<< " , make sure all update stages are excplicitly computed on GPU."
203
201
<<" \" );\n " ;
204
202
do_indent ();
205
- stream
203
+ stream
206
204
<< c_print_name (buffer_args[i].name ) << " _buffer"
207
205
<< " .device_detach_native();\n " ;
208
206
}
@@ -260,7 +258,7 @@ void CodeGen_PyTorch::test() {
260
258
{
261
259
// TODO(mgharbi): test that Target("host-cuda") raises an exception since
262
260
// we require the "user_context" feature when using CUDA
263
-
261
+
264
262
CodeGen_PyTorch cg (source, Target (" host" ), " PyTorchTestOp.h" );
265
263
cg.compile (m);
266
264
@@ -270,7 +268,7 @@ void CodeGen_PyTorch::test() {
270
268
string src = source.str () + " \n " + source_cuda.str ();
271
269
272
270
// The correct source concatenates CPU and GPU headers
273
- string correct_src =
271
+ string correct_src =
274
272
R"GOLDEN_CODE( #include "torch/extension.h"
275
273
#include "HalideBuffer.h"
276
274
#include "HalidePyTorchHelpers.h"
0 commit comments