Skip to content

Commit 82f34d0

Browse files
committed
c10::optional -> std::optional
1 parent 7cabb53 commit 82f34d0

16 files changed

+132
-132
lines changed

csrc/cpu/scatter_cpu.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
#include "reducer.h"
55
#include "utils.h"
66

7-
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
7+
std::tuple<torch::Tensor, std::optional<torch::Tensor>>
88
scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
9-
torch::optional<torch::Tensor> optional_out,
10-
torch::optional<int64_t> dim_size, std::string reduce) {
9+
std::optional<torch::Tensor> optional_out,
10+
std::optional<int64_t> dim_size, std::string reduce) {
1111
CHECK_CPU(src);
1212
CHECK_CPU(index);
1313
if (optional_out.has_value())
@@ -36,7 +36,7 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
3636
out = torch::empty(sizes, src.options());
3737
}
3838

39-
torch::optional<torch::Tensor> arg_out = torch::nullopt;
39+
std::optional<torch::Tensor> arg_out = std::nullopt;
4040
int64_t *arg_out_data = nullptr;
4141
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
4242
arg_out = torch::full_like(out, src.size(dim), index.options());

csrc/cpu/scatter_cpu.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
#include "../extensions.h"
44

5-
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
5+
std::tuple<torch::Tensor, std::optional<torch::Tensor>>
66
scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
7-
torch::optional<torch::Tensor> optional_out,
8-
torch::optional<int64_t> dim_size, std::string reduce);
7+
std::optional<torch::Tensor> optional_out,
8+
std::optional<int64_t> dim_size, std::string reduce);

csrc/cpu/segment_coo_cpu.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
#include "utils.h"
66
#include <ATen/OpMathType.h>
77

8-
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
8+
std::tuple<torch::Tensor, std::optional<torch::Tensor>>
99
segment_coo_cpu(torch::Tensor src, torch::Tensor index,
10-
torch::optional<torch::Tensor> optional_out,
11-
torch::optional<int64_t> dim_size, std::string reduce) {
10+
std::optional<torch::Tensor> optional_out,
11+
std::optional<int64_t> dim_size, std::string reduce) {
1212
CHECK_CPU(src);
1313
CHECK_CPU(index);
1414
if (optional_out.has_value())
@@ -45,7 +45,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
4545
out = torch::empty(sizes, src.options());
4646
}
4747

48-
torch::optional<torch::Tensor> arg_out = torch::nullopt;
48+
std::optional<torch::Tensor> arg_out = std::nullopt;
4949
int64_t *arg_out_data = nullptr;
5050
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
5151
arg_out = torch::full_like(out, src.size(dim), index.options());
@@ -141,7 +141,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
141141
}
142142

143143
torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index,
144-
torch::optional<torch::Tensor> optional_out) {
144+
std::optional<torch::Tensor> optional_out) {
145145
CHECK_CPU(src);
146146
CHECK_CPU(index);
147147
if (optional_out.has_value())

csrc/cpu/segment_coo_cpu.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
#include "../extensions.h"
44

5-
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
5+
std::tuple<torch::Tensor, std::optional<torch::Tensor>>
66
segment_coo_cpu(torch::Tensor src, torch::Tensor index,
7-
torch::optional<torch::Tensor> optional_out,
8-
torch::optional<int64_t> dim_size, std::string reduce);
7+
std::optional<torch::Tensor> optional_out,
8+
std::optional<int64_t> dim_size, std::string reduce);
99

1010
torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index,
11-
torch::optional<torch::Tensor> optional_out);
11+
std::optional<torch::Tensor> optional_out);

csrc/cpu/segment_csr_cpu.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
#include "utils.h"
66
#include <ATen/OpMathType.h>
77

8-
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
8+
std::tuple<torch::Tensor, std::optional<torch::Tensor>>
99
segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
10-
torch::optional<torch::Tensor> optional_out,
10+
std::optional<torch::Tensor> optional_out,
1111
std::string reduce) {
1212
CHECK_CPU(src);
1313
CHECK_CPU(indptr);
@@ -38,7 +38,7 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
3838
out = torch::empty(sizes, src.options());
3939
}
4040

41-
torch::optional<torch::Tensor> arg_out = torch::nullopt;
41+
std::optional<torch::Tensor> arg_out = std::nullopt;
4242
int64_t *arg_out_data = nullptr;
4343
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
4444
arg_out = torch::full(out.sizes(), src.size(dim), indptr.options());
@@ -92,7 +92,7 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
9292
}
9393

9494
torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
95-
torch::optional<torch::Tensor> optional_out) {
95+
std::optional<torch::Tensor> optional_out) {
9696
CHECK_CPU(src);
9797
CHECK_CPU(indptr);
9898
if (optional_out.has_value())

csrc/cpu/segment_csr_cpu.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
#include "../extensions.h"
44

5-
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
5+
std::tuple<torch::Tensor, std::optional<torch::Tensor>>
66
segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
7-
torch::optional<torch::Tensor> optional_out,
7+
std::optional<torch::Tensor> optional_out,
88
std::string reduce);
99

1010
torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
11-
torch::optional<torch::Tensor> optional_out);
11+
std::optional<torch::Tensor> optional_out);

csrc/cuda/scatter_cuda.cu

+4-4
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@ scatter_arg_kernel(const scalar_t *src_data,
5555
}
5656
}
5757

58-
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
58+
std::tuple<torch::Tensor, std::optional<torch::Tensor>>
5959
scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
60-
torch::optional<torch::Tensor> optional_out,
61-
torch::optional<int64_t> dim_size, std::string reduce) {
60+
std::optional<torch::Tensor> optional_out,
61+
std::optional<int64_t> dim_size, std::string reduce) {
6262
CHECK_CUDA(src);
6363
CHECK_CUDA(index);
6464
if (optional_out.has_value())
@@ -89,7 +89,7 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
8989
out = torch::empty(sizes, src.options());
9090
}
9191

92-
torch::optional<torch::Tensor> arg_out = torch::nullopt;
92+
std::optional<torch::Tensor> arg_out = std::nullopt;
9393
int64_t *arg_out_data = nullptr;
9494
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
9595
arg_out = torch::full_like(out, src.size(dim), index.options());

csrc/cuda/scatter_cuda.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
#include "../extensions.h"
44

5-
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
5+
std::tuple<torch::Tensor, std::optional<torch::Tensor>>
66
scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
7-
torch::optional<torch::Tensor> optional_out,
8-
torch::optional<int64_t> dim_size, std::string reduce);
7+
std::optional<torch::Tensor> optional_out,
8+
std::optional<int64_t> dim_size, std::string reduce);

csrc/cuda/segment_coo_cuda.cu

+5-5
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,10 @@ __global__ void segment_coo_arg_broadcast_kernel(
149149
}
150150
}
151151

152-
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
152+
std::tuple<torch::Tensor, std::optional<torch::Tensor>>
153153
segment_coo_cuda(torch::Tensor src, torch::Tensor index,
154-
torch::optional<torch::Tensor> optional_out,
155-
torch::optional<int64_t> dim_size, std::string reduce) {
154+
std::optional<torch::Tensor> optional_out,
155+
std::optional<int64_t> dim_size, std::string reduce) {
156156
CHECK_CUDA(src);
157157
CHECK_CUDA(index);
158158
if (optional_out.has_value())
@@ -191,7 +191,7 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
191191
out = torch::zeros(sizes, src.options());
192192
}
193193

194-
torch::optional<torch::Tensor> arg_out = torch::nullopt;
194+
std::optional<torch::Tensor> arg_out = std::nullopt;
195195
int64_t *arg_out_data = nullptr;
196196
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
197197
arg_out = torch::full_like(out, src.size(dim), index.options());
@@ -325,7 +325,7 @@ __global__ void gather_coo_broadcast_kernel(
325325
}
326326
327327
torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
328-
torch::optional<torch::Tensor> optional_out) {
328+
std::optional<torch::Tensor> optional_out) {
329329
CHECK_CUDA(src);
330330
CHECK_CUDA(index);
331331
if (optional_out.has_value())

csrc/cuda/segment_coo_cuda.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
#include "../extensions.h"
44

5-
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
5+
std::tuple<torch::Tensor, std::optional<torch::Tensor>>
66
segment_coo_cuda(torch::Tensor src, torch::Tensor index,
7-
torch::optional<torch::Tensor> optional_out,
8-
torch::optional<int64_t> dim_size, std::string reduce);
7+
std::optional<torch::Tensor> optional_out,
8+
std::optional<int64_t> dim_size, std::string reduce);
99

1010
torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
11-
torch::optional<torch::Tensor> optional_out);
11+
std::optional<torch::Tensor> optional_out);

csrc/cuda/segment_csr_cuda.cu

+4-4
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,9 @@ __global__ void segment_csr_broadcast_kernel(
9494
}
9595
}
9696

97-
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
97+
std::tuple<torch::Tensor, std::optional<torch::Tensor>>
9898
segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
99-
torch::optional<torch::Tensor> optional_out,
99+
std::optional<torch::Tensor> optional_out,
100100
std::string reduce) {
101101
CHECK_CUDA(src);
102102
CHECK_CUDA(indptr);
@@ -128,7 +128,7 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
128128
out = torch::empty(sizes, src.options());
129129
}
130130

131-
torch::optional<torch::Tensor> arg_out = torch::nullopt;
131+
std::optional<torch::Tensor> arg_out = std::nullopt;
132132
int64_t *arg_out_data = nullptr;
133133
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
134134
arg_out = torch::full(out.sizes(), src.size(dim), indptr.options());
@@ -217,7 +217,7 @@ __global__ void gather_csr_broadcast_kernel(
217217
}
218218

219219
torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
220-
torch::optional<torch::Tensor> optional_out) {
220+
std::optional<torch::Tensor> optional_out) {
221221
CHECK_CUDA(src);
222222
CHECK_CUDA(indptr);
223223
if (optional_out.has_value())

csrc/cuda/segment_csr_cuda.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
#include "../extensions.h"
44

5-
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
5+
std::tuple<torch::Tensor, std::optional<torch::Tensor>>
66
segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
7-
torch::optional<torch::Tensor> optional_out,
7+
std::optional<torch::Tensor> optional_out,
88
std::string reduce);
99

1010
torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
11-
torch::optional<torch::Tensor> optional_out);
11+
std::optional<torch::Tensor> optional_out);

csrc/scatter.cpp

+24-24
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ torch::Tensor broadcast(torch::Tensor src, torch::Tensor other, int64_t dim) {
3232
return src;
3333
}
3434

35-
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
35+
std::tuple<torch::Tensor, std::optional<torch::Tensor>>
3636
scatter_fw(torch::Tensor src, torch::Tensor index, int64_t dim,
37-
torch::optional<torch::Tensor> optional_out,
38-
torch::optional<int64_t> dim_size, std::string reduce) {
37+
std::optional<torch::Tensor> optional_out,
38+
std::optional<int64_t> dim_size, std::string reduce) {
3939
if (src.device().is_cuda()) {
4040
#ifdef WITH_CUDA
4141
return scatter_cuda(src, index, dim, optional_out, dim_size, reduce);
@@ -55,8 +55,8 @@ class ScatterSum : public torch::autograd::Function<ScatterSum> {
5555
public:
5656
static variable_list forward(AutogradContext *ctx, Variable src,
5757
Variable index, int64_t dim,
58-
torch::optional<Variable> optional_out,
59-
torch::optional<int64_t> dim_size) {
58+
std::optional<Variable> optional_out,
59+
std::optional<int64_t> dim_size) {
6060
dim = dim < 0 ? src.dim() + dim : dim;
6161
ctx->saved_data["dim"] = dim;
6262
ctx->saved_data["src_shape"] = src.sizes();
@@ -84,8 +84,8 @@ class ScatterMul : public torch::autograd::Function<ScatterMul> {
8484
public:
8585
static variable_list forward(AutogradContext *ctx, Variable src,
8686
Variable index, int64_t dim,
87-
torch::optional<Variable> optional_out,
88-
torch::optional<int64_t> dim_size) {
87+
std::optional<Variable> optional_out,
88+
std::optional<int64_t> dim_size) {
8989
dim = dim < 0 ? src.dim() + dim : dim;
9090
ctx->saved_data["dim"] = dim;
9191
ctx->saved_data["src_shape"] = src.sizes();
@@ -116,8 +116,8 @@ class ScatterMean : public torch::autograd::Function<ScatterMean> {
116116
public:
117117
static variable_list forward(AutogradContext *ctx, Variable src,
118118
Variable index, int64_t dim,
119-
torch::optional<Variable> optional_out,
120-
torch::optional<int64_t> dim_size) {
119+
std::optional<Variable> optional_out,
120+
std::optional<int64_t> dim_size) {
121121
dim = dim < 0 ? src.dim() + dim : dim;
122122
ctx->saved_data["dim"] = dim;
123123
ctx->saved_data["src_shape"] = src.sizes();
@@ -131,7 +131,7 @@ class ScatterMean : public torch::autograd::Function<ScatterMean> {
131131
auto ones = torch::ones(old_index.sizes(), src.options());
132132
result = scatter_fw(ones, old_index,
133133
old_index.dim() <= dim ? old_index.dim() - 1 : dim,
134-
torch::nullopt, out.size(dim), "sum");
134+
std::nullopt, out.size(dim), "sum");
135135
auto count = std::get<0>(result);
136136
count.masked_fill_(count < 1, 1);
137137
count = broadcast(count, out, dim);
@@ -164,8 +164,8 @@ class ScatterMin : public torch::autograd::Function<ScatterMin> {
164164
public:
165165
static variable_list forward(AutogradContext *ctx, Variable src,
166166
Variable index, int64_t dim,
167-
torch::optional<Variable> optional_out,
168-
torch::optional<int64_t> dim_size) {
167+
std::optional<Variable> optional_out,
168+
std::optional<int64_t> dim_size) {
169169
dim = dim < 0 ? src.dim() + dim : dim;
170170
ctx->saved_data["dim"] = dim;
171171
ctx->saved_data["src_shape"] = src.sizes();
@@ -200,8 +200,8 @@ class ScatterMax : public torch::autograd::Function<ScatterMax> {
200200
public:
201201
static variable_list forward(AutogradContext *ctx, Variable src,
202202
Variable index, int64_t dim,
203-
torch::optional<Variable> optional_out,
204-
torch::optional<int64_t> dim_size) {
203+
std::optional<Variable> optional_out,
204+
std::optional<int64_t> dim_size) {
205205
dim = dim < 0 ? src.dim() + dim : dim;
206206
ctx->saved_data["dim"] = dim;
207207
ctx->saved_data["src_shape"] = src.sizes();
@@ -234,37 +234,37 @@ class ScatterMax : public torch::autograd::Function<ScatterMax> {
234234

235235
SCATTER_API torch::Tensor
236236
scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
237-
torch::optional<torch::Tensor> optional_out,
238-
torch::optional<int64_t> dim_size) {
237+
std::optional<torch::Tensor> optional_out,
238+
std::optional<int64_t> dim_size) {
239239
return ScatterSum::apply(src, index, dim, optional_out, dim_size)[0];
240240
}
241241

242242
SCATTER_API torch::Tensor
243243
scatter_mul(torch::Tensor src, torch::Tensor index, int64_t dim,
244-
torch::optional<torch::Tensor> optional_out,
245-
torch::optional<int64_t> dim_size) {
244+
std::optional<torch::Tensor> optional_out,
245+
std::optional<int64_t> dim_size) {
246246
return ScatterMul::apply(src, index, dim, optional_out, dim_size)[0];
247247
}
248248

249249
SCATTER_API torch::Tensor
250250
scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim,
251-
torch::optional<torch::Tensor> optional_out,
252-
torch::optional<int64_t> dim_size) {
251+
std::optional<torch::Tensor> optional_out,
252+
std::optional<int64_t> dim_size) {
253253
return ScatterMean::apply(src, index, dim, optional_out, dim_size)[0];
254254
}
255255

256256
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
257257
scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim,
258-
torch::optional<torch::Tensor> optional_out,
259-
torch::optional<int64_t> dim_size) {
258+
std::optional<torch::Tensor> optional_out,
259+
std::optional<int64_t> dim_size) {
260260
auto result = ScatterMin::apply(src, index, dim, optional_out, dim_size);
261261
return std::make_tuple(result[0], result[1]);
262262
}
263263

264264
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
265265
scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim,
266-
torch::optional<torch::Tensor> optional_out,
267-
torch::optional<int64_t> dim_size) {
266+
std::optional<torch::Tensor> optional_out,
267+
std::optional<int64_t> dim_size) {
268268
auto result = ScatterMax::apply(src, index, dim, optional_out, dim_size);
269269
return std::make_tuple(result[0], result[1]);
270270
}

0 commit comments

Comments
 (0)