- 
                Notifications
    
You must be signed in to change notification settings  - Fork 710
 
Description
When a graph creates (constant) tensors as part of the IR, these memory allocations seem to neither be constants nor inputs. When and where is the allocation for these handled?
Here's a sample example I'm working on where the backward pass of sigmoid(x) requires the computation of 1 - sigmoid(x).
This causes the graph to materialize an all ones tensor on-the-fly.
Can these category of tensor allocations perhaps be exposed as constants or graph inputs so different backends can handle this more elegantly/explicitly?
Sample Example:
Click to Expand
import torch
import torch.nn as nn
from torch.export import export_for_training
from torch.export.experimental import _export_forward_backward
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge_transform_and_lower
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        return self.sigmoid(x).sum()
if __name__ == "__main__":
    model = Model()
    model.train()
    inputs = torch.randn((1, 512, 8192), requires_grad=True)
    exp = export_for_training(model, (inputs,))
    exp = _export_forward_backward(exp)
    print(f"\n\n Graph: {exp.graph_module.graph}\nInput Spec: {exp.graph_signature.input_specs}\n\nOutput Specs: {exp.graph_signature.output_specs}")
    edge = to_edge_transform_and_lower(
        exp,
        partitioner=[XnnpackPartitioner(force_fp32_dynamic_linear=True)],
        compile_config=EdgeCompileConfig(),
    )
    edge = edge.to_executorch(config=ExecutorchBackendConfig(
        external_mutable_weights=True,
        emit_mutable_buffer_names=True
    ))
    print(edge.exported_program())Generated Graph:
Relevant bits:
%full_like : [num_users=1] = call_function[target=torch.ops.aten.full_like.default](args = (%sum_1, 1), kwargs = {pin_memory: False, memory_format: torch.preserve_format})
%expand : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%full_like, [1, 512, 8192]), kwargs = {}) 
Click to Expand Full graph
Graph: graph():
    %x : [num_users=1] = placeholder[target=x]
    %sigmoid : [num_users=2] = call_function[target=torch.ops.aten.sigmoid.default](args = (%x,), kwargs = {})
    %alias : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%sigmoid,), kwargs = {})
    %alias_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias,), kwargs = {})
    %sum_1 : [num_users=2] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%sigmoid, []), kwargs = {})
    %full_like : [num_users=1] = call_function[target=torch.ops.aten.full_like.default](args = (%sum_1, 1), kwargs = {pin_memory: False, memory_format: torch.preserve_format})
    %expand : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%full_like, [1, 512, 8192]), kwargs = {})
    %alias_2 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_1,), kwargs = {})
    %alias_3 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%alias_2,), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (1, %alias_3), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%alias_3, %sub), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%expand, %mul), kwargs = {})
    return (sum_1, mul_1)
Input Spec: [InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)]
Output Specs: [OutputSpec(kind=<OutputKind.LOSS_OUTPUT: 2>, arg=TensorArgument(name='sum_1'), target=None), OutputSpec(kind=<OutputKind.GRADIENT_TO_USER_INPUT: 6>, arg=TensorArgument(name='mul_1'), target='arg0_1')]
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, _lifted_tensor_constant0: "i64[]", x: "f32[1, 512, 8192]"):
             # File: <eval_with_key>.15 from $PYTHONPATH/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:1350 in wrapped:14 in forward, code: sub = torch.ops.aten.sub.Tensor(1, alias_3)
            dim_order_ops__to_dim_order_copy_default: "f32[]" = executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default(_lifted_tensor_constant0, dtype = torch.float32, dim_order = []);  _lifted_tensor_constant0 = None
            # No stacktrace found for following nodes
            lowered_module_0 = self.lowered_module_0
            executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, x);  lowered_module_0 = x = None
            getitem_1: "f32[1, 512, 8192]" = executorch_call_delegate[0];  executorch_call_delegate = None
             # File: <eval_with_key>.15 from $PYTHONPATH/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:1350 in wrapped:7 in forward, code: alias = torch.ops.aten.alias.default(sigmoid)
            aten_alias_copy_default: "f32[1, 512, 8192]" = executorch_exir_dialects_edge__ops_aten_alias_copy_default(getitem_1)
             # File: <eval_with_key>.15 from $PYTHONPATH/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:1350 in wrapped:9 in forward, code: sum_1 = torch.ops.aten.sum.dim_IntList(sigmoid, []);  sigmoid = None
            aten_sum_dim_int_list: "f32[]" = executorch_exir_dialects_edge__ops_aten_sum_dim_IntList(getitem_1, []);  getitem_1 = None
             # File: <eval_with_key>.15 from $PYTHONPATH/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:1350 in wrapped:8 in forward, code: alias_1 = torch.ops.aten.alias.default(alias);  alias = None
            aten_alias_copy_default_1: "f32[1, 512, 8192]" = executorch_exir_dialects_edge__ops_aten_alias_copy_default(aten_alias_copy_default);  aten_alias_copy_default = None
             # File: <eval_with_key>.15 from $PYTHONPATH/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:1350 in wrapped:10 in forward, code: full_like = torch.ops.aten.full_like.default(sum_1, 1, pin_memory = False, memory_format = torch.preserve_format)
            aten_full_like_default: "f32[]" = executorch_exir_dialects_edge__ops_aten_full_like_default(aten_sum_dim_int_list, 1, pin_memory = False, memory_format = torch.preserve_format)
             # File: <eval_with_key>.15 from $PYTHONPATH/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:1350 in wrapped:12 in forward, code: alias_2 = torch.ops.aten.alias.default(alias_1);  alias_1 = None
            aten_alias_copy_default_2: "f32[1, 512, 8192]" = executorch_exir_dialects_edge__ops_aten_alias_copy_default(aten_alias_copy_default_1);  aten_alias_copy_default_1 = None
             # File: <eval_with_key>.15 from $PYTHONPATH/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:1350 in wrapped:11 in forward, code: expand = torch.ops.aten.expand.default(full_like, [1, 512, 8192]);  full_like = None
            aten_expand_copy_default: "f32[1, 512, 8192]" = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_full_like_default, [1, 512, 8192]);  aten_full_like_default = None
             # File: <eval_with_key>.15 from $PYTHONPATH/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:1350 in wrapped:13 in forward, code: alias_3 = torch.ops.aten.alias.default(alias_2);  alias_2 = None
            aten_alias_copy_default_3: "f32[1, 512, 8192]" = executorch_exir_dialects_edge__ops_aten_alias_copy_default(aten_alias_copy_default_2);  aten_alias_copy_default_2 = None
            # No stacktrace found for following nodes
            lowered_module_1 = self.lowered_module_1
            executorch_call_delegate_1 = torch.ops.higher_order.executorch_call_delegate(lowered_module_1, dim_order_ops__to_dim_order_copy_default, aten_alias_copy_default_3, aten_expand_copy_default);  lowered_module_1 = dim_order_ops__to_dim_order_copy_default = aten_alias_copy_default_3 = aten_expand_copy_default = None
            getitem: "f32[1, 512, 8192]" = executorch_call_delegate_1[0];  executorch_call_delegate_1 = None
            return (aten_sum_dim_int_list, getitem)
Graph signature:
    # inputs
    _lifted_tensor_constant0: BUFFER target='_lifted_tensor_constant0' persistent=True
    x: USER_INPUT
    # outputs
    aten_sum_dim_int_list: LOSS_OUTPUT
    getitem: GRADIENT_TO_USER_INPUT target='arg0_1'
Range constraints: {}Metadata
Metadata
Assignees
Labels
Type
Projects
Status