Skip to content

【Hackathon 8th No.3】 clean oldIR for engine #71929

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 36 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
5c267e4
clean the oldIR in api.py
aquagull Feb 24, 2025
e79638e
clean oldIR in engine.py
aquagull Feb 24, 2025
8f32e51
revert and delete deprecated test
aquagull Feb 24, 2025
a8660bd
revert
aquagull Feb 26, 2025
b8e5763
revert reshard to avoid test_semi_auto_parallel_global_input.test_sta…
aquagull Feb 26, 2025
d5b1b89
revert deprecated test
aquagull Feb 26, 2025
48e0657
revert cmake
aquagull Feb 26, 2025
b691f5a
delete testslist.csv
aquagull Feb 27, 2025
10e2e56
update
aquagull Mar 26, 2025
38ca789
Merge branch 'PaddlePaddle:develop' into dis
aquagull Mar 26, 2025
cba0886
update
aquagull Mar 26, 2025
435ba42
update
aquagull Mar 26, 2025
2fef84b
update
aquagull Mar 26, 2025
7b7e70b
fix
aquagull Mar 26, 2025
79212b5
update
aquagull Mar 27, 2025
193ce7f
update
aquagull Mar 27, 2025
12d5180
Merge branch 'develop' into engine
aquagull Mar 27, 2025
8377221
Merge branch 'PaddlePaddle:develop' into engine
aquagull Mar 27, 2025
c8408cd
Merge branch 'PaddlePaddle:develop' into dis
aquagull Mar 27, 2025
9d6aa40
add
aquagull Mar 31, 2025
580fcd2
Merge branch 'PaddlePaddle:develop' into engine
aquagull Apr 2, 2025
a944fa1
Merge branch 'dis' into engine
aquagull Apr 2, 2025
184acac
Merge branch 'PaddlePaddle:develop' into engine
aquagull Apr 2, 2025
ad9456a
update
aquagull Apr 2, 2025
d49bdf3
fix
aquagull Apr 2, 2025
ae8d6ed
fix
aquagull Apr 2, 2025
4af8c30
add FLAGS_enable_pir_api
aquagull Apr 2, 2025
b8bdfbc
fix
aquagull Apr 2, 2025
46cb554
update
aquagull Apr 3, 2025
44b005a
reback
aquagull Apr 10, 2025
dcf2e37
update
aquagull Apr 10, 2025
029c365
fix
aquagull Apr 10, 2025
b6801fd
Update CMakeLists.txt
aquagull Apr 14, 2025
73b23be
empty
aquagull Apr 14, 2025
f79be66
Merge branch 'develop' into engine
aquagull Apr 24, 2025
8dbd53b
update
aquagull Apr 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 23 additions & 122 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,17 @@
from paddle import _C_ops, nn, pir
from paddle.amp.grad_scaler import OptimizerState
from paddle.autograd import PyLayer
from paddle.base import unique_name
from paddle.base.dygraph.base import switch_to_static_graph
from paddle.base.framework import (
EagerParamBase,
Variable,
default_main_program,
in_dygraph_mode,
in_pir_mode,
use_pir_api,
)
from paddle.distributed import fleet
from paddle.distributed.auto_parallel import Engine, strategy as auto_strategy
from paddle.distributed.auto_parallel.interface import (
shard_tensor as shard_tensor_static,
)
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.static.completion import (
mark_as_sharding_propagation_skip_op,
)
from paddle.distributed.auto_parallel.static.dist_context import (
get_default_distributed_context,
)
from paddle.distributed.auto_parallel.static.dist_op import DistributedOperator
from paddle.distributed.auto_parallel.static.utils import (
convert_to_dims_mapping,
fuse_param_func,
get_dist_attr,
split_mesh,
Expand Down Expand Up @@ -279,15 +265,7 @@ def shard_tensor(
if stop_gradient is None:
stop_gradient = getattr(data, "stop_gradient", True)

if paddle.framework.in_pir_mode():
assert isinstance(
data, (type(None), pir.Value)
), "input tensor is not pir value."
assert (
data.is_dense_tensor_type()
), "shard_tensor() input data only supported dense tensor type right."
tensor = data
else:
if paddle.in_dynamic_mode():
if isinstance(data, EagerParamBase) and not data._is_initialized():
assert (
data._init_func is not None
Expand All @@ -302,7 +280,6 @@ def shard_tensor(
data, dtype=dtype, place=place, stop_gradient=stop_gradient
)

if paddle.in_dynamic_mode():
# here the dist tensor is deep copy constructed
if isinstance(data, EagerParamBase):

Expand Down Expand Up @@ -353,14 +330,21 @@ def _init_func(var, block):
dist_tensor.stop_gradient = tensor.stop_gradient
return dist_tensor
elif paddle.framework.in_pir_mode():
dist_tensor = paddle._C_ops.shard_tensor(tensor, mesh, placements)
dist_tensor.stop_gradient = tensor.stop_gradient
dist_tensor.persistable = tensor.persistable
assert isinstance(
data, (type(None), pir.Value)
), "input tensor is not pir value."
assert (
data.is_dense_tensor_type()
), "shard_tensor() input data only supported dense tensor type right."

dist_tensor = paddle._C_ops.shard_tensor(data, mesh, placements)
dist_tensor.stop_gradient = data.stop_gradient
dist_tensor.persistable = data.persistable
return dist_tensor
else:
# TODO(zhiqiu): we need to refine the static shard_tensor
sharding_specs = get_shard_spec(mesh, placements, tensor.ndim)
return shard_tensor_static(tensor, mesh, sharding_specs)
raise NotImplementedError(
"`shard_tensor()` only supported in dynamic and pir mode."
)


class _moe_global_mesh_tensor(PyLayer):
Expand Down Expand Up @@ -861,53 +845,9 @@ def reshard(
elif in_pir_mode():
return paddle._C_ops.reshard(dist_tensor, mesh, placements)
else:
assert isinstance(
dist_tensor, Variable
), f"in dy2static mode, reshard's input should be Variable, but got [{dist_tensor}]"
sharding_specs = get_shard_spec(mesh, placements, dist_tensor.ndim)
main_program = default_main_program()
default_dist_ctx = get_default_distributed_context()

# output variable
out_var = main_program.current_block().create_var(
name=unique_name.generate_with_ignorable_key(
".".join(['reshard_api', 'tmp'])
),
dtype=dist_tensor.dtype,
shape=dist_tensor.shape,
type=dist_tensor.type,
persistable=dist_tensor.persistable,
stop_gradient=dist_tensor.stop_gradient,
)

# transition op
# optimization in future to remove redundant D2D memory copy
target_dims_mapping = convert_to_dims_mapping(sharding_specs, mesh)
trans_op = main_program.current_block().append_op(
type='assign',
inputs={'X': [dist_tensor]},
outputs={'Out': [out_var]},
)
dist_op = DistributedOperator(trans_op)
dist_op.dist_attr.process_mesh = mesh
dist_op.dist_attr.mark_annotated("process_mesh")
dist_op.dist_attr.chunk_id = 0

input_dist_attr = dist_op.dist_attr.get_input_dist_attr(
dist_tensor.name
raise NotImplementedError(
"`reshard()` only supported in dynamic and pir mode."
)
input_dist_attr.dims_mapping = target_dims_mapping
input_dist_attr.mark_annotated("dims_mapping")
output_dist_attr = dist_op.dist_attr.get_output_dist_attr(out_var.name)
output_dist_attr.dims_mapping = target_dims_mapping
output_dist_attr.mark_annotated("dims_mapping")

default_dist_ctx.add_dist_op_for_program(dist_op)
mark_as_sharding_propagation_skip_op(trans_op)
# trans_op = shard_op_static(paddle.assign, mesh, [sharding_specs])
# out_var = trans_op(dist_tensor)

return out_var


def shard_layer(
Expand Down Expand Up @@ -2550,12 +2490,7 @@ def __init__(
if (
not self._in_pir_mode
): # TODO (2024-Q2) remove this when pir mode is fully constructed.
if optimizer is not None and loss is not None:
self.train()
elif loss is not None:
self.eval()
else:
self.predict()
raise NotImplementedError("Only supported in dynamic and pir mode.")

def train(self) -> None:
"""
Expand Down Expand Up @@ -2840,9 +2775,7 @@ def state_dict(
mode=self._engine._mode
).state_dict(mode, scope)
else:
local_state_dict = self.dist_main_program(
mode=self._engine._mode
).state_dict(mode)
raise NotImplementedError("state_dict not support old IR")

dist_state_dict = self._build_distributed_state_dict(local_state_dict)

Expand Down Expand Up @@ -2922,9 +2855,8 @@ def _build_distributed_state_dict(self, local_state_dict):
if use_pir_api():
dist_attrs = get_dist_attr(dist_main_program)
else:
# Dict[var.name, Dict["process_shape": process_mesh.shape, "process_group": process_mesh.process_ids, "dims_mapping": dims_mapping]]
dist_attrs = get_dist_attr(
dist_main_program, self._engine._dist_contexts[self._mode]
raise NotImplementedError(
"_build_distributed_state_dict not support old IR"
)

def build_distributed_tensor(local_tensor, dist_attr):
Expand Down Expand Up @@ -3062,8 +2994,8 @@ def set_state_dict(self, state_dict: dict[str, Tensor]) -> None:
local_state_dict, paddle.static.global_scope(), copy_tensor
)
else:
dist_main_program.set_state_dict(
local_state_dict, paddle.static.global_scope()
raise NotImplementedError(
"_build_distributed_state_dict not support old IR"
)

def _get_shard_stage1_optimizer(self):
Expand Down Expand Up @@ -3285,38 +3217,7 @@ def to_static(
>>> # python -m paddle.distributed.launch {test_case}.py
"""
if isinstance(optimizer, _ShardOptimizer) and not use_pir_api():
shard_fn = optimizer._shard_fn
sharding_degree = optimizer._sharding_degree
optimizer = optimizer._inner_opt

if shard_fn is not None:
strategy = dist.Strategy() if strategy is None else strategy

# Deduce sharding degree for static
# Note: Because limitation of architecture, we need to ensure that
# all parameters are sharded by the same mesh axis
assert (
sharding_degree is not None
), "Sharding degree can not be None."

if isinstance(shard_fn, ShardingStage1):
strategy.sharding.enable = True
strategy.sharding.stage = 1
strategy.sharding.degree = sharding_degree
elif isinstance(shard_fn, ShardingStage2):
strategy.sharding.enable = True
strategy.sharding.stage = 2
strategy.sharding.degree = sharding_degree
elif isinstance(shard_fn, ShardingStage3):
strategy.sharding.enable = True
strategy.sharding.stage = 3
strategy.sharding.degree = sharding_degree
for param in optimizer._parameter_list:
shard_fn._unshard_parameter(param)
else:
raise NotImplementedError(
"Only sharding stage 1, 2 and 3 can to_static for now. User-defined shard_fn will be supported later."
)
raise NotImplementedError("to_static() only support PIR now.")
if strategy is None or strategy.full_graph:
dist_model = DistModel(
layer, loader, loss, optimizer, strategy, input_spec=input_spec
Expand Down
Loading
Loading