Skip to content

Commit

Permalink
partial vision processor impl.
Browse files Browse the repository at this point in the history
  • Loading branch information
wenbingl committed Jan 18, 2025
1 parent e00b123 commit 1e6ffce
Show file tree
Hide file tree
Showing 6 changed files with 495 additions and 38 deletions.
2 changes: 2 additions & 0 deletions shared/api/image_processor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "image_transforms.hpp"
#include "image_transforms_phi_3.hpp"
#include "image_transforms_mllama.hpp"
#include "vision_processor_phi_4.hpp"

namespace ort_extensions {
std::tuple<std::unique_ptr<ImageRawData[]>, size_t>
Expand All @@ -39,6 +40,7 @@ Operation::KernelRegistry ImageProcessor::kernel_registry_ = {
{"Permute3D", []() { return CreateKernelInstance(&Permute3D::Compute); }},
{"Phi3ImageTransform", []() { return CreateKernelInstance(phi3_hd_transform); }},
{"Llama3ImageTransform", []() { return CreateKernelInstance(&Llama3ImageTransform::Compute); }},
{"Llama3ImageTransform", []() { return CreateKernelInstance(&Phi4VisionProcessor::Compute); }},
};

OrtxStatus ImageProcessor::Init(std::string_view processor_def) {
Expand Down
39 changes: 38 additions & 1 deletion shared/api/image_transforms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,44 @@ void DumpTensorToFile(const ortc::Tensor<T>& tensor, const char* name) {
#endif
}

template <typename T>
void SplitIntoTitles(const ortc::Tensor<T>& normalized_image, ortc::Tensor<T>& pixel_values,
int64_t tile_height, int64_t tile_width) {
auto& shape = normalized_image.Shape();
int64_t image_height = shape[0];
int64_t image_width = shape[1];
int64_t num_channels = shape[2];

const int64_t image_1c_size = tile_height * tile_width;
assert(image_height % tile_height == 0);
int64_t num_tiles_height = static_cast<int64_t>(image_height / tile_height);
assert(image_width % tile_width == 0);
int64_t num_tiles_width = static_cast<int64_t>(image_width / tile_width);

auto p_normalized_image = normalized_image.Data();
// shape (num_tiles_width * num_tiles_height, num_channels, tile_height, tile_width)
float* output_pixel =
pixel_values.Allocate({num_tiles_height * num_tiles_width, num_channels, tile_height, tile_width});

// From (image_height, image_width, num_channels)
// Permute to (num_tiles_height, num_tiles_width, num_channels, tile_height, tile_width)
for (int64_t i = 0; i < num_tiles_height; ++i) {
for (int64_t j = 0; j < num_tiles_width; ++j) {
// convert to be channel first
for (int64_t k = 0; k < num_channels; ++k) {
auto sub_index = image_1c_size * (i * num_tiles_width + j) * num_channels + image_1c_size * k;
for (int64_t y = 0; y < tile_height; ++y) {
for (int64_t x = 0; x < tile_width; ++x) {
output_pixel[sub_index + y * tile_width + x] =
p_normalized_image[(i * tile_height + y) * image_width * num_channels +
(j * tile_width + x) * num_channels + k];
}
}
}
}
}
}

inline OrtxStatus convert_to_rgb(const ortc::Tensor<uint8_t>& input, ortc::Tensor<uint8_t>& output) {
auto& dimensions = input.Shape();
if (dimensions.size() != 3ULL || dimensions[2] != 3) {
Expand Down Expand Up @@ -106,7 +144,6 @@ struct Resize {
std::memcpy(p_output_image + c0_index, output_image->image[i] + j * 4, c);
}
}
// DumpTensor(output);

ImagingDelete(output_image);
return {};
Expand Down
37 changes: 0 additions & 37 deletions shared/api/image_transforms_mllama.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,43 +11,6 @@
#include "image_transforms.hpp"

struct Llama3ImageTransform {
static void SplitIntoTitles(const ortc::Tensor<float>& normalized_image, ortc::Tensor<float>& pixel_values,
int64_t tile_height, int64_t tile_width) {
auto& shape = normalized_image.Shape();
int64_t image_height = shape[0];
int64_t image_width = shape[1];
int64_t num_channels = shape[2];

const int64_t image_1c_size = tile_height * tile_width;
assert(image_height % tile_height == 0);
int64_t num_tiles_height = static_cast<int64_t>(image_height / tile_height);
assert(image_width % tile_width == 0);
int64_t num_tiles_width = static_cast<int64_t>(image_width / tile_width);

auto p_normalized_image = normalized_image.Data();
// shape (num_tiles_width * num_tiles_height, num_channels, tile_height, tile_width)
float* output_pixel =
pixel_values.Allocate({num_tiles_height * num_tiles_width, num_channels, tile_height, tile_width});

// From (image_height, image_width, num_channels)
// Permute to (num_tiles_height, num_tiles_width, num_channels, tile_height, tile_width)
for (int64_t i = 0; i < num_tiles_height; ++i) {
for (int64_t j = 0; j < num_tiles_width; ++j) {
// convert to be channel first
for (int64_t k = 0; k < num_channels; ++k) {
auto sub_index = image_1c_size * (i * num_tiles_width + j) * num_channels + image_1c_size * k;
for (int64_t y = 0; y < tile_height; ++y) {
for (int64_t x = 0; x < tile_width; ++x) {
output_pixel[sub_index + y * tile_width + x] =
p_normalized_image[(i * tile_height + y) * image_width * num_channels +
(j * tile_width + x) * num_channels + k];
}
}
}
}
}
}

OrtxStatus Compute(const ortc::Tensor<uint8_t>& image, ortc::Tensor<float>& pixel_values,
ortc::Tensor<int64_t>& aspect_ratio_ids, ortc::Tensor<int64_t>& aspect_ratio_mask,
ortc::Tensor<int64_t>& num_tiles) {
Expand Down
Loading

0 comments on commit 1e6ffce

Please sign in to comment.