Skip to content

Commit d6ce90f

Browse files
Merge pull request #55 from TCLResearchEurope/add-rsqrt
add torch.rsqrt operation support
2 parents 40faae5 + a53969b commit d6ce90f

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

torch_dag/core/unstructured_to_structured.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,9 @@ def convert_node(self, node: torch.fx.node.Node, modules_dict, state_dict):
383383
assert len(node.args) == 1
384384
return structured_modules.MeanModule(**node.kwargs)
385385

386+
elif node.target == torch.rsqrt:
387+
return structured_modules.RsqrtModule()
388+
386389
elif node.target == torch.sum:
387390
raise NotImplementedError
388391

torch_dag/structured_modules.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1588,6 +1588,14 @@ def __init__(self, pow: Union[float, int]):
15881588
def forward(self, inputs: torch.Tensor):
15891589
return torch.pow(inputs, self.pow)
15901590

1591+
@register_notrace_module
1592+
class RsqrtModule(torch.nn.Module):
1593+
# wraps torch.rsqrt
1594+
def __init__(self):
1595+
super().__init__()
1596+
1597+
def forward(self, x) -> torch.Tensor:
1598+
return torch.rsqrt(x)
15911599

15921600
@register_notrace_module
15931601
class UnsqueezeModule(torch.nn.Module):

0 commit comments

Comments
 (0)