Skip to content

Commit fa9ea8a

Browse files
committed
fix spec checking issue.
Signed-off-by: xiping.yan <xiping.yan@intel.com>
1 parent 81aea1f commit fa9ea8a

File tree

3 files changed

+26
-12
lines changed

3 files changed

+26
-12
lines changed

samples/cpp/module_genai/md_omni.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ int main(int argc, char* argv[]) {
9191
bool perf = std::stoi(utils::get_input_arg(argc, argv, "-perf", std::string("0")));
9292
bool use_tts = std::stoi(utils::get_input_arg(argc, argv, "-tts", std::string("0"))) != 0;
9393

94+
std::cout << "Config YAML: " << std::endl;
95+
std::cout << " - " << config_path << std::endl;
96+
9497
utils::OmniInputParams input_params = utils::parse_omni_input_params(argc, argv);
9598
ov::AnyMap inputs = parse_inputs_for_omni(input_params);
9699

@@ -119,9 +122,7 @@ int main(int argc, char* argv[]) {
119122

120123
std::cout << "[Generation] Running main generation..." << std::endl;
121124
auto t1 = std::chrono::high_resolution_clock::now();
122-
123125
pipe.generate(inputs);
124-
125126
auto t2 = std::chrono::high_resolution_clock::now();
126127
if (perf) {
127128
auto diff = std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1).count();

src/cpp/src/module_genai/modules/md_llm_inference_sdpa/md_llm_inference_sdpa.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ const ModuleSpec& LLMInferenceSDPAModule::get_spec() {
3838
static const ModuleSpec spec = []() {
3939
ModuleSpec s("LLMInferenceSDPAModule", "llm_inference_sdpa");
4040
s.add_input("input_ids", {DataType::OVTensor});
41+
s.add_input("position_ids", {DataType::OVTensor}, true);
42+
s.add_input("rope_delta", {DataType::OVTensor}, true);
4143
s.add_input("visual_embeds", {DataType::OVTensor}, true);
4244
s.add_input("visual_pos_mask", {DataType::OVTensor}, true);
4345
s.add_input("grid_thw", {DataType::OVTensor}, true);

src/cpp/src/module_genai/pipeline/module_base.cpp

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -179,30 +179,41 @@ bool IBaseModule::check_bool_optional_param(const std::string& param_name, const
179179
}
180180

181181
void IBaseModule::check_params_with_spec(const ModuleSpec& spec) {
182+
auto error_tip = [this](const std::string& name) {
183+
return "Module[" + ModuleTypeConverter::toString(module_desc->type) + "(" + module_desc->name + ")" + "]: '" +
184+
name + "' ";
185+
};
186+
182187
// Check Inputs
183188
const auto& inputs = module_desc->inputs;
184189
for (const auto& input : inputs) {
185190
// Find input.name in spec.inputs()
186-
auto it = std::find_if(spec.inputs().begin(), spec.inputs().end(),
187-
[&input](const auto& input_spec) { return input_spec.name == input.name; });
188-
OPENVINO_ASSERT (it != spec.inputs().end(), "Module[" + module_desc->name + "]: input '" + input.name + "' is not defined in ModuleSpec");
191+
auto it = std::find_if(spec.inputs().begin(), spec.inputs().end(), [&input](const auto& input_spec) {
192+
return input_spec.name == input.name;
193+
});
194+
OPENVINO_ASSERT(it != spec.inputs().end(), error_tip(input.name) + "is not defined in ModuleSpec");
195+
189196
// Check input.dt_type in it->supported_types
190197
auto& supported_types = it->supported_types;
191-
OPENVINO_ASSERT(std::find(supported_types.begin(), supported_types.end(), input.dt_type) != supported_types.end(),
192-
"Module[" + module_desc->name + "]: input '" + input.name + "' has unsupported data type");
198+
OPENVINO_ASSERT(
199+
std::find(supported_types.begin(), supported_types.end(), input.dt_type) != supported_types.end(),
200+
error_tip(input.name) + "has unsupported data type");
193201
}
194202

195203
// Check Outputs
196204
const auto& outputs = module_desc->outputs;
197205
for (const auto& output : outputs) {
198206
// Find output.name in spec.outputs()
199-
auto it = std::find_if(spec.outputs().begin(), spec.outputs().end(),
200-
[&output](const auto& output_spec) { return output_spec.name == output.name; });
201-
OPENVINO_ASSERT (it != spec.outputs().end(), "Module[" + module_desc->name + "]: output '" + output.name + "' is not defined in ModuleSpec");
207+
auto it = std::find_if(spec.outputs().begin(), spec.outputs().end(), [&output](const auto& output_spec) {
208+
return output_spec.name == output.name;
209+
});
210+
OPENVINO_ASSERT(it != spec.outputs().end(), error_tip(output.name) + "is not defined in ModuleSpec");
211+
202212
// Check output.dt_type in it->supported_types
203213
auto& supported_types = it->supported_types;
204-
OPENVINO_ASSERT(std::find(supported_types.begin(), supported_types.end(), output.dt_type) != supported_types.end(),
205-
"Module[" + module_desc->name + "]: output '" + output.name + "' has unsupported data type");
214+
OPENVINO_ASSERT(
215+
std::find(supported_types.begin(), supported_types.end(), output.dt_type) != supported_types.end(),
216+
error_tip(output.name) + "has unsupported data type");
206217
}
207218
}
208219

0 commit comments

Comments
 (0)