Skip to content

Commit 6e16a15

Browse files
msmiatactjanczak
andauthored
[ITEP-26367] Add PaddleOCR (LPR) converter (open-edge-platform#199)
Co-authored-by: Tomasz Janczak <[email protected]>
1 parent c655f1f commit 6e16a15

File tree

3 files changed

+535
-0
lines changed

3 files changed

+535
-0
lines changed

libraries/dl-streamer/src/monolithic/gst/inference_elements/common/post_processor/blob_to_meta_converter.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "converters/to_tensor/keypoints_hrnet.h"
2020
#include "converters/to_tensor/keypoints_openpose.h"
2121
#include "converters/to_tensor/label.h"
22+
#include "converters/to_tensor/paddle_ocr.h"
2223
#include "converters/to_tensor/raw_data_copy.h"
2324
#include "converters/to_tensor/semantic_mask.h"
2425
#include "converters/to_tensor/text.h"
@@ -182,6 +183,8 @@ BlobToMetaConverter::Ptr BlobToMetaConverter::create(Initializer initializer, Co
182183
return std::make_unique<SemanticMaskConverter>(std::move(initializer));
183184
else if (converter_name == docTROCRConverter::getName())
184185
return std::make_unique<docTROCRConverter>(std::move(initializer));
186+
else if (converter_name == PaddleOCRConverter::getName())
187+
return std::make_unique<PaddleOCRConverter>(std::move(initializer));
185188
else
186189
throw std::runtime_error("Unsupported converter: " + converter_name);
187190
default:
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/*******************************************************************************
2+
* Copyright (C) 2021-2025 Intel Corporation
3+
*
4+
* SPDX-License-Identifier: MIT
5+
******************************************************************************/
6+
7+
#include "paddle_ocr.h"
8+
#include "copy_blob_to_gststruct.h"
9+
#include "inference_backend/logger.h"
10+
#include "safe_arithmetic.hpp"
11+
#include <algorithm>
12+
#include <cmath>
13+
#include <gst/gst.h>
14+
#include <sstream>
15+
#include <stdexcept>
16+
17+
#include <fstream>
18+
#include <iostream>
19+
20+
using namespace post_processing;
21+
using namespace InferenceBackend;
22+
23+
// Constructor to initialize the OCRConverter with the initializer.
24+
PaddleOCRConverter::PaddleOCRConverter(BlobToMetaConverter::Initializer initializer)
25+
: BlobToTensorConverter(std::move(initializer)) {
26+
}
27+
28+
TensorsTable PaddleOCRConverter::convert(const OutputBlobs &output_blobs) {
29+
ITT_TASK(__FUNCTION__);
30+
TensorsTable tensors_table;
31+
32+
try {
33+
const size_t batch_size = getModelInputImageInfo().batch_size;
34+
tensors_table.resize(batch_size);
35+
36+
for (const auto &blob_iter : output_blobs) {
37+
OutputBlob::Ptr blob = blob_iter.second;
38+
if (!blob) {
39+
throw std::invalid_argument("Output blob is empty");
40+
}
41+
42+
const float *data = reinterpret_cast<const float *>(blob->GetData());
43+
if (!data) {
44+
throw std::invalid_argument("Output blob data is nullptr");
45+
}
46+
47+
const size_t data_size = blob->GetSize();
48+
const std::string layer_name = blob_iter.first;
49+
50+
for (size_t batch_elem_index = 0; batch_elem_index < batch_size; ++batch_elem_index) {
51+
GVA::Tensor classification_result = createTensor();
52+
53+
if (!raw_tensor_copying->enabled(RawTensorCopyingToggle::id))
54+
CopyOutputBlobToGstStructure(blob, classification_result.gst_structure(),
55+
BlobToMetaConverter::getModelName().c_str(), layer_name.c_str(),
56+
batch_size, batch_elem_index);
57+
58+
const auto item = get_data_by_batch_index(data, data_size, batch_size, batch_elem_index);
59+
const float *item_data = item.first;
60+
61+
std::string decoded_text =
62+
decodeOutputTensor(item_data);
63+
64+
if (decoded_text.size() > SEQ_MINLEN)
65+
classification_result.set_string("label", decoded_text);
66+
else
67+
classification_result.set_string("label", "");
68+
69+
// Set metadata for the tensor in the GstStructure
70+
gst_structure_set(classification_result.gst_structure(), "tensor_id", G_TYPE_INT,
71+
safe_convert<int>(batch_elem_index), "type", G_TYPE_STRING, "classification_result",
72+
NULL);
73+
std::vector<GstStructure *> tensors{classification_result.gst_structure()};
74+
tensors_table[batch_elem_index].push_back(tensors);
75+
}
76+
}
77+
} catch (const std::exception &e) {
78+
GVA_ERROR("An error occurred in OCR converter: %s", e.what());
79+
}
80+
81+
return tensors_table;
82+
}
83+
84+
// Function to decode output tensor into text using the charset
85+
std::string PaddleOCRConverter::decodeOutputTensor(const float *item_data) {
86+
87+
std::vector<int> pred_indices(SEQUENCE_LENGTH); // Stores indices of max elements for each sequence
88+
89+
for (size_t i = 0; i < SEQUENCE_LENGTH; ++i) {
90+
const float *row_start = item_data + i * CHARSET_LEN; // Pointer to the start of the current sequence
91+
const float *max_element_ptr = std::max_element(row_start, row_start + CHARSET_LEN); // Find max element
92+
int max_index = std::distance(row_start, max_element_ptr); // Calculate index of max element
93+
pred_indices[i] = max_index; // Store the index
94+
}
95+
96+
// Decode the indices into text using the charset
97+
return decode(pred_indices);
98+
}
99+
100+
// Function to decode text indices into text labels using a charset
101+
std::string PaddleOCRConverter::decode(const std::vector<int> &text_index) {
102+
103+
std::string char_list; // Accumulates characters for the sequence
104+
std::vector<int> ignored_tokens = {0}; // Tokens to ignore during decoding
105+
106+
// Iterate over each index in the sequence
107+
for (size_t idx = 0; idx < text_index.size(); ++idx) {
108+
int current_index = text_index[idx];
109+
110+
// Skip ignored tokens
111+
if (std::find(ignored_tokens.begin(), ignored_tokens.end(), current_index) != ignored_tokens.end()) {
112+
continue;
113+
}
114+
115+
// Remove consecutive duplicate indices (optional)
116+
if (idx > 0 && text_index[idx - 1] == current_index) {
117+
continue;
118+
}
119+
120+
// Append the corresponding character from charset
121+
char_list += CHARACTER_SET[current_index];
122+
}
123+
124+
return char_list; // Return the decoded text
125+
}

0 commit comments

Comments
 (0)