diff --git a/test/run_tests.sh b/test/run_tests.sh index eeb2e8ee34d..c1720b53e99 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -6,6 +6,9 @@ MAX_GRAPH_SIZE=500 GRAPH_CHECK_FREQUENCY=100 VERBOSITY=2 +# Utils file +source "${CDIR}/utils/run_tests_utils.sh" + # Note [Keep Going] # # Set the `CONTINUE_ON_ERROR` flag to `true` to make the CI tests continue on error. @@ -112,16 +115,6 @@ function run_eager_debug { XLA_USE_EAGER_DEBUG_MODE=1 run_test "$@" } -function run_save_tensor_ir { - echo "Running in save tensor file mode: $@" - XLA_SAVE_TENSORS_FILE="/tmp/xla_test_save_ir.txt" XLA_SAVE_TENSORS_FMT="text" run_test "$@" -} - -function run_save_tensor_hlo { - echo "Running in save tensor file mode: $@" - XLA_SAVE_TENSORS_FILE="/tmp/xla_test_save_ir.txt" XLA_SAVE_TENSORS_FMT="hlo" run_test "$@" -} - function run_pt_xla_debug { echo "Running in save tensor file mode: $@" PT_XLA_DEBUG=1 PT_XLA_DEBUG_FILE="/tmp/pt_xla_debug.txt" run_test "$@" @@ -193,7 +186,7 @@ function run_xla_op_tests1 { run_test "$CDIR/dynamo/test_num_output.py" run_test "$CDIR/dynamo/test_graph_input_matcher.py" run_test "$CDIR/dynamo/test_dynamo_config.py" - run_save_tensor_ir "$CDIR/dynamo/test_dynamo_graph_dump.py" + run_save_tensor_ir run_test "$CDIR/dynamo/test_dynamo_graph_dump.py" run_test "$CDIR/test_data_type.py" run_use_bf16 "$CDIR/test_data_type.py" run_downcast_bf16 "$CDIR/test_data_type.py" @@ -201,8 +194,8 @@ function run_xla_op_tests1 { run_xla_ir_debug "$CDIR/test_env_var_mapper.py" run_xla_hlo_debug "$CDIR/test_env_var_mapper.py" run_xla_hlo_debug "$CDIR/stablehlo/test_stablehlo_save_load.py" - run_save_tensor_ir "$CDIR/spmd/test_spmd_graph_dump.py" - run_save_tensor_hlo "$CDIR/spmd/test_spmd_graph_dump.py" + run_save_tensor_ir run_test "$CDIR/spmd/test_spmd_graph_dump.py" + run_save_tensor_hlo run_test "$CDIR/spmd/test_spmd_graph_dump.py" } function run_xla_op_tests2 { @@ -248,7 +241,7 @@ function run_xla_op_tests3 { run_test "$CDIR/spmd/test_xla_auto_sharding.py" run_test "$CDIR/spmd/test_spmd_parameter_wrapping.py" run_test "$CDIR/spmd/test_mp_input_sharding.py" - run_test "$CDIR/spmd/test_spmd_lowering_context.py" + run_save_tensor_hlo run_test "$CDIR/spmd/test_spmd_lowering_context.py" run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY run_test "$CDIR/test_input_output_aliases.py" run_test "$CDIR/test_torch_distributed_xla_backend.py" diff --git a/test/spmd/test_spmd_graph_dump.py b/test/spmd/test_spmd_graph_dump.py index e3cadd7b3ce..9ef67ef0a85 100644 --- a/test/spmd/test_spmd_graph_dump.py +++ b/test/spmd/test_spmd_graph_dump.py @@ -23,8 +23,7 @@ def setUpClass(cls): def test_dump_with_output_sharding(self): save_file = os.getenv('XLA_SAVE_TENSORS_FILE') save_format = os.getenv('XLA_SAVE_TENSORS_FMT') - if not save_file: - assert False, "This test should be run with XLA_SAVE_TENSORS_FILE" + assert save_file, "This test should be run with XLA_SAVE_TENSORS_FILE" should_dump_output_sharding = (save_format == 'hlo') save_file += '.0' device = xm.xla_device() @@ -35,12 +34,10 @@ def test_dump_with_output_sharding(self): xla_sharded_x = xs.mark_sharding(xla_x, self._get_mesh((1, self.n_devices)), partition_spec) xla_res = xla_x + xla_y + xm.mark_step() with open(save_file, 'rb') as f: - current_line = sum(1 for line in f) - with open(save_file, 'rb') as f: - xm.mark_step() lines = f.readlines() - self.assertGreater(len(lines), current_line) + self.assertGreater(len(lines), 0) if should_dump_output_sharding: self.assertIn('OUTPUT_SHARDING_END', str(lines[-2])) else: diff --git a/test/spmd/test_spmd_lowering_context.py b/test/spmd/test_spmd_lowering_context.py index cb5018b1a4f..df872073073 100644 --- a/test/spmd/test_spmd_lowering_context.py +++ b/test/spmd/test_spmd_lowering_context.py @@ -1,4 +1,7 @@ +import os +import re import sys +from pathlib import Path import unittest @@ -6,10 +9,10 @@ import torch import torch_xla +import torch_xla.core.xla_builder as xb import torch_xla.debug.metrics as met import torch_xla.distributed.spmd as xs import torch_xla.core.xla_model as xm -import contextlib class TestSPMDLoweringContext(test_xla_sharding_base.XlaShardingTest): @@ -18,10 +21,91 @@ class TestSPMDLoweringContext(test_xla_sharding_base.XlaShardingTest): def setUpClass(cls): super().setUpClass() + def _get_computation_hlo_txt(self, ctx): + hlo = ctx.hlo() + comp = xb.computation_from_module_proto("my_custom_comp", hlo) + return xb.get_computation_hlo(comp) + + def test_basic(self): + save_file = os.getenv('XLA_SAVE_TENSORS_FILE') + save_format = os.getenv('XLA_SAVE_TENSORS_FMT') + assert save_file, "This test should be run with XLA_SAVE_TENSORS_FILE" + save_file += '.0' # Identify a single device + assert save_format == 'hlo', "This test should be run with XLA_SAVE_TENSORS_FMT=hlo" + + model_axis = max(1, self.n_devices // 2) + data_axis = self.n_devices // model_axis + mesh_shape = (data_axis, model_axis) + spmd_mesh = self._get_mesh(mesh_shape, axis_names=('x', 'y')) + + device = xm.xla_device() + a = torch.zeros(2048, device=device, requires_grad=True) + xs.mark_sharding(a, spmd_mesh, ('x',)) + b = torch.randn([32, 2048], device=device, requires_grad=True) + xs.mark_sharding(b, spmd_mesh, (None, 'y')) + + def fn(x, y): + x = x + 1 + return x, y * 2 + + result = fn(a, b) + ctx = torch_xla._XLAC.lowering.LoweringContext("MyCustomName") + ctx.build(list(result)) + torch_xla.sync() + + # Sanity HLO check. + hlo_text = ctx.hlo_text() + self.assertIn('MyCustomName', hlo_text) + self.assertIn('opcode: "parameter"', hlo_text) + self.assertIn('opcode: "add"', hlo_text) + self.assertIn('sharding', hlo_text) + + # Ensure that the corresponding input parameters contain the expected sharding. + hlo_comp_txt = self._get_computation_hlo_txt(ctx) + a_sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(a) + self.assertRegex( + hlo_comp_txt, + rf'%custom-call.*.*f32[2048]{{0}}.*sharding={re.escape(a_sharding_spec)}' + ) + b_sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(b) + self.assertRegex( + hlo_comp_txt, + rf'%custom-call.*f32[32,2048]{{0}}.*sharding={re.escape(b_sharding_spec)}' + ) + + # Ensure that the results retain the same sharding specs. + result_a, result_b = result + self.assertEqual( + torch_xla._XLAC._get_xla_sharding_spec(result_a), a_sharding_spec) + self.assertEqual( + torch_xla._XLAC._get_xla_sharding_spec(result_b), b_sharding_spec) + + hlo_content = Path(save_file).read_text() + assert len(re.findall('END_GRAPH', + hlo_content)) == 1, "There is a single graph" + + # Extract the content between OUTPUT_SHARDING_BEGIN and OUTPUT_SHARDING_END + pattern = r'#OUTPUT_SHARDING_BEGIN\n(.*?)\n#OUTPUT_SHARDING_END' + match = re.search(pattern, hlo_content, re.DOTALL) + assert match is not None, "#OUTPUT_SHARDING not found in the file" + assert len(match.groups() + ) == 1, f"Expected 1 group, but found {len(match.groups())}" + expected_output = match.group(1).strip().split('\n') + + # Assert that the output sharding match our expectation. + assert len(expected_output + ) == 4, f"Expected 4 lines, but found {len(expected_output)}" + assert expected_output[0] == f"f32[2048] {a_sharding_spec}" + assert expected_output[1] == f"f32[32,2048] {b_sharding_spec}" + assert expected_output[2] == f"f32[2048] {a_sharding_spec}" + assert expected_output[3] == f"f32[32,2048] {b_sharding_spec}" + self.assertTrue(met.counter_value("ExecuteReplicated") == 1) + self.assertTrue(met.counter_value("ExecuteComputation") is None) + def test_device_parameter_id_tensor_mapping(self): met.clear_all() - model_axis = min(8, self.n_devices) + model_axis = max(1, self.n_devices // 2) data_axis = self.n_devices // model_axis mesh_shape = (data_axis, model_axis) spmd_mesh = self._get_mesh(mesh_shape, axis_names=('x', 'y')) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 0134e6730f8..c6e90e9ae0a 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -1,66 +1,70 @@ #!/bin/bash set -xue +CDIR="$(cd "$(dirname "$0")" ; pwd -P)" +TEST_CDIR="$(dirname "$CDIR")" + +source "${TEST_CDIR}/utils/run_tests_utils.sh" # TODO: merge with other run_tests -python3 test/test_operations.py -v -python3 test/pjrt/test_runtime_tpu.py -python3 test/pjrt/test_collective_ops_tpu.py -python3 test/spmd/test_mp_input_sharding.py -python3 test/spmd/test_spmd_lowering_context.py -python3 test/spmd/test_xla_sharding.py -python3 test/spmd/test_xla_virtual_device.py -python3 test/spmd/test_xla_distributed_checkpoint.py -python3 test/spmd/test_train_spmd_linear_model.py -python3 test/spmd/test_xla_spmd_python_api_interaction.py -python3 test/spmd/test_xla_auto_sharding.py -python3 test/spmd/test_fsdp_v2.py -XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shape_models.py -v -python3 test/test_autocast.py -python3 test/test_fp8.py -python3 test/test_grad_checkpoint.py -python3 test/dynamo/test_dynamo.py -python3 test/dynamo/test_dynamo_dynamic_shape.py -python3 test/spmd/test_spmd_debugging.py -XLA_PARAMETER_WRAPPING_THREADSHOLD=1 python test/spmd/test_spmd_parameter_wrapping.py -python3 test/pjrt/test_dtypes.py -python3 test/pjrt/test_dynamic_plugin_tpu.py -python3 test/test_while_loop.py -python3 test/scan/test_scan.py -python3 test/scan/test_scan_spmd.py -python3 test/scan/test_scan_layers.py -python3 test/test_pallas.py -v -python3 test/test_pallas_spmd.py -python3 test/test_tpu_paged_attention_kernel.py -python3 test/test_input_output_aliases.py -python3 test/test_gmm.py -python3 test/eager/test_eager_spmd.py -python3 test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py -python3 test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py -python3 test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py -python3 test/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py -python3 test/quantized_ops/test_dot_general.py +python3 "$TEST_CDIR/test_operations.py" -v +python3 "$TEST_CDIR/pjrt/test_runtime_tpu.py" +python3 "$TEST_CDIR/pjrt/test_collective_ops_tpu.py" +python3 "$TEST_CDIR/spmd/test_mp_input_sharding.py" +run_save_tensor_hlo python3 "$TEST_CDIR/spmd/test_spmd_lowering_context.py" +python3 "$TEST_CDIR/spmd/test_xla_sharding.py" +python3 "$TEST_CDIR/spmd/test_xla_virtual_device.py" +python3 "$TEST_CDIR/spmd/test_xla_distributed_checkpoint.py" +python3 "$TEST_CDIR/spmd/test_train_spmd_linear_model.py" +python3 "$TEST_CDIR/spmd/test_xla_spmd_python_api_interaction.py" +python3 "$TEST_CDIR/spmd/test_xla_auto_sharding.py" +python3 "$TEST_CDIR/spmd/test_fsdp_v2.py" +XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 "$TEST_CDIR/ds/test_dynamic_shape_models.py" -v +python3 "$TEST_CDIR/test_autocast.py" +python3 "$TEST_CDIR/test_fp8.py" +python3 "$TEST_CDIR/test_grad_checkpoint.py" +python3 "$TEST_CDIR/dynamo/test_dynamo.py" +python3 "$TEST_CDIR/dynamo/test_dynamo_dynamic_shape.py" +python3 "$TEST_CDIR/spmd/test_spmd_debugging.py" +XLA_PARAMETER_WRAPPING_THREADSHOLD=1 python3 "$TEST_CDIR/spmd/test_spmd_parameter_wrapping.py" +python3 "$TEST_CDIR/pjrt/test_dtypes.py" +python3 "$TEST_CDIR/pjrt/test_dynamic_plugin_tpu.py" +python3 "$TEST_CDIR/test_while_loop.py" +python3 "$TEST_CDIR/scan/test_scan.py" +python3 "$TEST_CDIR/scan/test_scan_spmd.py" +python3 "$TEST_CDIR/scan/test_scan_layers.py" +python3 "$TEST_CDIR/test_pallas.py" -v +python3 "$TEST_CDIR/test_pallas_spmd.py" +python3 "$TEST_CDIR/test_tpu_paged_attention_kernel.py" +python3 "$TEST_CDIR/test_input_output_aliases.py" +python3 "$TEST_CDIR/test_gmm.py" +python3 "$TEST_CDIR/eager/test_eager_spmd.py" +python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_all_gather_xla_backend.py" +python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py" +python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py" +python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py" +python3 "$TEST_CDIR/quantized_ops/test_dot_general.py" # run examples, each test should takes <2 minutes -python3 examples/data_parallel/train_resnet_spmd_data_parallel.py -python3 examples/fsdp/train_decoder_only_fsdp_v2.py -python3 examples/train_resnet_amp.py +python3 "$TEST_CDIR/../examples/data_parallel/train_resnet_spmd_data_parallel.py" +python3 "$TEST_CDIR/../examples/fsdp/train_decoder_only_fsdp_v2.py" +python3 "$TEST_CDIR/../examples/train_resnet_amp.py" # HACK: don't confuse local `torch_xla` folder with installed package # Python 3.11 has the permanent fix: https://stackoverflow.com/a/73636559 # Egaer tests will take more HBM, only run them on TPU v4 CI TPU_VERSION=$(python -c "import sys; sys.path.remove(''); import torch_xla; print(torch_xla._internal.tpu.version())") if [[ -n "$TPU_VERSION" && "$TPU_VERSION" == "4" ]]; then - python3 test/dynamo/test_traceable_collectives.py - python3 examples/data_parallel/train_resnet_xla_ddp.py - python3 examples/fsdp/train_resnet_fsdp_auto_wrap.py - python3 examples/eager/train_decoder_only_eager.py - python3 examples/eager/train_decoder_only_eager_spmd_data_parallel.py - python3 examples/eager/train_decoder_only_eager_with_compile.py - python3 examples/eager/train_decoder_only_eager_multi_process.py - XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shapes.py -v + python3 "$TEST_CDIR/dynamo/test_traceable_collectives.py" + python3 "$TEST_CDIR/../examples/data_parallel/train_resnet_xla_ddp.py" + python3 "$TEST_CDIR/../examples/fsdp/train_resnet_fsdp_auto_wrap.py" + python3 "$TEST_CDIR/../examples/eager/train_decoder_only_eager.py" + python3 "$TEST_CDIR/../examples/eager/train_decoder_only_eager_spmd_data_parallel.py" + python3 "$TEST_CDIR/../examples/eager/train_decoder_only_eager_with_compile.py" + python3 "$TEST_CDIR/../examples/eager/train_decoder_only_eager_multi_process.py" + XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 "$TEST_CDIR/ds/test_dynamic_shapes.py" -v fi if [[ -n "$TPU_VERSION" && "$TPU_VERSION" != "6" ]]; then # Test `tpu-info` CLI compatibility - python3 test/tpu/tpu_info/test_cli.py + python3 "$CDIR/tpu_info/test_cli.py" fi diff --git a/test/utils/run_tests_utils.sh b/test/utils/run_tests_utils.sh new file mode 100755 index 00000000000..9f519a2885d --- /dev/null +++ b/test/utils/run_tests_utils.sh @@ -0,0 +1,56 @@ +#!/bin/bash +set -exo pipefail + +# Run a test with tensor saving enabled, using a specified graph format. The +# graph dump files are cleaned after the test. In case the test crashes, the +# file is retained. +# +# Usage: run_save_tensor [test arguments...] +# +# Arguments: +# exec: The executable or function to run the test (python3 or any function) +# format: The graph format to use with XLA_SAVE_TENSORS_FMT +# test arguments: Arguments to pass to the test +# +# Environment: +# Sets XLA_SAVE_TENSORS_FILE and XLA_SAVE_TENSORS_FMT +function run_save_tensor { + local run_test_func="$1" ; local file_graph_format="$2" ; shift 2 + + echo "Running in save tensor file mode: $@" + local base_file="/tmp/xla_test_save_ir.txt" + + # Check if the file already exists, for any device ordinal number. + if ls "${base_file}"* 1> /dev/null 2>&1; then + echo "Error: File ${base_file} or a numbered version already exists. Please remove it before running the test." + return 1 + fi + + XLA_SAVE_TENSORS_FILE="$base_file" XLA_SAVE_TENSORS_FMT="$file_graph_format" $run_test_func "$@" + local test_status=$? + + # Clean up the file once the test finalizes. + local actual_file + actual_file=$(ls "${base_file}"* 2>/dev/null | head -n1) + if [ -f "$actual_file" ]; then + echo "Cleaning up temporary file: $actual_file" + rm "$actual_file" + else + echo "Warning: Expected output file not found" + fi + return $test_status +} + +function run_save_tensor_ir { + local run_test_func="$1" + shift + echo "Running in save tensor file mode: $@" + run_save_tensor "$run_test_func" "text" "$@" +} + +function run_save_tensor_hlo { + local run_test_func="$1" + shift + echo "Running in save tensor file mode: $@" + run_save_tensor "$run_test_func" "hlo" "$@" +} diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index b42a084e85c..b348c898974 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1006,6 +1006,9 @@ class PyLoweringContext { torch::lazy::Output(ir_value.node.get(), ir_value.index)); lowering_ctx.AddResult(root); } + + ShardingUtil::SetHloSharding(&lowering_ctx); + computation = ConsumeValue(lowering_ctx.BuildXla()); } @@ -1048,21 +1051,27 @@ class PyLoweringContext { } } + ShardingUtil::SetHloSharding(&lowering_ctx); + computation = ConsumeValue(lowering_ctx.BuildXla()); // wrap inputs of cond/body_computation if ((GetNameString() == "condctx") || (GetNameString() == "bodyctx")) { std::vector> input_output_alias_pair; - std::vector buffer_donor_indices; + std::vector param_shardings; + // If sharded, then extract all input Op shardings. + if (UseVirtualDevice()) { + param_shardings = XlaHelpers::ExtractInputShardings(computation); + } xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape()); // TODO(@manfei): please confirm whether we check for more than two or use // default value true bool should_wrap_parameter = (program_shape.parameters_size() >= 2); if (should_wrap_parameter) { - // For now we assume that we for i loop input is not sharded. computation = ConsumeValue(XlaHelpers::WrapXlaComputation( - computation, program_shape.parameters(), {}, buffer_donor_indices)); + computation, program_shape.parameters(), param_shardings, + /* buffer_donor_indices */ {})); } } } diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index c2db9b36309..6c2906dc724 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -111,23 +111,31 @@ LoweringContext::LoweringContext( static constexpr int64_t kUnboundedSize = std::numeric_limits::min(); xla::XlaOp LoweringContext::GetParameter( - const std::shared_ptr& data, + const std::shared_ptr& backend_data, const std::unordered_set& unbounded_dynamic_dims) { - torch::lazy::BackendData::Handle handle = data->GetHandle(); + torch::lazy::BackendData::Handle handle = backend_data->GetHandle(); auto it = parameters_map_.find(handle); if (it == parameters_map_.end()) { - xla::Shape shape = - std::dynamic_pointer_cast(data) - ->shape(); + auto data = std::dynamic_pointer_cast( + backend_data); + XLA_CHECK(data != nullptr); + xla::Shape shape = data->shape(); for (const int dim : unbounded_dynamic_dims) { shape.set_dynamic_dimension(dim, true); shape.set_dimensions(dim, kUnboundedSize); } - xla::XlaOp param = xla::Parameter(builder(), parameters_.size(), shape, - absl::StrCat("p", parameters_.size())); - it = parameters_map_.emplace(handle, Parameter{param, parameters_.size()}) - .first; - parameters_.push_back(data); + size_t param_index = parameters_.size(); + std::string param_name = absl::StrCat("p", param_index); + xla::XlaOp param; + if (data->HasSharding()) { + xla::OpSharding sharding = data->GetSharding(); + xla::XlaScopedShardingAssignment scoped_sharding(builder(), sharding); + param = xla::Parameter(builder(), param_index, shape, param_name); + } else { + param = xla::Parameter(builder(), param_index, shape, param_name); + } + it = parameters_map_.emplace(handle, Parameter{param, param_index}).first; + parameters_.push_back(backend_data); } else { XLA_CHECK(unbounded_dynamic_dims.empty()) << "The unbounded dynamic dims can only be set when Parameter is " @@ -138,8 +146,8 @@ xla::XlaOp LoweringContext::GetParameter( } std::optional LoweringContext::GetParameterId( - const std::shared_ptr& data) const { - torch::lazy::BackendData::Handle handle = data->GetHandle(); + const std::shared_ptr& backend_data) const { + torch::lazy::BackendData::Handle handle = backend_data->GetHandle(); auto it = parameters_map_.find(handle); if (it == parameters_map_.end()) { return std::nullopt; diff --git a/torch_xla/csrc/lowering_context.h b/torch_xla/csrc/lowering_context.h index 3a36695e1c0..cb4f0bc2d2f 100644 --- a/torch_xla/csrc/lowering_context.h +++ b/torch_xla/csrc/lowering_context.h @@ -50,13 +50,13 @@ class LoweringContext : public torch::lazy::LoweringContext { // returned. Otherwise a new one will be created, associated with the tensor // held in data. xla::XlaOp GetParameter( - const std::shared_ptr& data, + const std::shared_ptr& backend_data, const std::unordered_set& dynamic_dims = {}); // If a parameter associated with data has already been declared, returns its // ID. Otherwise, returns `std::nullopt`. std::optional GetParameterId( - const std::shared_ptr& data) const; + const std::shared_ptr& backend_data) const; // Retrieves the vector holding all the tensors associated with the parameter // instructions which have been created.