@@ -398,13 +398,17 @@ static MemorySpaceAttr getHostSpace(RewriterBase &rewriter) {
398398
399399static Value getHostConstantTensor (RewriterBase &rewriter, Location loc,
400400 ArrayRef<int64_t > values) {
401- auto rtt =
402- RankedTensorType::get ({ static_cast <int64_t >(values.size ())},
403- rewriter.getIndexType (), getHostSpace (rewriter));
401+ auto rtt = RankedTensorType::get (
402+ {std::max< int64_t >( 1 , static_cast <int64_t >(values.size () ))},
403+ rewriter.getIndexType (), getHostSpace (rewriter));
404404
405405 // Create the DenseIntElementsAttr with host space
406- auto attr = DenseIntElementsAttr::get (rtt, values);
407- return rewriter.create <arith::ConstantOp>(loc, rtt, attr);
406+ auto attrData = [&]() -> TypedAttr {
407+ if (static_cast <int64_t >(values.size ()) == rtt.getRank ())
408+ return DenseIntElementsAttr::get (rtt, values);
409+ return rewriter.getZeroAttr (rtt);
410+ }();
411+ return rewriter.create <arith::ConstantOp>(loc, rtt, attrData);
408412}
409413
410414static SmallVector<Value> createConstantIndices (RewriterBase &rewriter,
@@ -523,20 +527,29 @@ static FailureOr<func::FuncOp> createAggregateShapeFunc(
523527
524528 SmallVector<Value> argValues = getArguments (rewriter, func);
525529 SmallVector<Type> argTypes;
530+
526531 for (auto [idx, argValue] : llvm::enumerate (argValues)) {
527532 Type t = argValue.getType ();
528533 auto rtt = dyn_cast<RankedTensorType>(t);
529- if (!rtt)
530- argTypes.push_back (t);
531- const TensorKindLattice *lattice =
532- solver.lookupState <TensorKindLattice>(argValue);
533- if (!lattice || lattice->getValue ().isUninitialized ())
534- return failure ();
535- if (lattice->getValue ().isHostVisible ()) {
536- argTypes.push_back (t);
534+
535+ bool isHostTensorArg = [&, idx = idx, argValue = argValue]() {
536+ if (!rtt)
537+ return false ;
538+ if (func.getArgAttr (idx, plan::PlanDialect::kValueBoundsAttrName ))
539+ return true ;
540+ if (auto memSpace = func.getArgAttrOfType <plan::MemorySpaceAttr>(
541+ idx, plan::PlanDialect::kMemorySpaceConstraintAttrName ))
542+ return memSpace.isHostVisible ();
543+ const TensorKindLattice *lattice =
544+ solver.lookupState <TensorKindLattice>(argValue);
545+ return lattice && !lattice->getValue ().isUninitialized () &&
546+ lattice->getValue ().isHostVisible ();
547+ }();
548+ if (!isHostTensorArg) {
549+ argTypes.push_back (!rtt ? t : getShapeTensorType (rtt));
537550 continue ;
538551 }
539- argTypes.push_back (getShapeTensorType ( rtt) );
552+ argTypes.push_back (rtt);
540553 }
541554
542555 FailureOr<SmallVector<Value>> returnedValues = getReturnedValues (func);
@@ -712,8 +725,6 @@ class CreateShapeFuncsPass
712725 (*aggShapeFunc)
713726 ->setAttr (PlanDialect::kShapeFuncMarkerAttrName ,
714727 UnitAttr::get (rewriter.getContext ()));
715- if (failed (aggShapeFunc))
716- continue ;
717728
718729 // Add the symbol to the original func.
719730 symbolTable.insert (*aggShapeFunc);
0 commit comments