Skip to content

Commit ced1e36

Browse files
committed
WIP: Enhance PyTorch frontend with inlined extension that captures any Python function as a custom PyTorch and OpenVINO operation.
1 parent 89b48ab commit ced1e36

File tree

10 files changed

+200
-2
lines changed

10 files changed

+200
-2
lines changed

src/bindings/python/src/openvino/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
# Helper functions for openvino module
7373
from openvino.utils.data_helpers import tensor_from_file
7474
from openvino._ov_api import compile_model
75+
from openvino.frontend.pytorch.inlined_extension import inlined_extension
7576

7677

7778
# Import opsets
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Copyright (C) 2018-2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# flake8: noqa
5+
# mypy: ignore-errors
6+
7+
import inspect
8+
from openvino.frontend.pytorch.ts_decoder import InlineConversionExtension
9+
from openvino.frontend.pytorch.utils import pt_to_ov_type_map
10+
import openvino as ov
11+
12+
global_counter_id = 0
13+
14+
# makes a custom op class from a func and input/output signatures
15+
def make_custom_op_class(func, input_signature, output_signature):
16+
import torch, numpy
17+
global global_counter_id
18+
# print('make_custom_op_class, id =', global_counter_id)
19+
class InlinedCustomOp(ov.Op):
20+
class_type_info = ov.runtime.DiscreteTypeInfo("InlinedCustomOp", "extension")
21+
22+
def __init__(self, *args):
23+
# TODO: What about attributes?
24+
super().__init__(self, args)
25+
self.attrs = {"id": global_counter_id} # `id` attribute distinguishes different instances of the same class, we need it because different instances may have different behaviour
26+
# print(f'Made custom op class with id = {self.attrs["id"]}')
27+
# print(f"Input signature: {input_signature}")
28+
# print(f"Output signature: {output_signature}")
29+
self.constructor_validate_and_infer_types()
30+
31+
def evaluate(self, outputs, inputs):
32+
# print("called evaluate")
33+
inputs_torch = (torch.from_numpy(input.data) for input in inputs) # TODO: Check memory sharing
34+
result = func(*inputs_torch)
35+
if not isinstance(result, tuple):
36+
result = (result,)
37+
for i, tensor in enumerate(result):
38+
ov.Tensor(numpy.array(tensor), shared_memory=True).copy_to(outputs[i]) # TODO: set the output tensor directly without copying
39+
return True
40+
41+
def has_evaluate(self, *args):
42+
return True
43+
44+
def visit_attributes(self, visitor):
45+
visitor.on_attributes(self.attrs)
46+
return True
47+
48+
def validate_and_infer_types(self):
49+
#TODO: Validate input signature
50+
for i, output in enumerate(output_signature):
51+
self.set_output_type(i, output[0], output[1])
52+
global_counter_id += 1
53+
return InlinedCustomOp
54+
55+
56+
def make_signature(args):
57+
# TODO: Extend beyond just tensors
58+
# convert each torch.Tensor object in args to a tuple (element_type, partial_shape) in OpenVINO terms
59+
return tuple((pt_to_ov_type_map[str(arg.dtype)], ov.PartialShape.dynamic(len(arg.shape))) for arg in args)
60+
61+
62+
# Returns a tuple of tuples (element_type, partial_shape) for each argument, flattening nested structures if needed, setting all dimensions dynamic preserving rank
63+
# Currently assumes that all input arguments are torch.Tensor objects
64+
def make_input_signature(args, kwargs):
65+
# TODO: Avoid the current limitation: kwargs parameters should be passed in the same order as the function signature without gaps
66+
# flatten kwargs relying on the order of the keys
67+
assert not kwargs, "Keyword arguments are not supported yet"
68+
return make_signature(args + tuple(kwargs.values()))
69+
70+
71+
def make_output_signature(args):
72+
if args is None:
73+
# TODO: This case is not really supported by PT FE -- because we don't support ops that do not have outputs, they will be lost
74+
args = ()
75+
if not isinstance(args, tuple):
76+
args = (args,)
77+
return make_signature(args)
78+
79+
80+
def is_class_method(obj):
81+
if not inspect.isfunction(obj) and not inspect.ismethod(obj):
82+
return False
83+
argspec = inspect.getfullargspec(obj)
84+
if argspec.args and argspec.args[0] == 'self':
85+
return True
86+
else:
87+
return False
88+
89+
90+
def make_trampoline_class(func, op, op_attrs):
91+
import torch
92+
class Trampoline(torch.autograd.Function):
93+
target_extension = InlineConversionExtension() # this is a marker for this type of extension
94+
95+
# This function defines how the operation behaves when called as a part of PyTorch model code in eager execution or while jit.trace
96+
@staticmethod
97+
def forward(ctx, *call_args, **call_kwargs): #TODO: what is `ctx`?
98+
# print('Called through the trampoline')
99+
func_target = func
100+
if not op:
101+
if is_class_method(func):
102+
self_obj = call_args[0]
103+
call_args = call_args[1:]
104+
wrapped = lambda *distil_args, **distil_kwargs: func(self_obj, *distil_args, **distil_kwargs)
105+
func_target = wrapped
106+
input_signature = make_input_signature(call_args, call_kwargs)
107+
# TODO: Try to trace `func` with the hope to obtain tracable shapes to build more precise `validate_and_infer_types` automatically (unlikely possible)
108+
result = func_target(*call_args, **call_kwargs)
109+
if not op:
110+
output_signature = make_output_signature(result)
111+
__class__.op = make_custom_op_class(func_target, input_signature, output_signature)
112+
else:
113+
__class__.op = op
114+
return result
115+
116+
# This function defines how the operation is represented in OpenVINO model graph
117+
@staticmethod
118+
def convert(node_context):
119+
inputs = [node_context.get_input(i) for i in range(node_context.get_input_size())]
120+
return __class__.op(*inputs, **op_attrs).outputs()
121+
122+
return Trampoline
123+
124+
125+
def inlined_extension(*args, **op_attrs):
126+
def make_trampoline(func, op=None):
127+
def trampoline(*args, **kwargs):
128+
# Keep trampoline class creation at the point when the function is called to make each time a new trampoline.
129+
# It is required because `func` is fused inside Trampoline class and can have different behaviour from call to call in PyTorch world even if
130+
# the same op is specified to wrap multiple different functions.
131+
trampoline = make_trampoline_class(func, op, op_attrs)
132+
return trampoline.apply(*args, **kwargs)
133+
return trampoline
134+
135+
if len(args) == 1 and callable(args[0]) and not (isinstance(args[0], type) and issubclass(args[0], ov.Op)):
136+
func = args[0]
137+
return make_trampoline(func)
138+
else:
139+
op = args[0]
140+
return lambda func: make_trampoline(func, op)

src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727

2828
log = logging.getLogger(__name__)
2929

30+
# A marker for a special type of conversion extension that is inlined in Trampoline class
31+
class InlineConversionExtension:
32+
pass
3033

3134
class TorchScriptPythonDecoder(Decoder):
3235
def __init__(
@@ -326,14 +329,17 @@ def get_subgraph_decoder(self, index: int):
326329
self.m_decoders.append(decoder)
327330
return decoder
328331

329-
def get_op_type(self) -> str:
332+
def get_op_extension(self):
330333
assert isinstance(
331334
self.graph_element, torch.Node), "Function can be called only when self.graph_element is of type torch.Node"
332335
if self.graph_element.kind() == "prim::PythonOp" and callable(getattr(self.graph_element, "pyobj", None)):
333336
pyobj = self.graph_element.pyobj()
334337
trampoline = getattr(pyobj, "__self__", None)
335-
target_extension = getattr(trampoline, "target_extension", None)
338+
return trampoline, getattr(trampoline, "target_extension", None)
336339

340+
def get_op_type(self) -> str:
341+
if op_extension := self.get_op_extension():
342+
trampoline, target_extension = op_extension
337343
if isinstance(target_extension, ModuleExtension):
338344
target_op = target_extension.target_op
339345
if callable(target_op):
@@ -589,3 +595,17 @@ def _transform_optional_constants(graph: torch.Graph):
589595
const_input.node().moveBefore(node)
590596
const_input.node().copyMetadata(node)
591597
node.output().replaceAllUsesWith(const_input)
598+
599+
def has_converter(self):
600+
if op_extension := self.get_op_extension():
601+
trampoline, target_extension = op_extension
602+
return isinstance(target_extension, InlineConversionExtension)
603+
return False
604+
605+
def convert(self, node_context):
606+
if op_extension := self.get_op_extension():
607+
trampoline, target_extension = op_extension
608+
assert isinstance(target_extension, InlineConversionExtension)
609+
result = trampoline.convert(node_context)
610+
return result
611+
assert False, "PyTorch FrontEnd Internal Error: `converter` method of TorchScriptPythonDecoder is called for node that has no custom converter"

src/bindings/python/src/pyopenvino/frontend/pytorch/decoder.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <pybind11/pybind11.h>
88

99
#include "openvino/frontend/pytorch/decoder.hpp"
10+
#include "openvino/frontend/node_context.hpp"
1011

1112
namespace py = pybind11;
1213

@@ -137,6 +138,14 @@ class PyDecoder : public ov::frontend::pytorch::TorchDecoder {
137138
ov::frontend::pytorch::DecoderRTInfo get_rt_info() const override {
138139
PYBIND11_OVERRIDE_PURE(ov::frontend::pytorch::DecoderRTInfo, TorchDecoder, get_rt_info);
139140
}
141+
142+
bool has_converter() const override {
143+
PYBIND11_OVERRIDE_PURE(bool, TorchDecoder, has_converter);
144+
}
145+
146+
ov::OutputVector convert(const ov::frontend::NodeContext* context) const override {
147+
PYBIND11_OVERRIDE_PURE(ov::OutputVector, TorchDecoder, convert, context);
148+
}
140149
};
141150

142151
void regclass_frontend_pytorch_decoder(py::module m);

src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#pragma once
66

77
#include "openvino/core/node.hpp"
8+
#include "openvino/frontend/node_context.hpp"
89
#include "openvino/frontend/decoder.hpp"
910
#include "openvino/frontend/pytorch/visibility.hpp"
1011

@@ -134,6 +135,14 @@ class PYTORCH_FRONTEND_API TorchDecoder : public IDecoder {
134135

135136
/// \brief Returns the rt_info for the element
136137
virtual DecoderRTInfo get_rt_info() const = 0;
138+
139+
/// @brief Returns if node has a custom converter that should be used instead (if any) of the default converter registered in front-end
140+
/// If this method returns true, `convert` method should be used as a conversion extension for this node instead of (any) default converter
141+
/// Such node may not have implemented other methods, like `get_op_type` that usually are implemented for "normal" nodes.
142+
virtual bool has_converter() const = 0;
143+
144+
/// @brief Converts the node if `has_converter` returns true
145+
virtual OutputVector convert(const ov::frontend::NodeContext* context) const = 0;
137146
};
138147

139148
} // namespace pytorch

src/frontends/pytorch/src/op/pythonop.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ namespace op {
1313

1414
OutputVector translate_pythonop(const NodeContext& context) {
1515
auto decoder = context.get_decoder();
16+
if(decoder->has_converter()) {
17+
// If the node has a custom converter, use it
18+
// A custom converter is defined for in-model definition of a custom operation.
19+
return decoder->convert(&context);
20+
}
1621
PYTORCH_OP_CONVERSION_CHECK(decoder->get_subgraph_size() == 1,
1722
"PythonOp must have 1 subgraph to be able to translate it to OV.");
1823
auto body = context.convert_subgraph(0);

src/frontends/pytorch/src/transforms/tuple_unpack_replacer.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ PrimTupleUnpackReplacer::PrimTupleUnpackReplacer() {
2929
auto input_node = tuple_unpack->get_input_node_shared_ptr(0);
3030
auto tuple_construct = cast_fw_node(input_node, "prim::TupleConstruct");
3131
if (!tuple_construct) {
32+
if(!ov::as_type_ptr<ov::op::util::FrameworkNode>(input_node)) {
33+
// remove TupleUnpack just bypassing it with all outputs from any op except FrameworkNode
34+
// We are leaving FrameworkNode case for further processing
35+
replace_node(tuple_unpack, input_node->outputs());
36+
return true;
37+
}
3238
return false;
3339
}
3440
for (const auto& input : input_node->inputs()) {

src/frontends/pytorch/src/utils.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,12 @@ class DummyDecoder : public TorchDecoder {
325325
virtual std::unordered_map<std::string, ov::Any> get_rt_info() const override {
326326
FRONT_END_NOT_IMPLEMENTED(get_rt_info);
327327
}
328+
virtual bool has_converter() const override {
329+
FRONT_END_NOT_IMPLEMENTED(has_converter);
330+
}
331+
virtual OutputVector convert(const ov::frontend::NodeContext* context) const override {
332+
FRONT_END_NOT_IMPLEMENTED(convert);
333+
}
328334

329335
private:
330336
const std::string m_schema = "NONE";

tools/benchmark_tool/openvino/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
# Helper functions for openvino module
7373
from openvino.utils.data_helpers import tensor_from_file
7474
from openvino._ov_api import compile_model
75+
from openvino.frontend.pytorch.inlined_extension import inlined_extension
7576

7677

7778
# Import opsets

tools/ovc/openvino/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
# Helper functions for openvino module
7373
from openvino.utils.data_helpers import tensor_from_file
7474
from openvino._ov_api import compile_model
75+
from openvino.frontend.pytorch.inlined_extension import inlined_extension
7576

7677

7778
# Import opsets

0 commit comments

Comments
 (0)