Skip to content

Commit 326b496

Browse files
Fix zero-rank tensor handling in shape functions (#731)
- Fix getHostConstantTensor to handle zero-rank tensors correctly - Create rank-1 tensor with zeros when values.size() is 0 - Add test for zero-rank tensor handling in shape function creation GitOrigin-RevId: dc1ed6d94ce37d6d47661030273b042fbee85acc
1 parent e5651d6 commit 326b496

2 files changed

Lines changed: 57 additions & 16 deletions

File tree

mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/CreateShapeFuncs.cpp

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -398,13 +398,17 @@ static MemorySpaceAttr getHostSpace(RewriterBase &rewriter) {
398398

399399
static Value getHostConstantTensor(RewriterBase &rewriter, Location loc,
400400
ArrayRef<int64_t> values) {
401-
auto rtt =
402-
RankedTensorType::get({static_cast<int64_t>(values.size())},
403-
rewriter.getIndexType(), getHostSpace(rewriter));
401+
auto rtt = RankedTensorType::get(
402+
{std::max<int64_t>(1, static_cast<int64_t>(values.size()))},
403+
rewriter.getIndexType(), getHostSpace(rewriter));
404404

405405
// Create the DenseIntElementsAttr with host space
406-
auto attr = DenseIntElementsAttr::get(rtt, values);
407-
return rewriter.create<arith::ConstantOp>(loc, rtt, attr);
406+
auto attrData = [&]() -> TypedAttr {
407+
if (static_cast<int64_t>(values.size()) == rtt.getRank())
408+
return DenseIntElementsAttr::get(rtt, values);
409+
return rewriter.getZeroAttr(rtt);
410+
}();
411+
return rewriter.create<arith::ConstantOp>(loc, rtt, attrData);
408412
}
409413

410414
static SmallVector<Value> createConstantIndices(RewriterBase &rewriter,
@@ -523,20 +527,29 @@ static FailureOr<func::FuncOp> createAggregateShapeFunc(
523527

524528
SmallVector<Value> argValues = getArguments(rewriter, func);
525529
SmallVector<Type> argTypes;
530+
526531
for (auto [idx, argValue] : llvm::enumerate(argValues)) {
527532
Type t = argValue.getType();
528533
auto rtt = dyn_cast<RankedTensorType>(t);
529-
if (!rtt)
530-
argTypes.push_back(t);
531-
const TensorKindLattice *lattice =
532-
solver.lookupState<TensorKindLattice>(argValue);
533-
if (!lattice || lattice->getValue().isUninitialized())
534-
return failure();
535-
if (lattice->getValue().isHostVisible()) {
536-
argTypes.push_back(t);
534+
535+
bool isHostTensorArg = [&, idx = idx, argValue = argValue]() {
536+
if (!rtt)
537+
return false;
538+
if (func.getArgAttr(idx, plan::PlanDialect::kValueBoundsAttrName))
539+
return true;
540+
if (auto memSpace = func.getArgAttrOfType<plan::MemorySpaceAttr>(
541+
idx, plan::PlanDialect::kMemorySpaceConstraintAttrName))
542+
return memSpace.isHostVisible();
543+
const TensorKindLattice *lattice =
544+
solver.lookupState<TensorKindLattice>(argValue);
545+
return lattice && !lattice->getValue().isUninitialized() &&
546+
lattice->getValue().isHostVisible();
547+
}();
548+
if (!isHostTensorArg) {
549+
argTypes.push_back(!rtt ? t : getShapeTensorType(rtt));
537550
continue;
538551
}
539-
argTypes.push_back(getShapeTensorType(rtt));
552+
argTypes.push_back(rtt);
540553
}
541554

542555
FailureOr<SmallVector<Value>> returnedValues = getReturnedValues(func);
@@ -712,8 +725,6 @@ class CreateShapeFuncsPass
712725
(*aggShapeFunc)
713726
->setAttr(PlanDialect::kShapeFuncMarkerAttrName,
714727
UnitAttr::get(rewriter.getContext()));
715-
if (failed(aggShapeFunc))
716-
continue;
717728

718729
// Add the symbol to the original func.
719730
symbolTable.insert(*aggShapeFunc);
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: mlir-tensorrt-opt %s -split-input-file -plan-create-shape-funcs="abi-version=1" -inline | FileCheck %s
2+
3+
// Test zero-rank tensor handling in shape function creation
4+
// The fix ensures that when creating constant tensors with zero values,
5+
// we create a rank-1 tensor with zeros instead of a zero-rank tensor.
6+
7+
func.func public @test_zero_rank_shape(%arg0: !executor.ptr<host> {executor.abi = #executor.arg<byval, tensor<f32>>},
8+
%arg1: i32,
9+
%arg2: !executor.ptr<host> {executor.abi = #executor.arg<byref, tensor<?xf32>>},
10+
%arg3: !executor.ptr<host> {executor.abi = #executor.arg<byref, tensor<f32>>})
11+
attributes {executor.func_abi = (tensor<f32>, i32) -> (tensor<?xf32>, tensor<f32>)} {
12+
%0 = executor.abi.recv %arg0 : tensor<f32>
13+
%1 = stablehlo.exponential %0 : tensor<f32>
14+
%2 = plan.with_shape %1() : (tensor<f32>) -> tensor<f32>
15+
%size = tensor.from_elements %arg1 : tensor<1xi32>
16+
%3 = stablehlo.dynamic_broadcast_in_dim %2, %size, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
17+
%4 = plan.with_shape %3(%arg1) : (tensor<?xf32>, i32) -> tensor<?xf32>
18+
executor.abi.send %4 to %arg2 : tensor<?xf32>
19+
executor.abi.send %2 to %arg3 : tensor<f32>
20+
return
21+
}
22+
23+
// CHECK-LABEL: func.func public @test_zero_rank_shape(
24+
// CHECK-LABEL: func.func public @test_zero_rank_shape_get_shapes(
25+
// CHECK: %[[cst:.+]] = arith.constant dense<0> : tensor<1xindex, #plan.memory_space<host>>
26+
// CHECK-DAG: %[[v0:.+]] = arith.index_cast %{{.+}} : i32 to index
27+
// CHECK-DAG: %[[v1:.+]] = tensor.from_elements %[[v0]] : tensor<1xindex, #plan.memory_space<host>>
28+
// CHECK: executor.abi.send %[[v1]] to %{{.+}}
29+
// CHECK: executor.abi.send %[[cst]] to %{{.+}}
30+
// CHECK: return

0 commit comments

Comments
 (0)