Skip to content

Commit 4e62d77

Browse files
committed
Update docs v0.5.0
1 parent 0a58c8a commit 4e62d77

File tree

408 files changed

+107273
-1
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

408 files changed

+107273
-1
lines changed

html/docs/stable

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
v0.3.0
1+
v0.5.0

html/docs/v0.5.0/.buildinfo

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Sphinx build info version 1
2+
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
3+
config: 8fc7f53b9f0393fb28b7e285b356fa7d
4+
tags: 645f666f9bcd5a90fca523b33c5a78b7
Binary file not shown.
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
"""
2+
.. _add-operator-resolve-rule:
3+
4+
Add Operator Resolve Rule
5+
=========================
6+
7+
This is a tutorial introduces the `operator resolving mechanism` and how to add a new operator resolve rule. An operator
8+
resolve rule is used to resolve an operator to other operators. Usually, we would resolve a more generic operator to
9+
more specific and efficient operators. The operator resolving rules allow us to reuse existing highly-optimized
10+
operators to implement a new operator, while organizing the operators in a more modular way.
11+
12+
13+
Operator Resolving
14+
------------------
15+
16+
The core idea of the **operator resolving** is to resolve a generic operator to more specific and high-optimized
17+
operators. When we define a new operator, we can also attach an operator resolve rule to it. The rule defines how to
18+
resolve the operator to other operators with the same semantics. After the operator is resolved, the original operator
19+
will be replaced by the resolved operators. This process is transparent to the user and is done automatically by a pass
20+
when we optimize a flow graph.
21+
22+
There are typical two scenarios that we need to resolve an operator to other operators:
23+
24+
- **Resolve a generic operator to specialized variants**: We can provide a generic operator and lots of its specialized
25+
variants. When optimizing the model, we can resolve the generic operator to the most suitable specialized operator.
26+
For example, in Hidet, we provided a generic :py:func:`~hidet.ops.matmul` operator with the same semantics as
27+
the numpy equivalent :py:data:`numpy.matmul`. This operator is a generic operator and is scheduled automatically by
28+
our auto-scheduler, thus it is not very efficient. But we also provided a lot of specialized variants of the operators
29+
such as highly-optimized :py:func:`~hidet.ops.batch_matmul` that only accepts :math:`A=[B, M, K]` and
30+
:math:`B=[B, K, N]`. During the operator resolving, we first reshape and broadcast the input tensors to align the
31+
input shapes with the specialized operator, then use the specialized operator to compute the result, and finally
32+
reshape the output tensor to get the correct output shape.
33+
34+
.. tip::
35+
:class: margin
36+
37+
During the operator resolving, we might introduce some extra operators to adjust the input tensors (e.g.,
38+
:func:`~hidet.ops.reshape`, :func:`~hidet.ops.broadcast`, :func:`~hidet.ops.transpose`, etc.).
39+
These extra operators are usually fused into the resolved operators or surrounding operators of the original operator
40+
in the later optimization pass. Thus, the extra overhead is usually negligible.
41+
42+
.. figure:: /_static/img/resolve-example-matmul.svg
43+
:align: center
44+
:scale: 70%
45+
46+
The resolve rule for `Matmul` operator.
47+
48+
- **Reuse a new operator to existing operators**: When we add a new operator and the new operator can be implemented by
49+
existing operators, we can use a resolve rule to resolve the new operator to the existing highly-optimized operators
50+
to reduce the development effort.
51+
52+
.. figure:: /_static/img/resolve-example-conv2d.svg
53+
:align: center
54+
:scale: 70%
55+
56+
This rule resolves the generic :func:`~hidet.ops.conv2d` operator to matrix multiplication using the img2col
57+
algorithm.
58+
59+
The operator resolving pass would repeat the resolving process until no more operators can be resolved. Thus, in the
60+
conv2d example, we will first resolve :func:`~hidet.ops.conv2d` to :func:`~hidet.ops.matmul`, and then
61+
to :func:`~hidet.ops.batch_matmul`.
62+
63+
Add Operator Resolve Rule
64+
-------------------------
65+
66+
To add a resolve rule to an operator, we need to
67+
68+
#. define a subclass of :class:`~hidet.graph.transforms.resolve_variant.ResolveRule` and then
69+
#. register the rule by decorating it with :func:`~hidet.graph.transforms.resolve_variant.register_resolve_rule`.
70+
71+
In the following example, we resolve the :func:`~hidet.ops.pow` operator to normal multiplications if the exponent
72+
is a constant integer 3.
73+
74+
Before we start, let's have a look at the original behavior when there is no such resolve rule.
75+
"""
76+
import hidet
77+
78+
a = hidet.symbol(shape=[2, 3], device='cuda')
79+
b = hidet.ops.pow(a, hidet.asarray(3, device='cuda'))
80+
graph = hidet.trace_from(b, inputs=[a])
81+
print('Original graph:')
82+
print(graph)
83+
84+
print('Optimized graph without resolving Pow:')
85+
graph_opt = hidet.graph.optimize(graph)
86+
print(graph_opt)
87+
88+
# %%
89+
# The original graph contains a :func:`~hidet.ops.pow` operator, and the optimized graph is the same as the
90+
# original graph. Now let's add the resolve rule and see what happens.
91+
92+
from typing import Optional, List
93+
from hidet import Tensor
94+
from hidet.graph.ops.arithmetic import PowOp
95+
from hidet.graph.transforms import register_resolve_rule, ResolveRule
96+
97+
98+
@register_resolve_rule(PowOp)
99+
class PowResolveRule(ResolveRule):
100+
def resolve(self, op: PowOp) -> Optional[List[Tensor]]:
101+
a: Tensor = op.inputs[0] # get the base tensor
102+
b: Tensor = op.inputs[1] # get the exponent tensor
103+
if not b.is_symbolic() and len(b.shape) == 0 and int(b) == 3:
104+
# if the exponent is a constant integer 3, resolve the operator to a * a * a
105+
return [a * a * a]
106+
# otherwise, return None to indicate that the operator cannot be resolved
107+
# and the original operator will be kept
108+
return None
109+
110+
111+
# optimize the original graph again
112+
# the Pow operator will be resolved to a * a * a
113+
# after that, the two multiplications will be fused into one operator
114+
graph_opt_new = hidet.graph.optimize(graph)
115+
print('Optimized graph after resolving Pow:')
116+
print(graph_opt_new)
117+
118+
119+
# %%
120+
# .. seealso::
121+
#
122+
# :func:`~hidet.graph.transforms.resolve_variant.register_resolve_rule`,
123+
# :class:`~hidet.graph.transforms.resolve_variant.ResolveRule` for the details of the resolve rule.
124+
#
125+
# Summary
126+
# -------
127+
# In this tutorial, we learned how to resolve an operator to other operators. We also learned how to add a resolve
128+
# rule to an operator. The resolve rule is a powerful tool to reuse existing operators to implement new operators.
129+
# We can also use it to resolve a generic operator to more specialized variants.
130+
#
Binary file not shown.
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"\n# Hello World!\n\nIn this example, we will show you how to use hidet to write a simple \"Hello World\" program.\n"
8+
]
9+
},
10+
{
11+
"cell_type": "markdown",
12+
"metadata": {},
13+
"source": [
14+
"Hidet is a deep learning compiler implemented in python. Let's import it first.\n\n"
15+
]
16+
},
17+
{
18+
"cell_type": "code",
19+
"execution_count": null,
20+
"metadata": {
21+
"collapsed": false
22+
},
23+
"outputs": [],
24+
"source": [
25+
"import hidet"
26+
]
27+
},
28+
{
29+
"cell_type": "markdown",
30+
"metadata": {},
31+
"source": [
32+
"Hidet caches all its generated source code and binary in its cache directory. We can set the cache directory\nto a local directory ``./outs/cache`` so that you can check the generated code and binary.\n\n"
33+
]
34+
},
35+
{
36+
"cell_type": "code",
37+
"execution_count": null,
38+
"metadata": {
39+
"collapsed": false
40+
},
41+
"outputs": [],
42+
"source": [
43+
"hidet.option.cache_dir('./outs/cache')"
44+
]
45+
},
46+
{
47+
"cell_type": "markdown",
48+
"metadata": {},
49+
"source": [
50+
"The ``hidet.lang`` submodule implements the Hidet Script domain specific language.\nIn this example, we will use ``attrs`` variable and ``printf`` function from ``hidet.lang``.\n\n"
51+
]
52+
},
53+
{
54+
"cell_type": "code",
55+
"execution_count": null,
56+
"metadata": {
57+
"collapsed": false
58+
},
59+
"outputs": [],
60+
"source": [
61+
"from hidet.lang import attrs, printf"
62+
]
63+
},
64+
{
65+
"cell_type": "markdown",
66+
"metadata": {},
67+
"source": [
68+
"A **script module** is a compilation unit that contains a list of functions defined in it. Inside a script module,\nwe can use ``hidet.script`` to define a hidet script function. The following example defines a function named\n``launch`` that prints a message to the standard output.\n\n"
69+
]
70+
},
71+
{
72+
"cell_type": "code",
73+
"execution_count": null,
74+
"metadata": {
75+
"collapsed": false
76+
},
77+
"outputs": [],
78+
"source": [
79+
"with hidet.script_module() as script_module:\n\n # we use `hidet.script` to decorate a python function to define a hidet script function.\n @hidet.script\n def launch():\n # we use `hidet.lang.attrs` to set the attributes of the function.\n # the following line specify this hidet script function is a public function.\n attrs.func_kind = 'public'\n\n # print a message to the standard output.\n printf(\"Hello World!\\n\")"
80+
]
81+
},
82+
{
83+
"cell_type": "markdown",
84+
"metadata": {},
85+
"source": [
86+
"With the script module defined, we can build the script module with ``build()`` method. The returned ``module`` is\nan instance of ``hidet.runtime.CompiledModule``, which contains the compiled binary.\n\n"
87+
]
88+
},
89+
{
90+
"cell_type": "code",
91+
"execution_count": null,
92+
"metadata": {
93+
"collapsed": false
94+
},
95+
"outputs": [],
96+
"source": [
97+
"module = script_module.build()"
98+
]
99+
},
100+
{
101+
"cell_type": "markdown",
102+
"metadata": {},
103+
"source": [
104+
"We can directly call the compiled module, in this case the 'launch' function would be invoked.\n\n<div class=\"alert alert-info\"><h4>Note</h4><p>:class: margin\n\n The printed message has not been captured by our documentation generation tool (i.e., sphinx).\n If you run the script by yourself, you will see the message printed out in your console.</p></div>\n\n"
105+
]
106+
},
107+
{
108+
"cell_type": "code",
109+
"execution_count": null,
110+
"metadata": {
111+
"collapsed": false
112+
},
113+
"outputs": [],
114+
"source": [
115+
"module()"
116+
]
117+
},
118+
{
119+
"cell_type": "markdown",
120+
"metadata": {},
121+
"source": [
122+
"We can also explicitly specify the function to be invoked using ``module['func_name'](args)``.\n\n"
123+
]
124+
},
125+
{
126+
"cell_type": "code",
127+
"execution_count": null,
128+
"metadata": {
129+
"collapsed": false
130+
},
131+
"outputs": [],
132+
"source": [
133+
"module['launch']()"
134+
]
135+
},
136+
{
137+
"cell_type": "markdown",
138+
"metadata": {},
139+
"source": [
140+
"you can access the source code of the compiled module using ``module.source()``.\n\n<div class=\"alert alert-info\"><h4>Note</h4><p>:class: margin\n\n The function in the source code has a prefix ``hidet_``, which is used to avoid name conflict with standard\n library functions.</p></div>\n\n"
141+
]
142+
},
143+
{
144+
"cell_type": "code",
145+
"execution_count": null,
146+
"metadata": {
147+
"collapsed": false
148+
},
149+
"outputs": [],
150+
"source": [
151+
"print(module.source(color=True))"
152+
]
153+
}
154+
],
155+
"metadata": {
156+
"kernelspec": {
157+
"display_name": "Python 3",
158+
"language": "python",
159+
"name": "python3"
160+
},
161+
"language_info": {
162+
"codemirror_mode": {
163+
"name": "ipython",
164+
"version": 3
165+
},
166+
"file_extension": ".py",
167+
"mimetype": "text/x-python",
168+
"name": "python",
169+
"nbconvert_exporter": "python",
170+
"pygments_lexer": "ipython3",
171+
"version": "3.10.16"
172+
}
173+
},
174+
"nbformat": 4,
175+
"nbformat_minor": 0
176+
}
Binary file not shown.
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""
2+
Hello World!
3+
============
4+
5+
In this example, we will show you how to use hidet to write a simple "Hello World" program.
6+
7+
"""
8+
# %%
9+
# Hidet is a deep learning compiler implemented in python. Let's import it first.
10+
import hidet
11+
12+
# %%
13+
# Hidet caches all its generated source code and binary in its cache directory. We can set the cache directory
14+
# to a local directory ``./outs/cache`` so that you can check the generated code and binary.
15+
hidet.option.cache_dir('./outs/cache')
16+
17+
# %%
18+
# The ``hidet.lang`` submodule implements the Hidet Script domain specific language.
19+
# In this example, we will use ``attrs`` variable and ``printf`` function from ``hidet.lang``.
20+
from hidet.lang import attrs, printf
21+
22+
# %%
23+
# A **script module** is a compilation unit that contains a list of functions defined in it. Inside a script module,
24+
# we can use ``hidet.script`` to define a hidet script function. The following example defines a function named
25+
# ``launch`` that prints a message to the standard output.
26+
27+
with hidet.script_module() as script_module:
28+
29+
# we use `hidet.script` to decorate a python function to define a hidet script function.
30+
@hidet.script
31+
def launch():
32+
# we use `hidet.lang.attrs` to set the attributes of the function.
33+
# the following line specify this hidet script function is a public function.
34+
attrs.func_kind = 'public'
35+
36+
# print a message to the standard output.
37+
printf("Hello World!\n")
38+
39+
40+
# %%
41+
# With the script module defined, we can build the script module with ``build()`` method. The returned ``module`` is
42+
# an instance of ``hidet.runtime.CompiledModule``, which contains the compiled binary.
43+
module = script_module.build()
44+
45+
# %%
46+
# We can directly call the compiled module, in this case the 'launch' function would be invoked.
47+
#
48+
# .. note::
49+
# :class: margin
50+
#
51+
# The printed message has not been captured by our documentation generation tool (i.e., sphinx).
52+
# If you run the script by yourself, you will see the message printed out in your console.
53+
module()
54+
55+
# %%
56+
# We can also explicitly specify the function to be invoked using ``module['func_name'](args)``.
57+
module['launch']()
58+
59+
# %%
60+
# you can access the source code of the compiled module using ``module.source()``.
61+
#
62+
# .. note::
63+
# :class: margin
64+
#
65+
# The function in the source code has a prefix ``hidet_``, which is used to avoid name conflict with standard
66+
# library functions.
67+
print(module.source(color=True))

0 commit comments

Comments
 (0)