7373}
7474
7575
76+ # Referenced from: https://github.com/pytorch/pytorch/blob/9105d54c6b37099575c0059ef274c86c4dc80c57/torch/ao/quantization/utils.py#L711
77+ def _get_model_device (model : torch .fx .GraphModule ) -> Any :
78+ """
79+ Copied from torchao.quantization.pt2e.utils
80+ Returns the unique device for a module, or None if no device is found.
81+ Throws an error if multiple devices are detected.
82+ """
83+ devices = {p .device for p in model .parameters ()} | {p .device for p in model .buffers ()}
84+ device = next (iter (devices ))
85+ return device
86+
87+
7688def _set_new_node_meta (
7789 new_node : torch .fx .Node ,
7890 prev_nodes : tuple [Argument , ...],
@@ -239,7 +251,7 @@ def constant_update(
239251 """
240252 graph = model .graph
241253 old_const = get_node_args (node )[input_port_id ]
242-
254+ breakpoint ()
243255 if old_const .op != "get_attr" :
244256 msg = f"Constant on input port { input_port_id } for { node } is expected, but node { old_const } is present."
245257 raise nncf .InternalError (msg )
@@ -251,9 +263,12 @@ def constant_update(
251263 # To ensure the updated node has the right order,
252264 # we insert constant node before the node placed at the highest order in topological order.
253265 sorted_consumer_nodes = [node for node in graph .nodes if node in consumer_nodes ]
266+ model_device = _get_model_device (model )
267+ tensor_device = value .device if isinstance (value , torch .Tensor ) else model_device
254268
255269 with graph .inserting_before (sorted_consumer_nodes [0 ]):
256- new_const = create_getattr_from_value (model , graph , node_name , value )
270+ # Passing device is neccesary to avoid large models to be cached by torchao.
271+ new_const = create_getattr_from_value (model , graph , node_name , value , device = tensor_device )
257272
258273 old_const .replace_all_uses_with (new_const , propagate_meta = True )
259274 graph .eliminate_dead_code ()
@@ -431,7 +446,10 @@ def insert_one_qdq(model: torch.fx.GraphModule, target_point: PTTargetPoint, qua
431446 # With extra check of scale and zero_point being scalar, it makes
432447 # sure that the default overload can be used.
433448 # TODO(dlyakhov): maybe need more complex attr name here
434- qparam_node = create_getattr_from_value (model , graph , target_node .name + key , value_or_node )
449+ tensor_device = value_or_node .device
450+ qparam_node = create_getattr_from_value (
451+ model , graph , target_node .name + key , value_or_node , device = tensor_device
452+ )
435453 quantize_op_inputs .append (qparam_node )
436454 else :
437455 # for qparams that are not scale/zero_point (like axis, dtype) we store
0 commit comments