Skip to content

Commit e733dec

Browse files
authored
Add HuggingFace ViT classifier config parser (#558)
1 parent b5bd4f4 commit e733dec

File tree

2 files changed

+244
-0
lines changed

2 files changed

+244
-0
lines changed

src/monolithic/inference_backend/image_inference/openvino/model_api_converters.cpp

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,240 @@ bool convertYoloMeta2ModelApi(const std::string model_file, ov::AnyMap &modelCon
170170
return true;
171171
}
172172

173+
// Helper function to return a config file path in the same directory as model_file
174+
std::string getConfigPath(const std::string &model_file, const std::string &filename) {
175+
std::filesystem::path model_path(model_file);
176+
std::filesystem::path model_dir = model_path.parent_path();
177+
std::filesystem::path config_path = model_dir / filename;
178+
179+
if (std::filesystem::exists(config_path))
180+
return config_path.string();
181+
return {};
182+
}
183+
184+
// Helper function to load JSON from a config file
185+
bool loadJsonFromFile(const std::string &file_path, nlohmann::json &json_out) {
186+
std::ifstream file_stream(file_path);
187+
if (!file_stream.is_open()) {
188+
GST_ERROR("Failed to open config file: %s", file_path.c_str());
189+
return false;
190+
}
191+
192+
try {
193+
file_stream >> json_out;
194+
} catch (const std::exception &e) {
195+
GST_ERROR("Failed to parse config file: %s", e.what());
196+
return false;
197+
}
198+
199+
return true;
200+
}
201+
202+
// Helper function to load JSON from a config file in the same directory as model_file
203+
bool loadJsonFromModelDir(const std::string &model_file, const std::string &filename, nlohmann::json &json_out) {
204+
const std::string file_path = getConfigPath(model_file, filename);
205+
if (file_path.empty())
206+
return false;
207+
return loadJsonFromFile(file_path, json_out);
208+
}
209+
210+
// Return matched HuggingFace architecture name from config.json, empty string otherwise
211+
std::string getHuggingFaceArchitecture(const nlohmann::json &config_json) {
212+
auto is_supported = [](const std::string &arch_name) {
213+
for (const auto &supported : kHfSupportedArchitectures) {
214+
if (arch_name == supported)
215+
return true;
216+
}
217+
return false;
218+
};
219+
220+
if (config_json.contains("architectures") && config_json["architectures"].is_array()) {
221+
for (const auto &arch : config_json["architectures"]) {
222+
if (!arch.is_string())
223+
continue;
224+
const std::string arch_name = arch.get<std::string>();
225+
if (is_supported(arch_name))
226+
return arch_name;
227+
}
228+
}
229+
230+
if (config_json.contains("architecture") && config_json["architecture"].is_string()) {
231+
const std::string arch_name = config_json["architecture"].get<std::string>();
232+
if (is_supported(arch_name))
233+
return arch_name;
234+
}
235+
236+
return {};
237+
}
238+
239+
static bool parseViTForImageClassification(const nlohmann::json &config_json, const nlohmann::json &preproc_json,
240+
ov::AnyMap &modelConfig) {
241+
// Setting up preprocessing parameters
242+
243+
// Set reshape size from preprocessor_config.json size field
244+
if (preproc_json.contains("size") && preproc_json["size"].is_object()) {
245+
const auto &size = preproc_json["size"];
246+
if (size.contains("height") && size.contains("width") && size["height"].is_number_integer() &&
247+
size["width"].is_number_integer()) {
248+
const int height = size["height"].get<int>();
249+
const int width = size["width"].get<int>();
250+
modelConfig["reshape"] = ov::Any(std::vector<int>{height, width});
251+
}
252+
}
253+
254+
// Default resize_type to "standard"
255+
modelConfig["resize_type"] = ov::Any(std::string("standard"));
256+
257+
// Check if "do_center_crop": true, then set resize_type to "crop"
258+
if (preproc_json.contains("do_center_crop") && preproc_json["do_center_crop"].is_boolean() &&
259+
preproc_json["do_center_crop"].get<bool>() == true) {
260+
modelConfig["resize_type"] = ov::Any(std::string("crop"));
261+
262+
if (preproc_json.contains("crop_size") && preproc_json["crop_size"].is_object()) {
263+
const auto &size = preproc_json["crop_size"];
264+
if (size.contains("height") && size.contains("width") && size["height"].is_number_integer() &&
265+
size["width"].is_number_integer()) {
266+
const int height = size["height"].get<int>();
267+
const int width = size["width"].get<int>();
268+
modelConfig["reshape"] = ov::Any(std::vector<int>{height, width});
269+
}
270+
}
271+
}
272+
273+
// Return an error if modelConfig does not have reshape set
274+
if (modelConfig.find("reshape") == modelConfig.end()) {
275+
GST_ERROR("HuggingFace ViTForImageClassification image size is not specified in preprocessor_config.json");
276+
return false;
277+
}
278+
279+
double rescale_factor = 1.0 / 255.0;
280+
if (preproc_json.contains("rescale_factor") && preproc_json["rescale_factor"].is_number()) {
281+
rescale_factor = preproc_json["rescale_factor"].get<double>();
282+
}
283+
284+
if (preproc_json.contains("image_mean") && preproc_json["image_mean"].is_array()) {
285+
std::vector<std::string> mean_values;
286+
for (const auto &val : preproc_json["image_mean"]) {
287+
if (val.is_number()) {
288+
mean_values.push_back(std::to_string(val.get<double>() / rescale_factor));
289+
}
290+
}
291+
if (!mean_values.empty()) {
292+
std::ostringstream mean_values_stream;
293+
for (size_t i = 0; i < mean_values.size(); ++i) {
294+
if (i)
295+
mean_values_stream << ' ';
296+
mean_values_stream << mean_values[i];
297+
}
298+
modelConfig["mean_values"] = ov::Any(mean_values_stream.str());
299+
}
300+
}
301+
302+
if (preproc_json.contains("image_std") && preproc_json["image_std"].is_array()) {
303+
std::vector<std::string> std_values;
304+
for (const auto &val : preproc_json["image_std"]) {
305+
if (val.is_number()) {
306+
std_values.push_back(std::to_string(val.get<double>() / rescale_factor));
307+
}
308+
}
309+
if (!std_values.empty()) {
310+
std::ostringstream std_values_stream;
311+
for (size_t i = 0; i < std_values.size(); ++i) {
312+
if (i)
313+
std_values_stream << ' ';
314+
std_values_stream << std_values[i];
315+
}
316+
modelConfig["scale_values"] = ov::Any(std_values_stream.str());
317+
}
318+
}
319+
320+
// Check if do_convert_rgb is not false, then set model format to RGB
321+
const bool do_convert_rgb =
322+
!(preproc_json.contains("do_convert_rgb") && preproc_json["do_convert_rgb"].is_boolean() &&
323+
preproc_json["do_convert_rgb"].get<bool>() == false);
324+
if (do_convert_rgb) {
325+
modelConfig["reverse_input_channels"] = ov::Any(std::string("true"));
326+
}
327+
328+
// Setting up postprocessing parameters
329+
330+
// Model type is always "label" for ViTForImageClassification
331+
modelConfig["model_type"] = ov::Any(std::string("label"));
332+
modelConfig["output_raw_scores"] = ov::Any(std::string("True"));
333+
334+
// Parse label2id mapping to extract labels ordered by their IDs
335+
if (config_json.contains("label2id") && config_json["label2id"].is_object()) {
336+
std::vector<std::pair<int, std::string>> id_labels;
337+
for (auto it = config_json["label2id"].begin(); it != config_json["label2id"].end(); ++it) {
338+
if (!it.value().is_number_integer())
339+
continue;
340+
std::string label = it.key();
341+
std::replace(label.begin(), label.end(), ' ', '_');
342+
id_labels.emplace_back(it.value().get<int>(), label);
343+
}
344+
std::sort(id_labels.begin(), id_labels.end(), [](const auto &a, const auto &b) { return a.first < b.first; });
345+
346+
if (!id_labels.empty()) {
347+
std::ostringstream labels_stream;
348+
for (size_t i = 0; i < id_labels.size(); ++i) {
349+
if (i)
350+
labels_stream << ' ';
351+
labels_stream << id_labels[i].second;
352+
}
353+
modelConfig["labels"] = ov::Any(labels_stream.str());
354+
}
355+
}
356+
357+
return true;
358+
}
359+
360+
// Convert HuggingFace metadata file into Model API format
361+
bool convertHuggingFaceMeta2ModelApi(const std::string &model_file, ov::AnyMap &modelConfig) {
362+
nlohmann::json config_json;
363+
if (!loadJsonFromModelDir(model_file, "config.json", config_json))
364+
return false;
365+
366+
const std::string architecture = getHuggingFaceArchitecture(config_json);
367+
if (architecture.empty())
368+
return false;
369+
370+
if (architecture == "ViTForImageClassification") {
371+
nlohmann::json preproc_json;
372+
if (!loadJsonFromModelDir(model_file, "preprocessor_config.json", preproc_json))
373+
return false;
374+
375+
return parseViTForImageClassification(config_json, preproc_json, modelConfig);
376+
}
377+
378+
return false;
379+
}
380+
381+
// Helper function to check XML for HuggingFace metadata
382+
bool isHuggingFaceModel(const std::string &model_file) {
383+
std::filesystem::path model_path(model_file);
384+
if (model_path.extension() != ".xml") {
385+
model_path.replace_extension(".xml");
386+
}
387+
388+
if (!std::filesystem::exists(model_path))
389+
return false;
390+
391+
std::ifstream xml_stream(model_path.string());
392+
if (!xml_stream.is_open()) {
393+
GST_ERROR("Failed to open XML model file: %s", model_path.string().c_str());
394+
return false;
395+
}
396+
397+
std::string line;
398+
while (std::getline(xml_stream, line)) {
399+
if (line.find("transformers_version") != std::string::npos) {
400+
return true;
401+
}
402+
}
403+
404+
return false;
405+
}
406+
173407
// Convert third-party input metadata config files into Model API format
174408
bool convertThirdPartyModelConfig(const std::string model_file, ov::AnyMap &modelConfig) {
175409
bool updated = false;
@@ -180,6 +414,9 @@ bool convertThirdPartyModelConfig(const std::string model_file, ov::AnyMap &mode
180414
}
181415
}
182416

417+
else if (isHuggingFaceModel(model_file))
418+
updated = convertHuggingFaceMeta2ModelApi(model_file, modelConfig);
419+
183420
return updated;
184421
}
185422

@@ -241,6 +478,8 @@ std::map<std::string, GstStructure *> get_model_info_preproc(const std::shared_p
241478

242479
// override model config with third-party config files (if found)
243480
convertThirdPartyModelConfig(model_file, modelConfig);
481+
if (!modelConfig.empty() && s == nullptr)
482+
s = gst_structure_new_empty(layer_name.data());
244483

245484
// the parameter parsing loop may use locale-dependent floating point conversion
246485
// save current locale and restore after the loop
@@ -405,6 +644,8 @@ std::map<std::string, GstStructure *> get_model_info_postproc(const std::shared_
405644

406645
// update model config with third-party config files (if found)
407646
convertThirdPartyModelConfig(model_file, modelConfig);
647+
if (!modelConfig.empty() && s == nullptr)
648+
s = gst_structure_new_empty(layer_name.data());
408649

409650
// the parameter parsing loop may use locale-dependent floating point conversion
410651
// save current locale and restore after the loop

src/monolithic/inference_backend/image_inference/openvino/model_api_converters.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313

1414
namespace ModelApiConverters {
1515

16+
// Supported HuggingFace architectures
17+
const std::vector<std::string> kHfSupportedArchitectures = {"ViTForImageClassification"};
18+
1619
// convert varying metadata input formats (Yolo, HuggingFace, Geti, ...) to OV Model API pre-processing metadata
1720
std::map<std::string, GstStructure *> get_model_info_preproc(const std::shared_ptr<ov::Model> model,
1821
const std::string model_file,

0 commit comments

Comments
 (0)