Skip to content

Commit 01819a0

Browse files
committed
Return grid_thw output
1 parent 9424eab commit 01819a0

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

shared/api/image_transforms_qwen2_5.hpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ class PatchImage {
1414
public:
1515
PatchImage() = default;
1616

17-
OrtxStatus Compute(const ortc::Tensor<float>& input, ortc::Tensor<float>& output) {
17+
OrtxStatus Compute(const ortc::Tensor<float>& input,
18+
ortc::Tensor<float>& output,
19+
ortc::Tensor<int64_t>& grid_thw_output) {
1820
// Validate and read HWC
1921
const auto& dims = input.Shape();
2022
if (dims.size() != 3ULL) {
@@ -59,6 +61,13 @@ class PatchImage {
5961
int64_t grid_h = H / patch_size_;
6062
int64_t grid_w = W / patch_size_;
6163

64+
// Populate grid_thw output tensor
65+
grid_thw_output.Allocate({1, 3});
66+
int64_t* grid_data = const_cast<int64_t*>(grid_thw_output.Data());
67+
grid_data[0] = grid_t;
68+
grid_data[1] = grid_h;
69+
grid_data[2] = grid_w;
70+
6271
// Reshape dimensions (Python 9D)
6372
if (merge_size_ <= 0) {
6473
return {kOrtxErrorInvalidArgument, "[PatchImage]: merge_size must be > 0"};

0 commit comments

Comments
 (0)