Skip to content

【Hackathon 8th No.3】del oldIR in engine -part #72542

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 3 additions & 32 deletions python/paddle/distributed/auto_parallel/static/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from paddle.static.amp.fp16_utils import _convert_float_to_bfloat16

from ...utils.log_utils import get_logger
from ..interface import CollectionNames, fetch, get_collection
from ..interface import CollectionNames, get_collection
from ..static.dist_tensor import DistributedTensor
from ..strategy import Strategy
from .callbacks import config_callbacks
Expand Down Expand Up @@ -580,37 +580,8 @@ def _prepare_fetch(self, user_fetches, mode):
# TODO(2024-Q2)
if self._in_pir_mode:
return fetch_names, fetch_indices

def _process_fetch_group(group_name, var_list):
group_indices = []
for var in var_list:
# Remove duplicate var_names
if self._is_local_var(var):
var_name = _to_name_str(var)
if var_name not in fetch_names:
fetch_names.append(var_name)
group_indices.append(fetch_names.index(var_name))
fetch_indices.append(group_indices)

dist_context = self._dist_contexts[mode]
fetch_vars = dist_context.serial_fetch_vars
if mode != "predict":
_process_fetch_group("loss", fetch_vars["loss"])
if mode != "predict":
metrics = fetch_vars["metrics"]
for i, var_list in enumerate(metrics):
_process_fetch_group("metrics_" + str(i), var_list)
if mode == "predict":
_process_fetch_group("outputs", fetch_vars["outputs"])
for usr_fetch in user_fetches or []:
var_name = _to_name_str(usr_fetch)
fetch(var_name)
user_fetches_collection = [
item[1] for item in get_collection(CollectionNames.FETCHES)
]
var_list = user_fetches_collection or []
_process_fetch_group("fetches", var_list)
return fetch_names, fetch_indices
else:
raise NotImplementedError("_prepare_fetch() only support PIR now.")

def _prepare_logger(
self,
Expand Down
Loading