Skip to content

Commit f3d9e81

Browse files
committed
update: check_bool_param
Signed-off-by: xipingya <xiping.yan@intel.com>
1 parent 9f47613 commit f3d9e81

File tree

4 files changed

+8
-21
lines changed

4 files changed

+8
-21
lines changed

src/cpp/src/module_genai/modules/md_text_to_speech/md_text_to_speech.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ void TextToSpeechModule::print_static_config() {
5959
codec_embedding_model_path: "codec_embedding_model.xml" # codec embedding model IR xml path
6060
code_predictor_ar_model_path: "code_predictor_ar_model" # code predictor autoregressive model directory path
6161
sample_codec_token_greedy_search: false # Eanble greedy decoding in sample_codec_token, which is used for fast debugging and also for GPU inference since random sampling is not easy to implement on GPU.
62+
merge_ar_and_sce_ov_models: false # Merge AR and CSE models into one OV model for better performance. Requires "sample_codec_token_greedy_search=true.
6263
code_predictor_single_codec_embed_model_path: "code_predictor_single_codec_embed_model" # code predictor single codec embedding model directory path
6364
code_predictor_single_codec_embedding_model_path: "code_predictor_single_codec_embedding_model.xml" # code predictor single codec embedding model IR xml path
6465
speech_decoder_model_path: "speech_decoder_model.xml" # speech decoder model IR xml path
@@ -71,10 +72,8 @@ TextToSpeechModule::TextToSpeechModule(const IBaseModuleDesc::PTR& desc,
7172
: IBaseModule(desc, pipeline_desc),
7273
m_model_type(model_type),
7374
m_device(desc->device.empty() ? "CPU" : desc->device) {
74-
auto sample_codec_token_greedy_search_str = get_optional_param("sample_codec_token_greedy_search");
75-
if (!sample_codec_token_greedy_search_str.empty()) {
76-
m_sample_codec_token_greedy_search = str_to_bool(sample_codec_token_greedy_search_str);
77-
}
75+
m_sample_codec_token_greedy_search = check_bool_param("sample_codec_token_greedy_search", false);
76+
m_merge_ar_and_sce_ov_models = check_bool_param("merge_ar_and_sce_ov_models", false);
7877
}
7978

8079
TextToSpeechModule::~TextToSpeechModule() = default;

src/cpp/src/module_genai/modules/md_text_to_speech/md_text_to_speech.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ class TextToSpeechModule : public IBaseModule {
7070
bool m_sample_codec_token_greedy_search =
7171
false; // Eanble greedy decoding in sample_codec_token, which is used for fast debugging and also for GPU
7272
// inference since random sampling is not easy to implement on GPU.
73+
bool m_merge_ar_and_sce_ov_models = false; // Merge AR and CSE models into one OV model for better performance.
74+
// Requires "sample_codec_token_greedy_search=true.
7375
};
7476

7577
}

src/cpp/src/module_genai/pipeline/module_base.cpp

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -77,17 +77,6 @@ size_t IBaseModule::str_to_size_t(const std::string& str) {
7777
}
7878
}
7979

80-
// Convert string to bool. String "true", "True", "TRUE", "1" will be converted to true, and string "false",
81-
// "False", "FALSE", "0" will be converted to false. Other strings will throw an exception.
82-
bool IBaseModule::str_to_bool(const std::string& param_item) {
83-
if (param_item == "true" || param_item == "True" || param_item == "TRUE" || param_item == "1") {
84-
return true;
85-
} else if (param_item == "false" || param_item == "False" || param_item == "FALSE" || param_item == "0") {
86-
return false;
87-
}
88-
OPENVINO_THROW("Failed to parse bool from string: " + param_item);
89-
}
90-
9180
void IBaseModule::init_ov_model() {
9281
if (m_ov_model == nullptr) {
9382
m_ov_model = get_ov_model_from_cfg_models_map("ov_model", false);
@@ -167,8 +156,8 @@ void IBaseModule::check_splitted_model() {
167156
}
168157
}
169158

170-
bool IBaseModule::check_bool_param(const std::string& param_name, const bool& default_value) {
171-
auto p = get_optional_param(param_name);
159+
bool IBaseModule::check_bool_param(const std::string& param_name, const bool& default_value, const bool& requires) {
160+
auto p = requires ? get_param(param_name) : get_optional_param(param_name);
172161
if (p.empty()) {
173162
return default_value;
174163
}

src/cpp/src/module_genai/pipeline/module_base.hpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,6 @@ class IBaseModule {
6565
std::string get_param(const std::string& param_item);
6666
std::string get_optional_param(const std::string& param_item);
6767
size_t str_to_size_t(const std::string& param_item);
68-
// Convert string to bool. String "true", "True", "TRUE", "1" will be converted to true, and string "false",
69-
// "False", "FALSE", "0" will be converted to false. Other strings will throw an exception.
70-
bool str_to_bool(const std::string& param_item);
7168
static void start_generate() {
7269
m_generate_start_time = std::chrono::steady_clock::now();
7370
}
@@ -84,7 +81,7 @@ class IBaseModule {
8481
bool m_splitted_model = false;
8582
void check_splitted_model();
8683

87-
bool check_bool_param(const std::string& param_name, const bool& default_value);
84+
bool check_bool_param(const std::string& param_name, const bool& default_value, const bool& requires=false);
8885

8986
// Initialize ov::Model from config models_map with param_name: "ov_model"
9087
void init_ov_model();

0 commit comments

Comments
 (0)