Skip to content

Commit 534a393

Browse files
authored
ImageGen models support for export_models.py (#3301)
CVS-166467
1 parent 21a1bf7 commit 534a393

File tree

2 files changed

+90
-1
lines changed

2 files changed

+90
-1
lines changed

demos/common/export_models/export_model.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,15 @@ def add_common_arguments(parser):
6464
parser_rerank.add_argument('--num_streams', default="1", help='The number of parallel execution streams to use for the model. Use at least 2 on 2 socket CPU systems.', dest='num_streams')
6565
parser_rerank.add_argument('--max_doc_length', default=16000, type=int, help='Maximum length of input documents in tokens', dest='max_doc_length')
6666
parser_rerank.add_argument('--version', default="1", help='version of the model', dest='version')
67+
68+
parser_image_generation = subparsers.add_parser('image_generation', help='export model for image generation endpoint')
69+
add_common_arguments(parser_image_generation)
70+
parser_image_generation.add_argument('--num_streams', default=0, type=int, help='The number of parallel execution streams to use for the models in the pipeline.', dest='num_streams')
71+
parser_image_generation.add_argument('--max_resolution', default="", help='Max allowed resolution in a format of WxH; W=width H=height', dest='max_resolution')
72+
parser_image_generation.add_argument('--default_resolution', default="", help='Default resolution when not specified by client', dest='default_resolution')
73+
parser_image_generation.add_argument('--max_number_images_per_prompt', type=int, default=0, help='Max allowed number of images client is allowed to request for a given prompt', dest='max_number_images_per_prompt')
74+
parser_image_generation.add_argument('--default_num_inference_steps', type=int, default=0, help='Default number of inference steps when not specified by client', dest='default_num_inference_steps')
75+
parser_image_generation.add_argument('--max_num_inference_steps', type=int, default=0, help='Max allowed number of inference steps client is allowed to request for a given prompt', dest='max_num_inference_steps')
6776
args = vars(parser.parse_args())
6877

6978
embedding_graph_template = """input_stream: "REQUEST_PAYLOAD:input"
@@ -213,6 +222,35 @@ def add_common_arguments(parser):
213222
]
214223
}"""
215224

225+
image_generation_graph_template = """input_stream: "HTTP_REQUEST_PAYLOAD:input"
226+
output_stream: "HTTP_RESPONSE_PAYLOAD:output"
227+
228+
node: {
229+
name: "ImageGenExecutor"
230+
calculator: "ImageGenCalculator"
231+
input_stream: "HTTP_REQUEST_PAYLOAD:input"
232+
input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes"
233+
output_stream: "HTTP_RESPONSE_PAYLOAD:output"
234+
node_options: {
235+
[type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: {
236+
models_path: "{{model_path}}",
237+
{%- if plugin_config_str %}
238+
plugin_config: '{{plugin_config_str}}',{% endif %}
239+
device: "{{target_device|default("CPU", true)}}",
240+
{%- if max_resolution %}
241+
max_resolution: '{{max_resolution}}',{% endif %}
242+
{%- if default_resolution %}
243+
default_resolution: '{{default_resolution}}',{% endif %}
244+
{%- if max_number_images_per_prompt > 0 %}
245+
max_number_images_per_prompt: {{max_number_images_per_prompt}},{% endif %}
246+
{%- if default_num_inference_steps > 0 %}
247+
default_num_inference_steps: {{default_num_inference_steps}},{% endif %}
248+
{%- if max_num_inference_steps > 0 %}
249+
max_num_inference_steps: {{max_num_inference_steps}},{% endif %}
250+
}
251+
}
252+
}"""
253+
216254
def export_rerank_tokenizer(source_model, destination_path, max_length):
217255
import openvino as ov
218256
from openvino_tokenizers import convert_tokenizer
@@ -448,6 +486,46 @@ def export_rerank_model(model_repository_path, source_model, model_name, precisi
448486
add_servable_to_config(config_file_path, model_name, os.path.relpath( os.path.join(model_repository_path, model_name), os.path.dirname(config_file_path)))
449487

450488

489+
def export_image_generation_model(model_repository_path, source_model, model_name, precision, task_parameters, config_file_path, num_streams):
490+
model_path = "./"
491+
target_path = os.path.join(model_repository_path, model_name)
492+
model_index_path = os.path.join(target_path, 'model_index.json')
493+
494+
if os.path.isfile(model_index_path):
495+
print("Model index file already exists. Skipping conversion, re-generating graph only.")
496+
else:
497+
optimum_command = "optimum-cli export openvino --model {} --weight-format {} {}".format(source_model, precision, target_path)
498+
if os.system(optimum_command):
499+
raise ValueError("Failed to export image generation model model", source_model)
500+
501+
plugin_config = {}
502+
assert num_streams >= 0, "num_streams should be a non-negative integer"
503+
if num_streams > 0:
504+
plugin_config['NUM_STREAMS'] = num_streams
505+
if 'ov_cache_dir' in task_parameters and task_parameters['ov_cache_dir'] is not None:
506+
plugin_config['CACHE_DIR'] = task_parameters['ov_cache_dir']
507+
508+
if len(plugin_config) > 0:
509+
task_parameters['plugin_config_str'] = json.dumps(plugin_config)
510+
511+
# assert that max_resolution if exists, is in WxH format
512+
for param in ['max_resolution', 'default_resolution']:
513+
if task_parameters[param]:
514+
if 'x' not in task_parameters[param]:
515+
raise ValueError(param + " should be in WxH format, e.g. 1024x768")
516+
width, height = task_parameters[param].split('x')
517+
if not (width.isdigit() and height.isdigit()):
518+
raise ValueError(param + " should be in WxH format with positive integers, e.g. 1024x768")
519+
task_parameters[param] = '{}x{}'.format(int(width), int(height))
520+
521+
gtemplate = jinja2.Environment(loader=jinja2.BaseLoader).from_string(image_generation_graph_template)
522+
graph_content = gtemplate.render(model_path=model_path, **task_parameters)
523+
with open(os.path.join(model_repository_path, model_name, 'graph.pbtxt'), 'w') as f:
524+
f.write(graph_content)
525+
print("Created graph {}".format(os.path.join(model_repository_path, model_name, 'graph.pbtxt')))
526+
add_servable_to_config(config_file_path, model_name, os.path.relpath( os.path.join(model_repository_path, model_name), os.path.dirname(config_file_path)))
527+
528+
451529
if not os.path.isdir(args['model_repository_path']):
452530
raise ValueError(f"The model repository path '{args['model_repository_path']}' is not a valid directory.")
453531
if args['source_model'] is None:
@@ -477,4 +555,14 @@ def export_rerank_model(model_repository_path, source_model, model_name, precisi
477555
elif args['task'] == 'rerank':
478556
export_rerank_model(args['model_repository_path'], args['source_model'], args['model_name'] ,args['precision'], template_parameters, str(args['version']), args['config_file_path'], args['max_doc_length'])
479557

480-
558+
elif args['task'] == 'image_generation':
559+
template_parameters = {k: v for k, v in args.items() if k in [
560+
'ov_cache_dir',
561+
'target_device',
562+
'max_resolution',
563+
'default_resolution',
564+
'max_number_images_per_prompt',
565+
'default_num_inference_steps',
566+
'max_num_inference_steps',
567+
]}
568+
export_image_generation_model(args['model_repository_path'], args['source_model'], args['model_name'], args['precision'], template_parameters, args['config_file_path'], args['num_streams'])

demos/common/export_models/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ einops
1313
torchvision==0.21.0
1414
timm==1.0.15
1515
auto-gptq==0.7.1
16+
diffusers==0.33.1 # for image generation

0 commit comments

Comments
 (0)