forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbuiltin_function.h
116 lines (91 loc) · 3.03 KB
/
builtin_function.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#pragma once
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/core/function.h>
namespace torch {
namespace jit {
struct BuiltinOpFunction : public Function {
BuiltinOpFunction(
c10::QualifiedName qualname,
c10::FunctionSchema schema,
std::function<void(Stack&)> callable)
: name_(std::move(qualname)),
callable_(std::move(callable)),
schema_(std::move(schema)) {
TORCH_INTERNAL_ASSERT(schema_.returns().size() == 1);
}
bool isGraphFunction() const override {
return false;
}
void run(Stack& stack) override {
callable_(stack);
}
void run(Stack&& stack) override {
callable_(stack);
}
c10::intrusive_ptr<c10::ivalue::Future> runAsync(Stack& stack) override {
run(stack);
auto res = c10::make_intrusive<c10::ivalue::Future>(stack.front().type());
res->markCompleted(std::move(stack.front()));
return res;
}
at::IValue operator()(std::vector<at::IValue> stack, const Kwargs& kwargs)
override {
getSchema().checkAndNormalizeInputs(stack, kwargs);
callable_(stack);
return stack.front();
}
const c10::QualifiedName& qualname() const override {
return name_;
}
const std::string& name() const override {
return name_.name();
}
// if this isn't yet defined, run its method_creator function
void ensure_defined() override {
// nop
}
std::shared_ptr<Graph> graph() const override {
TORCH_INTERNAL_ASSERT(false , "BuiltinFunction had a graph requested "
"from it. This probably indicates that the JIT calling context needs a "
"special case on Function::isGraphFunction()");
}
std::shared_ptr<Graph> optimized_graph() const override {
TORCH_INTERNAL_ASSERT(false , "BuiltinFunction had a graph requested "
"from it. This probably indicates that the JIT calling context needs a "
"special case on Function::isGraphFunction()");
}
void clear_execution_info() override {
TORCH_INTERNAL_ASSERT(false , "BuiltinFunction had a graph requested "
"from it. This probably indicates that the JIT calling context needs a "
"special case on Function::isGraphFunction()");
}
GraphExecutor& get_executor() override {
TORCH_INTERNAL_ASSERT(false , "BuiltinFunction had a GraphExecutor requested "
"from it. This probably indicates that the JIT calling context needs a "
"special case on Function::isGraphFunction()");
}
const c10::FunctionSchema& getSchema() const override {
return schema_;
}
size_t num_inputs() const override {
return schema_.arguments().size();
}
void check_single_output() override {
TORCH_CHECK(schema_.returns().size() == 1);
}
std::string pretty_print_schema() const override {
TORCH_INTERNAL_ASSERT(false);
return "";
}
Function& setSchema(c10::FunctionSchema schema) override {
schema_ = std::move(schema);
return *this;
}
~BuiltinOpFunction() {}
private:
c10::QualifiedName name_;
std::function<void(Stack&)> callable_;
c10::FunctionSchema schema_;
};
} // namespace jit
} // namespace torch