@@ -400,31 +400,25 @@ bool isOnFullVersion(Operation* op, const bool enableFullVersion) {
400400// - All op results have the same unreduced axes.
401401// - If the op has no results, none of the operands has unreduced axes.
402402// - Operand and result meshes are the same ignoring device id order.
403+ // - There are no overflow axes.
403404//
404405// Returns the union of axes along all the reduction factors which may not be
405406// canonicalized.
406- SmallVector<AxisRefAttr> processOp (Operation* op,
407- ArrayRef<TensorShardingAttr> inShardings,
408- ArrayRef<TensorShardingAttr> outShardings,
409- IRRewriter& rewriter,
410- const SymbolTable& symbolTable,
411- OpShardingRuleAttr shardingRule,
412- const Mesh& mesh, const bool onFullVersion) {
413- ShardingProjection shardingProjection = ShardingProjection::build (
414- inShardings, outShardings, shardingRule, mesh.attr (),
415- /* closedIfMissing=*/ true );
416-
417- // Return without inserting reshards if any factor sharding has overflow
418- // axes. This case is not handled yet.
419- // TODO(enver): Handle the case when factor shardings have overflow axes.
420- if (hasOverflowAxes (shardingProjection)) {
421- return {};
422- }
423-
407+ //
408+ // Guarantees to return non-empty `AxesPerFactor` if `onFullVersion` is true.
409+ AxesPerFactor processOp (Operation* op, ShardingProjection& shardingProjection,
410+ ArrayRef<TensorShardingAttr> inShardings,
411+ ArrayRef<TensorShardingAttr> outShardings,
412+ IRRewriter& rewriter, const SymbolTable& symbolTable,
413+ OpShardingRuleAttr shardingRule, const Mesh& mesh,
414+ const bool onFullVersion) {
415+ // Checks if factors are sharded the same way across operands and results.
416+ AxesPerFactor commonAxesPerFactor =
417+ getCompatibleFactorShardings (shardingProjection, shardingRule);
418+
419+ // TODO(b/446833985): Return common axes factors also when the sharding
420+ // projection have overflow axes.
424421 if (onFullVersion) {
425- // Checks if factors are sharded the same way across operands and results.
426- AxesPerFactor commonAxesPerFactor =
427- getCompatibleFactorShardings (shardingProjection, shardingRule);
428422 // Find compatible shardings if it is not already compatible.
429423 if (commonAxesPerFactor.empty ()) {
430424 commonAxesPerFactor =
@@ -443,49 +437,19 @@ SmallVector<AxisRefAttr> processOp(Operation* op,
443437 insertExplicitReshards (op, inShardings, outShardings, shardingProjection,
444438 updateTensorShardings, rewriter, shardingRule,
445439 symbolTable, mesh);
446-
447- return getReductionAxes (commonAxesPerFactor, shardingRule);
440+ } else {
441+ TypeSwitch<Operation*>(op)
442+ .Case <stablehlo::DotOp>([&](stablehlo::DotOp dotOp) {
443+ processDot (dotOp, inShardings, outShardings, rewriter, symbolTable,
444+ shardingRule, mesh);
445+ })
446+ .Case <stablehlo::DotGeneralOp>(
447+ [&](stablehlo::DotGeneralOp dotGeneralOp) {
448+ processDot (dotGeneralOp, inShardings, outShardings, rewriter,
449+ symbolTable, shardingRule, mesh);
450+ });
448451 }
449-
450- TypeSwitch<Operation*>(op)
451- .Case <stablehlo::DotOp>([&](stablehlo::DotOp dotOp) {
452- processDot (dotOp, inShardings, outShardings, rewriter, symbolTable,
453- shardingRule, mesh);
454- })
455- .Case <stablehlo::DotGeneralOp>([&](stablehlo::DotGeneralOp dotGeneralOp) {
456- processDot (dotGeneralOp, inShardings, outShardings, rewriter,
457- symbolTable, shardingRule, mesh);
458- });
459-
460- if (outShardings.empty () || getUnreducedAxes (outShardings[0 ]).empty ()) {
461- return {};
462- }
463-
464- // TODO(enver): Repurpose getCompatibleFactorShardings to return compatible
465- // factors, and simplify the following logic.
466- SmallVector<AxisRefAttr> axesAlongAllReductionFactors;
467- for (int64_t reductionFactor : shardingRule.getReductionFactors ()) {
468- // We only iterate operands since reduction factors are not in results.
469- bool seen = false ;
470- SmallVector<AxisRefAttr> axesAlongCurrentReductionFactor;
471- for (const TensorFactorShardings& tensorFactorSharding :
472- shardingProjection.getOperands ()) {
473- if (std::optional<ArrayRef<AxisRefAttr>> factorSharding =
474- getFactorSharding (tensorFactorSharding, reductionFactor)) {
475- if (seen) {
476- SDY_CHECK (axesAlongCurrentReductionFactor == *factorSharding)
477- << " For the operation " << op
478- << " , the result has unreduced axes while the operand has "
479- " incompatible sharding along reduction factors." ;
480- } else {
481- axesAlongCurrentReductionFactor = llvm::to_vector (*factorSharding);
482- seen = true ;
483- }
484- }
485- }
486- axesAlongAllReductionFactors.append (axesAlongCurrentReductionFactor);
487- }
488- return axesAlongAllReductionFactors;
452+ return commonAxesPerFactor;
489453}
490454
491455struct InsertExplicitReshardsPass
@@ -544,11 +508,22 @@ struct InsertExplicitReshardsPass
544508 return ;
545509 }
546510
547- SmallVector<AxisRefAttr> reductionAxes =
548- processOp (op, inShardings, outShardings, rewriter, symbolTable,
549- shardingRule, *mesh, onFullVersion);
511+ ShardingProjection shardingProjection = ShardingProjection::build (
512+ inShardings, outShardings, shardingRule, mesh->attr (),
513+ /* closedIfMissing=*/ true );
514+ // Return without inserting reshards if any factor sharding has overflow
515+ // axes. This case is not handled yet.
516+ // TODO(enver): Handle the case when factor shardings have overflow axes.
517+ if (hasOverflowAxes (shardingProjection)) {
518+ return ;
519+ }
520+ AxesPerFactor commonAxesPerFactor =
521+ processOp (op, shardingProjection, inShardings, outShardings, rewriter,
522+ symbolTable, shardingRule, *mesh, onFullVersion);
550523 // TODO(b/440055868): Insert a reshard from unreduced to replicated axes.
551- insertAllReducesForReductionFactors (op, reductionAxes, *mesh, rewriter);
524+ insertAllReducesForReductionFactors (op, shardingProjection,
525+ commonAxesPerFactor, shardingRule,
526+ *mesh, rewriter, onFullVersion);
552527
553528 // TODO(enver): Remove sharding rules from ops.
554529 });
0 commit comments