Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 33 additions & 8 deletions dali/operators/generic/constant_value.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -144,21 +144,40 @@ void ConstantValue<CPUBackend>::RunImpl(Workspace &ws) {
"Supported types are : ", ListTypeNames<DALI_CONSTANT_VALUE_TYPES>()));
)); // NOLINT
}

if (has_layout_) {
output.SetLayout(layout_);
} else if (is_shape_like_) {
output.SetLayout(ws.Input<CPUBackend>(shape_like_input_idx_).GetLayout());
}
}

DALI_SCHEMA(Full)
.DocStr(R"code(Returns new data of given shape and type, filled with a fill value.)code")
.DocStr(R"code(Returns new data of given shape and type, filled with a fill value.

If the fill_value is not a scalar, it must be broadcastable to the output shape (NumPy-style broadcasting).
Dimensions are compared from innermost to outermost, and each dimension must either match or one of them must be 1.
In case of different dimensionality, the shape is padded with 1s for the missing outermost dimensions.
)code")
.NumInput(1)
.InputDox(0, "fill_value", "TensorList", R"code(The fill value.)code")
.NumOutput(1)
.AddOptionalArg<std::vector<int>>("shape", R"code(Shape of the output data.)code", nullptr,
true);
true)
.AddOptionalArg<TensorLayout>("layout", R"code(Output layout.

If set and not empty, the layout must match the dimensionality of the output.)code", nullptr);

DALI_REGISTER_OPERATOR(Full, Full<CPUBackend>, CPU);

DALI_SCHEMA(FullLike)
.DocStr(R"code(Returns new data with the same shape and type as the input data, filled with a `fill_value`.)code")
.DocStr(R"code(Returns new data with the same shape, type and layout as the input data, filled with a `fill_value`.

If the fill_value is not a scalar, it must be broadcastable to the output shape (NumPy-style broadcasting).
Dimensions are compared from innermost to outermost, and each dimension must either match or one of them must be 1.
In case of different dimensionality, the shape is padded with 1s for the missing outermost dimensions.)code")
.NumInput(2)
.InputDox(0, "data_like", "TensorList", R"code(The input data value to copy the shape and type from.)code")
.InputDox(0, "data_like", "TensorList", R"code(The input data value to copy the shape, type and layout from.)code")
.InputDevice(0, InputDevice::Metadata)
.InputDox(1, "fill_value", "TensorList", R"code(The fill value.)code")
.NumOutput(1);
Expand All @@ -170,11 +189,14 @@ DALI_SCHEMA(Zeros)
.NumOutput(1)
.AddOptionalArg<std::vector<int>>("shape", R"code(Shape of the output data.)code", nullptr,
true)
.AddOptionalArg<TensorLayout>("layout", R"code(Output layout.

If set and not empty, the layout must match the dimensionality of the output.)code", nullptr)
.AddOptionalTypeArg("dtype", R"code(Output data type.)code", DALI_INT32);
DALI_REGISTER_OPERATOR(Zeros, Zeros<CPUBackend>, CPU);

DALI_SCHEMA(ZerosLike)
.DocStr(R"code(Returns new data with the same shape and type as the input array, filled with zeros.)code")
.DocStr(R"code(Returns new data with the same shape, type and layout as the input array, filled with zeros.)code")
.NumInput(1)
.InputDox(0, "data_like", "TensorList", R"code(The input data value to copy the shape and type from.)code")
.InputDevice(0, InputDevice::Metadata)
Expand All @@ -188,13 +210,16 @@ DALI_SCHEMA(Ones)
.NumOutput(1)
.AddOptionalArg<std::vector<int>>("shape", R"code(Shape of the output data.)code", nullptr,
true)
.AddOptionalArg<TensorLayout>("layout", R"code(Output layout.

If set and not empty, the layout must match the dimensionality of the output.)code", nullptr)
.AddOptionalTypeArg("dtype", R"code(Output data type.)code", DALI_INT32);
DALI_REGISTER_OPERATOR(Ones, Ones<CPUBackend>, CPU);

DALI_SCHEMA(OnesLike)
.DocStr(R"code(Returns new data with the same shape and type as the input array, filled with ones.)code")
.DocStr(R"code(Returns new data with the same shape, type and layout as the input array, filled with ones.)code")
.NumInput(1)
.InputDox(0, "data_like", "TensorList", R"code(The input data value to copy the shape and type from.)code")
.InputDox(0, "data_like", "TensorList", R"code(The input data value to copy the shape, type and layout from.)code")
.InputDevice(0, InputDevice::Metadata)
.NumOutput(1)
.AddOptionalTypeArg("dtype", R"code(Overrides the output data type.)code", DALI_INT32);
Expand Down
28 changes: 25 additions & 3 deletions dali/operators/generic/constant_value.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@

#include <vector>
#include "dali/core/static_switch.h"
#include "dali/core/tensor_layout.h"
#include "dali/core/tensor_shape_print.h"
#include "dali/pipeline/operator/checkpointing/stateless_operator.h"
#include "dali/core/float16.h"
Expand All @@ -30,14 +31,27 @@ namespace dali {
template <typename Backend>
class ConstantValue : public StatelessOperator<Backend> {
public:
/**
* @brief Operator that returns new data of requested shape and type, filled with a fill value.
*
* @param spec
* @param has_fill_value If true, accepts input tensor for fill value (Full, FullLike).
* If false, uses constant set via SetConstValue() (Zeros, Ones variants).
* @param is_shape_like If true, shape matches input tensor ("Like" operators).
* If false, takes shape from "shape" argument (Full, Zeros, Ones).
*/
explicit ConstantValue(const OpSpec &spec, bool has_fill_value = false,
bool is_shape_like = false)
: StatelessOperator<Backend>(spec),
has_fill_value_(has_fill_value),
is_shape_like_(is_shape_like),
has_shape_(spec.ArgumentDefined("shape")),
has_dtype_(spec.ArgumentDefined("dtype")) {
has_dtype_(spec.ArgumentDefined("dtype")),
has_layout_(spec.ArgumentDefined("layout")) {
dtype_ = has_dtype_ ? spec.GetArgument<DALIDataType>("dtype") : DALI_INT32;
if (has_layout_) {
layout_ = spec.GetArgument<TensorLayout>("layout");
}
}

int GetBatchSize(const Workspace &ws) const {
Expand Down Expand Up @@ -96,6 +110,13 @@ class ConstantValue : public StatelessOperator<Backend> {
dtype = fill_value_dtype;
}
}

if (has_layout_ && !layout_.empty()) {
DALI_ENFORCE(layout_.ndim() == shape.sample_dim(),
make_string("Layout '", layout_, "' has ", layout_.ndim(),
" dimensions but output shape has ", shape.sample_dim(), " dimensions."));
}

return true;
}

Expand All @@ -111,8 +132,9 @@ class ConstantValue : public StatelessOperator<Backend> {
using Operator<Backend>::max_batch_size_;
bool has_fill_value_;
bool is_shape_like_;
bool has_shape_, has_dtype_;
bool has_shape_, has_dtype_, has_layout_;
DALIDataType dtype_;
TensorLayout layout_;

bool has_const_value_ = false;
int const_value_ = 0;
Expand Down
104 changes: 103 additions & 1 deletion dali/test/python/operator_1/test_constant_value.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,7 +13,10 @@
# limitations under the License.

import numpy as np
from nose2.tools import params
from nose_utils import assert_raises
from nvidia.dali import pipeline_def, fn
from nvidia.dali import types


def run(op):
Expand All @@ -25,6 +28,16 @@ def pipe0():
return np.array(p.run()[0][0])


def run_with_layout(op):
@pipeline_def(batch_size=1, num_threads=3, device_id=0)
def pipe0():
return op

p = pipe0()
out = p.run()[0]
return np.array(out[0]), out.layout()


def test_zeros():
sh = (2, 3)
np.testing.assert_array_equal(run(fn.zeros(shape=sh)), np.zeros(shape=sh))
Expand Down Expand Up @@ -103,3 +116,92 @@ def test_full_like():
np.testing.assert_array_almost_equal(
run(fn.full_like(arr, fill_value_arr)), np.full_like(arr, fill_value_arr)
)


@params(
(fn.zeros, np.zeros),
(fn.ones, np.ones),
)
def test_const_layout(op, np_op):
sh = (2, 3, 4)
layout = "HWC"
arr, out_layout = run_with_layout(op(shape=sh, layout=layout))
np.testing.assert_array_equal(arr, np_op(shape=sh))
assert out_layout == layout


@params(
(fn.zeros_like, np.zeros_like),
(fn.ones_like, np.ones_like),
)
def test_const_like_layout(op, np_op):
sh = (2, 3, 4)
layout = "HWC"
arr = types.Constant(np.ones(sh), layout=layout)
result, out_layout = run_with_layout(op(arr))
np.testing.assert_array_equal(result, np_op(np.ones(sh)))
assert out_layout == layout


@params(
((2, 3, 4), "HWC", 42),
((3, 5), "HW", np.array([1, 2, 3, 4, 5], dtype=np.int32)), # broadcast
((2, 3), "HW", np.array([[1], [2]], dtype=np.int32)), # broadcast
)
def test_full_layout(sh, layout, fill_value):
arr, out_layout = run_with_layout(fn.full(fill_value, shape=sh, layout=layout))
np.testing.assert_array_equal(arr, np.full(sh, fill_value))
assert out_layout == layout


@params(
((2, 3, 4), "HWC", 42),
((3, 5), "HW", np.array([1, 2, 3, 4, 5], dtype=np.int32)), # broadcast
((2, 3), "HW", np.array([[1], [2]], dtype=np.int32)), # broadcast
)
def test_full_like_layout(sh, layout, fill_value):
arr = types.Constant(np.ones(sh), layout=layout)
result, out_layout = run_with_layout(fn.full_like(arr, fill_value))
np.testing.assert_array_equal(result, np.full(sh, fill_value))
assert out_layout == layout


@params(
(fn.zeros, (2, 3, 4)),
(fn.ones, (2, 3)),
(fn.full, (5, 4, 3)),
)
def test_const_empty_layout(op, sh):
op_to_run = op(42, shape=sh, layout="") if op == fn.full else op(shape=sh, layout="")
arr, out_layout = run_with_layout(op_to_run)
assert out_layout == ""


@params(
(fn.zeros_like, (2, 3, 4)),
(fn.ones_like, (2, 3)),
(fn.full_like, (5, 4, 3)),
)
def test_const_like_empty_layout(op, sh):
arr = types.Constant(np.ones(sh), layout="")
result, out_layout = run_with_layout(op(arr, 42) if op == fn.full_like else op(arr))
assert out_layout == ""


@params(
(fn.zeros, (2, 3, 4), "HW", 2, 3),
(fn.ones, (2, 3), "FHWC", 4, 2),
(fn.full, (2, 3), "HWC", 3, 2),
)
def test_const_layout_mismatch(op, sh, layout, layout_ndim, shape_ndim):
@pipeline_def(batch_size=1, num_threads=3, device_id=0)
def pipe():
return op(42, shape=sh, layout=layout) if op == fn.full else op(shape=sh, layout=layout)

p = pipe()
p.build()
assert_raises(
RuntimeError,
p.run,
glob=f"Layout '{layout}' has {layout_ndim} dimensions but output shape has {shape_ndim}*",
)