Skip to content

Commit d388f8f

Browse files
committed
pass device also with create_getattr_from_value to avoid infinite cache which eats up memory
1 parent a119db0 commit d388f8f

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

src/nncf/experimental/torch/fx/transformations.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,18 @@
7373
}
7474

7575

76+
# Referenced from: https://github.com/pytorch/pytorch/blob/9105d54c6b37099575c0059ef274c86c4dc80c57/torch/ao/quantization/utils.py#L711
77+
def _get_model_device(model: torch.fx.GraphModule) -> Any:
78+
"""
79+
Copied from torchao.quantization.pt2e.utils
80+
Returns the unique device for a module, or None if no device is found.
81+
Throws an error if multiple devices are detected.
82+
"""
83+
devices = {p.device for p in model.parameters()} | {p.device for p in model.buffers()}
84+
device = next(iter(devices))
85+
return device
86+
87+
7688
def _set_new_node_meta(
7789
new_node: torch.fx.Node,
7890
prev_nodes: tuple[Argument, ...],
@@ -239,7 +251,7 @@ def constant_update(
239251
"""
240252
graph = model.graph
241253
old_const = get_node_args(node)[input_port_id]
242-
254+
breakpoint()
243255
if old_const.op != "get_attr":
244256
msg = f"Constant on input port {input_port_id} for {node} is expected, but node {old_const} is present."
245257
raise nncf.InternalError(msg)
@@ -251,9 +263,12 @@ def constant_update(
251263
# To ensure the updated node has the right order,
252264
# we insert constant node before the node placed at the highest order in topological order.
253265
sorted_consumer_nodes = [node for node in graph.nodes if node in consumer_nodes]
266+
model_device = _get_model_device(model)
267+
tensor_device = value.device if isinstance(value, torch.Tensor) else model_device
254268

255269
with graph.inserting_before(sorted_consumer_nodes[0]):
256-
new_const = create_getattr_from_value(model, graph, node_name, value)
270+
# Passing device is neccesary to avoid large models to be cached by torchao.
271+
new_const = create_getattr_from_value(model, graph, node_name, value, device=tensor_device)
257272

258273
old_const.replace_all_uses_with(new_const, propagate_meta=True)
259274
graph.eliminate_dead_code()
@@ -431,7 +446,10 @@ def insert_one_qdq(model: torch.fx.GraphModule, target_point: PTTargetPoint, qua
431446
# With extra check of scale and zero_point being scalar, it makes
432447
# sure that the default overload can be used.
433448
# TODO(dlyakhov): maybe need more complex attr name here
434-
qparam_node = create_getattr_from_value(model, graph, target_node.name + key, value_or_node)
449+
tensor_device = value_or_node.device
450+
qparam_node = create_getattr_from_value(
451+
model, graph, target_node.name + key, value_or_node, device=tensor_device
452+
)
435453
quantize_op_inputs.append(qparam_node)
436454
else:
437455
# for qparams that are not scale/zero_point (like axis, dtype) we store

tests/torch/fx/test_model_transformer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name
3434
from nncf.experimental.torch.fx.node_utils import get_node_args
3535
from nncf.experimental.torch.fx.node_utils import get_tensor_constant_from_node
36+
from nncf.experimental.torch.fx.transformations import _get_model_device
3637
from nncf.experimental.torch.fx.transformations import _set_new_node_meta
3738
from nncf.experimental.torch.fx.transformations import compress_post_quantize_transformation
3839
from nncf.experimental.torch.fx.transformations import constant_update_transformation_builder
@@ -471,19 +472,22 @@ def insert_qdq_nodes(
471472
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default
472473

473474
conv_node = get_graph_node_by_name(model.graph, node_name)
475+
model_device = _get_model_device(model)
474476
if per_channel:
475477
with model.graph.inserting_before(conv_node):
476478
scale_node = create_getattr_from_value(
477479
model,
478480
model.graph,
479481
"scale_node",
480482
torch.ones([3]),
483+
device=model_device,
481484
)
482485
zp_node = create_getattr_from_value(
483486
model,
484487
model.graph,
485488
"weight_node",
486489
torch.ones([3]),
490+
device=model_device,
487491
)
488492
qdq_args = (scale_node, zp_node, 0, -128, 127, torch.int8)
489493
else:

0 commit comments

Comments
 (0)