Skip to content

Commit 5d51d4f

Browse files
committed
added .detach() to connst value, added new unit test
1 parent 7db5338 commit 5d51d4f

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

coremltools/converters/mil/frontend/torch/test/test_torch_conversion_api.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,32 @@ def forward(
929929
past_kv_len += 1
930930

931931

932+
@staticmethod
933+
def test_immediate_return_getattr_model():
934+
class ImmediateReturnGetAttrModel(torch.nn.Module):
935+
def __init__(self):
936+
super().__init__()
937+
self.register_buffer("my_constant_output", torch.tensor([1.0, 2.0, 3.0, 4.0]))
938+
self.register_buffer("my_constant_output2", torch.tensor([5.0, 6.0, 7.0, 8.0]))
939+
940+
def forward(self, x):
941+
# x is a dummy input, not used
942+
return self.my_constant_output, self.my_constant_output2
943+
944+
model = ImmediateReturnGetAttrModel()
945+
model.eval()
946+
dummy_input = torch.zeros(1) # Dummy input for tracing
947+
traced_model = torch.jit.trace(model, example_inputs=(dummy_input,))
948+
mlmodel = ct.convert(
949+
traced_model,
950+
inputs=[ct.TensorType(shape=(1,))],
951+
convert_to='mlprogram'
952+
)
953+
outputs = mlmodel.predict({"x": np.zeros(1)})
954+
assert "my_constant_output" in outputs
955+
assert "my_constant_output2" in outputs
956+
957+
932958
###############################################################################
933959
# Note: Stress tests for PyTorch input / output types
934960
###############################################################################

coremltools/converters/mil/frontend/torch/torchir_passes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def remove_getattr_nodes(graph: InternalTorchIRGraph) -> None:
251251
outputs=node.outputs,
252252
kind="constant",
253253
name="internal_immediate_output_attr",
254-
attr={"value": node.parent.params[node.name]}
254+
attr={"value": node.parent.params[node.name].detach()}
255255
)
256256
)
257257
else:

0 commit comments

Comments
 (0)