Skip to content

Commit 1b797d6

Browse files
authored
Add support for an optional parameter in the example repeat int32 model (#396)
1 parent f61d423 commit 1b797d6

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

examples/decoupled/repeat_model.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# Redistribution and use in source and binary forms, with or without
44
# modification, are permitted provided that the following conditions
@@ -112,6 +112,14 @@ def initialize(self, args):
112112
self.out_dtype = pb_utils.triton_string_to_numpy(out_config["data_type"])
113113
self.idx_dtype = pb_utils.triton_string_to_numpy(idx_config["data_type"])
114114

115+
# Optional parameter to specify the number of elements in the OUT tensor in each response.
116+
# Defaults to 1 if not provided. Example: If input 'IN' is [4] and 'output_num_elements' is set to 3,
117+
# then 'OUT' will be [4, 4, 4]. If 'output_num_elements' is not specified, 'OUT' will default to [4].
118+
parameters = self.model_config.get("parameters", {})
119+
self.output_num_elements = int(
120+
parameters.get("output_num_elements", {}).get("string_value", 1)
121+
)
122+
115123
# To keep track of response threads so that we can delay
116124
# the finalizing the model until all response threads
117125
# have completed.
@@ -209,7 +217,10 @@ def response_thread(self, response_sender, in_input, delay_input):
209217
time.sleep(delay_value / 1000)
210218

211219
idx_output = pb_utils.Tensor("IDX", numpy.array([idx], idx_dtype))
212-
out_output = pb_utils.Tensor("OUT", numpy.array([in_value], out_dtype))
220+
out_output = pb_utils.Tensor(
221+
"OUT",
222+
numpy.full((self.output_num_elements,), in_value, dtype=out_dtype),
223+
)
213224
response = pb_utils.InferenceResponse(
214225
output_tensors=[idx_output, out_output]
215226
)

0 commit comments

Comments
 (0)