Skip to content

【Hackathon 8th No.3】del oldIR in api -part #72548

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 23 additions & 122 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,17 @@
from paddle import _C_ops, nn, pir
from paddle.amp.grad_scaler import OptimizerState
from paddle.autograd import PyLayer
from paddle.base import unique_name
from paddle.base.dygraph.base import switch_to_static_graph
from paddle.base.framework import (
EagerParamBase,
Variable,
default_main_program,
in_dygraph_mode,
in_pir_mode,
use_pir_api,
)
from paddle.distributed import fleet
from paddle.distributed.auto_parallel import Engine, strategy as auto_strategy
from paddle.distributed.auto_parallel.interface import (
shard_tensor as shard_tensor_static,
)
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.static.completion import (
mark_as_sharding_propagation_skip_op,
)
from paddle.distributed.auto_parallel.static.dist_context import (
get_default_distributed_context,
)
from paddle.distributed.auto_parallel.static.dist_op import DistributedOperator
from paddle.distributed.auto_parallel.static.utils import (
convert_to_dims_mapping,
fuse_param_func,
get_dist_attr,
split_mesh,
Expand Down Expand Up @@ -279,15 +265,7 @@ def shard_tensor(
if stop_gradient is None:
stop_gradient = getattr(data, "stop_gradient", True)

if paddle.framework.in_pir_mode():
assert isinstance(
data, (type(None), pir.Value)
), "input tensor is not pir value."
assert (
data.is_dense_tensor_type()
), "shard_tensor() input data only supported dense tensor type right."
tensor = data
else:
if paddle.in_dynamic_mode():
if isinstance(data, EagerParamBase) and not data._is_initialized():
assert (
data._init_func is not None
Expand All @@ -302,7 +280,6 @@ def shard_tensor(
data, dtype=dtype, place=place, stop_gradient=stop_gradient
)

if paddle.in_dynamic_mode():
# here the dist tensor is deep copy constructed
if isinstance(data, EagerParamBase):

Expand Down Expand Up @@ -353,14 +330,21 @@ def _init_func(var, block):
dist_tensor.stop_gradient = tensor.stop_gradient
return dist_tensor
elif paddle.framework.in_pir_mode():
dist_tensor = paddle._C_ops.shard_tensor(tensor, mesh, placements)
dist_tensor.stop_gradient = tensor.stop_gradient
dist_tensor.persistable = tensor.persistable
assert isinstance(
data, (type(None), pir.Value)
), "input tensor is not pir value."
assert (
data.is_dense_tensor_type()
), "shard_tensor() input data only supported dense tensor type right."

dist_tensor = paddle._C_ops.shard_tensor(data, mesh, placements)
dist_tensor.stop_gradient = data.stop_gradient
dist_tensor.persistable = data.persistable
return dist_tensor
else:
# TODO(zhiqiu): we need to refine the static shard_tensor
sharding_specs = get_shard_spec(mesh, placements, tensor.ndim)
return shard_tensor_static(tensor, mesh, sharding_specs)
raise NotImplementedError(
"`shard_tensor()` only supported in dynamic and pir mode."
)


class _moe_global_mesh_tensor(PyLayer):
Expand Down Expand Up @@ -861,53 +845,9 @@ def reshard(
elif in_pir_mode():
return paddle._C_ops.reshard(dist_tensor, mesh, placements)
else:
assert isinstance(
dist_tensor, Variable
), f"in dy2static mode, reshard's input should be Variable, but got [{dist_tensor}]"
sharding_specs = get_shard_spec(mesh, placements, dist_tensor.ndim)
main_program = default_main_program()
default_dist_ctx = get_default_distributed_context()

# output variable
out_var = main_program.current_block().create_var(
name=unique_name.generate_with_ignorable_key(
".".join(['reshard_api', 'tmp'])
),
dtype=dist_tensor.dtype,
shape=dist_tensor.shape,
type=dist_tensor.type,
persistable=dist_tensor.persistable,
stop_gradient=dist_tensor.stop_gradient,
)

# transition op
# optimization in future to remove redundant D2D memory copy
target_dims_mapping = convert_to_dims_mapping(sharding_specs, mesh)
trans_op = main_program.current_block().append_op(
type='assign',
inputs={'X': [dist_tensor]},
outputs={'Out': [out_var]},
)
dist_op = DistributedOperator(trans_op)
dist_op.dist_attr.process_mesh = mesh
dist_op.dist_attr.mark_annotated("process_mesh")
dist_op.dist_attr.chunk_id = 0

input_dist_attr = dist_op.dist_attr.get_input_dist_attr(
dist_tensor.name
raise NotImplementedError(
"`reshard()` only supported in dynamic and pir mode."
)
input_dist_attr.dims_mapping = target_dims_mapping
input_dist_attr.mark_annotated("dims_mapping")
output_dist_attr = dist_op.dist_attr.get_output_dist_attr(out_var.name)
output_dist_attr.dims_mapping = target_dims_mapping
output_dist_attr.mark_annotated("dims_mapping")

default_dist_ctx.add_dist_op_for_program(dist_op)
mark_as_sharding_propagation_skip_op(trans_op)
# trans_op = shard_op_static(paddle.assign, mesh, [sharding_specs])
# out_var = trans_op(dist_tensor)

return out_var


def shard_layer(
Expand Down Expand Up @@ -2550,12 +2490,7 @@ def __init__(
if (
not self._in_pir_mode
): # TODO (2024-Q2) remove this when pir mode is fully constructed.
if optimizer is not None and loss is not None:
self.train()
elif loss is not None:
self.eval()
else:
self.predict()
raise NotImplementedError("Only supported in dynamic and pir mode.")

def train(self) -> None:
"""
Expand Down Expand Up @@ -2840,9 +2775,7 @@ def state_dict(
mode=self._engine._mode
).state_dict(mode, scope)
else:
local_state_dict = self.dist_main_program(
mode=self._engine._mode
).state_dict(mode)
raise NotImplementedError("state_dict not support old IR")

dist_state_dict = self._build_distributed_state_dict(local_state_dict)

Expand Down Expand Up @@ -2922,9 +2855,8 @@ def _build_distributed_state_dict(self, local_state_dict):
if use_pir_api():
dist_attrs = get_dist_attr(dist_main_program)
else:
# Dict[var.name, Dict["process_shape": process_mesh.shape, "process_group": process_mesh.process_ids, "dims_mapping": dims_mapping]]
dist_attrs = get_dist_attr(
dist_main_program, self._engine._dist_contexts[self._mode]
raise NotImplementedError(
"_build_distributed_state_dict not support old IR"
)

def build_distributed_tensor(local_tensor, dist_attr):
Expand Down Expand Up @@ -3062,8 +2994,8 @@ def set_state_dict(self, state_dict: dict[str, Tensor]) -> None:
local_state_dict, paddle.static.global_scope(), copy_tensor
)
else:
dist_main_program.set_state_dict(
local_state_dict, paddle.static.global_scope()
raise NotImplementedError(
"_build_distributed_state_dict not support old IR"
)

def _get_shard_stage1_optimizer(self):
Expand Down Expand Up @@ -3285,38 +3217,7 @@ def to_static(
>>> # python -m paddle.distributed.launch {test_case}.py
"""
if isinstance(optimizer, _ShardOptimizer) and not use_pir_api():
shard_fn = optimizer._shard_fn
sharding_degree = optimizer._sharding_degree
optimizer = optimizer._inner_opt

if shard_fn is not None:
strategy = dist.Strategy() if strategy is None else strategy

# Deduce sharding degree for static
# Note: Because limitation of architecture, we need to ensure that
# all parameters are sharded by the same mesh axis
assert (
sharding_degree is not None
), "Sharding degree can not be None."

if isinstance(shard_fn, ShardingStage1):
strategy.sharding.enable = True
strategy.sharding.stage = 1
strategy.sharding.degree = sharding_degree
elif isinstance(shard_fn, ShardingStage2):
strategy.sharding.enable = True
strategy.sharding.stage = 2
strategy.sharding.degree = sharding_degree
elif isinstance(shard_fn, ShardingStage3):
strategy.sharding.enable = True
strategy.sharding.stage = 3
strategy.sharding.degree = sharding_degree
for param in optimizer._parameter_list:
shard_fn._unshard_parameter(param)
else:
raise NotImplementedError(
"Only sharding stage 1, 2 and 3 can to_static for now. User-defined shard_fn will be supported later."
)
raise NotImplementedError("to_static() only support PIR now.")
if strategy is None or strategy.full_graph:
dist_model = DistModel(
layer, loader, loss, optimizer, strategy, input_spec=input_spec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def _apply_impl(self, main_programs, startup_programs, context):
main_program, startup_program, context
)
else:
self._apply_single_impl(main_program, startup_program, context)
raise NotImplementedError(
"Not support for old IR, please use pir mode."
)

def _partial_pir_programs(self, program):
"""
Expand Down
Loading