@@ -453,99 +453,128 @@ Status HttpRestApiHandler::dispatchToProcessor(
453453 return StatusCode::UNKNOWN_REQUEST_COMPONENTS_TYPE;
454454}
455455
456- Status HttpRestApiHandler::processV3 (const std::string_view uri, const HttpRequestComponents& request_components, std::string& response, const std::string& request_body, std::shared_ptr<HttpAsyncWriter> serverReaderWriter, std::shared_ptr<MultiPartParser> multiPartParser) {
457- #if (MEDIAPIPE_DISABLE == 0)
458- OVMS_PROFILE_FUNCTION ();
459- HttpPayload request;
460- std::shared_ptr<Document> doc = std::make_shared<Document>();
461- std::shared_ptr<MediapipeGraphExecutor> executor;
462- bool streamFieldVal = false ;
463- {
464- auto it = request_components.headers .find (" content-type" );
465- bool isApplicationJson = it != request_components.headers .end () && it->second .find (" application/json" ) != std::string::npos;
466- bool isMultiPart = it != request_components.headers .end () && it->second .find (" multipart/form-data" ) != std::string::npos;
467- bool isDefault = !isApplicationJson && !isMultiPart;
468-
469- std::string model_name;
470-
471- if (isMultiPart) {
472- OVMS_PROFILE_SCOPE (" multipart parse" );
473- if (!multiPartParser->parse ()) {
474- return StatusCode::REST_INVALID_URL;
475- }
476- model_name = multiPartParser->getFieldByName (" model" );
477- if (model_name.empty ()) {
478- isDefault = true ;
479- } else {
480- SPDLOG_DEBUG (" Model name from deduced from MultiPart field: {}" , model_name);
481- }
482- // Set json parser in invalid state in order to get HasParseError to respond with true
483- doc->Parse (" error" );
484- } else if (isApplicationJson) {
485- {
486- OVMS_PROFILE_SCOPE (" rapidjson parse" );
487- doc->Parse (request_body.c_str ());
488- }
489- OVMS_PROFILE_SCOPE (" rapidjson validate" );
490- if (doc->HasParseError ()) {
491- return Status (StatusCode::JSON_INVALID, " Cannot parse JSON body" );
492- }
456+ static void ensureJsonParserInErrorState (std::shared_ptr<Document>& parsedJson) {
457+ // Hack to set json parser in invalid state in order to get HasParseError to respond with true
458+ parsedJson->Parse (" error" );
459+ }
493460
494- if (!doc->IsObject ()) {
495- return Status (StatusCode::JSON_INVALID, " JSON body must be an object" );
496- }
461+ static Status createV3HttpPayload (
462+ const std::string_view uri,
463+ const HttpRequestComponents& request_components,
464+ std::string& response,
465+ const std::string& request_body,
466+ std::shared_ptr<HttpAsyncWriter> serverReaderWriter,
467+ std::shared_ptr<MultiPartParser> multiPartParser,
468+ HttpPayload& request,
469+ std::string& modelName,
470+ bool & streamFieldVal) {
471+ OVMS_PROFILE_SCOPE (" createV3HttpPayload" );
472+
473+ std::shared_ptr<Document> parsedJson = std::make_shared<Document>();
474+
475+ auto it = request_components.headers .find (" content-type" );
476+ bool isApplicationJson = it != request_components.headers .end () && it->second .find (" application/json" ) != std::string::npos;
477+ bool isMultiPart = it != request_components.headers .end () && it->second .find (" multipart/form-data" ) != std::string::npos;
478+ bool isUriBasedRouting = !isApplicationJson && !isMultiPart; // For content types other than "application/json" and "multipart/form-data", we look for model information in the URI
479+
480+ if (isMultiPart) {
481+ OVMS_PROFILE_SCOPE (" multipart parse" );
482+ if (!multiPartParser->parse ()) {
483+ SPDLOG_DEBUG (" Failed to parse multipart content type request" );
484+ return StatusCode::FAILED_TO_PARSE_MULTIPART_CONTENT_TYPE;
485+ }
486+ modelName = multiPartParser->getFieldByName (" model" );
487+ if (modelName.empty ()) {
488+ isUriBasedRouting = true ;
489+ } else {
490+ SPDLOG_DEBUG (" Model name from deduced from MultiPart field: {}" , modelName);
491+ }
492+ ensureJsonParserInErrorState (parsedJson);
493+ } else if (isApplicationJson) {
494+ {
495+ OVMS_PROFILE_SCOPE (" rapidjson parse" );
496+ parsedJson->Parse (request_body.c_str ());
497+ }
498+ OVMS_PROFILE_SCOPE (" rapidjson validate" );
499+ if (parsedJson->HasParseError ()) {
500+ return Status (StatusCode::JSON_INVALID, " Cannot parse JSON body" );
501+ }
497502
498- auto modelNameIt = doc->FindMember (" model" );
499- if (modelNameIt == doc->MemberEnd ()) {
500- return Status (StatusCode::JSON_INVALID, " model field is missing in JSON body" );
501- }
503+ if (!parsedJson->IsObject ()) {
504+ return Status (StatusCode::JSON_INVALID, " JSON body must be an object" );
505+ }
502506
503- if (!modelNameIt->value .IsString ()) {
504- return Status (StatusCode::JSON_INVALID, " model field is not a string" );
505- }
507+ auto modelNameIt = parsedJson->FindMember (" model" );
508+ if (modelNameIt == parsedJson->MemberEnd ()) {
509+ return Status (StatusCode::JSON_INVALID, " model field is missing in JSON body" );
510+ }
506511
507- bool isTextGenerationEndpoint = uri.find (" completions" ) != std::string_view::npos;
508- if (isTextGenerationEndpoint) {
509- auto streamIt = doc->FindMember (" stream" );
510- if (streamIt != doc->MemberEnd ()) {
511- if (!streamIt->value .IsBool ()) {
512- return Status (StatusCode::JSON_INVALID, " stream field is not a boolean" );
513- }
514- streamFieldVal = streamIt->value .GetBool ();
515- }
516- }
512+ if (!modelNameIt->value .IsString ()) {
513+ return Status (StatusCode::JSON_INVALID, " model field is not a string" );
514+ }
517515
518- model_name = modelNameIt->value .GetString ();
519- if (model_name.empty ()) {
520- isDefault = true ;
521- } else {
522- SPDLOG_DEBUG (" Model name from deduced from JSON: {}" , model_name);
516+ bool isTextGenerationEndpoint = uri.find (" completions" ) != std::string_view::npos;
517+ if (isTextGenerationEndpoint) {
518+ auto streamIt = parsedJson->FindMember (" stream" );
519+ if (streamIt != parsedJson->MemberEnd ()) {
520+ if (!streamIt->value .IsBool ()) {
521+ return Status (StatusCode::JSON_INVALID, " stream field is not a boolean" );
522+ }
523+ streamFieldVal = streamIt->value .GetBool ();
523524 }
524525 }
525526
526- // Deduce Graph Name from URI since there is no info in JSON or MultiPart
527- if (isDefault) {
528- if (uri.size () <= 4 ) { // nothing after "/v3/..."
529- return StatusCode::REST_INVALID_URL;
530- }
531- model_name = std::string (uri.substr (4 ));
532- SPDLOG_DEBUG (" Model name from deduced from URI: {}" , model_name);
533- // Set json parser in invalid state in order to get HasParseError to respond with true
534- doc->Parse (" error" );
527+ modelName = modelNameIt->value .GetString ();
528+ if (modelName.empty ()) {
529+ isUriBasedRouting = true ;
530+ } else {
531+ SPDLOG_DEBUG (" Model name from deduced from JSON: {}" , modelName);
535532 }
533+ }
536534
537- auto status = this ->modelManager .createPipeline (executor, model_name);
538- if (!status.ok ()) {
539- return status;
535+ // Deduce Graph Name from URI since there is no info in JSON or MultiPart
536+ if (isUriBasedRouting) {
537+ if (uri.size () <= 4 ) { // nothing after "/v3/..."
538+ SPDLOG_DEBUG (" Failed to deduce model name from URI" );
539+ return StatusCode::FAILED_TO_DEDUCE_MODEL_NAME_FROM_URI;
540540 }
541+ modelName = std::string (uri.substr (4 ));
542+ SPDLOG_DEBUG (" Model name from deduced from URI: {}" , modelName);
543+ // Set json parser in invalid state in order to get HasParseError to respond with true
544+ ensureJsonParserInErrorState (parsedJson);
545+ }
546+
547+ request.headers = request_components.headers ;
548+ request.body = request_body;
549+ request.parsedJson = std::move (parsedJson);
550+ request.uri = std::string (uri);
551+ request.client = std::make_shared<HttpClientConnection>(serverReaderWriter);
552+ request.multipartParser = std::move (multiPartParser);
553+
554+ return StatusCode::OK;
555+ }
556+
557+ Status HttpRestApiHandler::processV3 (const std::string_view uri, const HttpRequestComponents& request_components, std::string& response, const std::string& request_body, std::shared_ptr<HttpAsyncWriter> serverReaderWriter, std::shared_ptr<MultiPartParser> multiPartParser) {
558+ #if (MEDIAPIPE_DISABLE == 0)
559+ OVMS_PROFILE_FUNCTION ();
560+
561+ HttpPayload request;
562+ std::string modelName;
563+ bool streamFieldVal = false ;
564+
565+ auto status = createV3HttpPayload (uri, request_components, response, request_body, serverReaderWriter, std::move (multiPartParser), request, modelName, streamFieldVal);
566+ if (!status.ok ()) {
567+ // TODO: Add logger
568+ SPDLOG_DEBUG (" Failed to create V3 payload" );
569+ return status;
570+ }
541571
542- request.headers = request_components.headers ;
543- request.body = request_body;
544- request.parsedJson = std::move (doc);
545- request.uri = std::string (uri);
546- request.client = std::make_shared<HttpClientConnection>(serverReaderWriter);
547- request.multipartParser = std::move (multiPartParser);
572+ std::shared_ptr<MediapipeGraphExecutor> executor;
573+ status = this ->modelManager .createPipeline (executor, modelName);
574+ if (!status.ok ()) {
575+ return status;
548576 }
577+
549578 if (streamFieldVal == false ) {
550579 ExecutionContext executionContext{ExecutionContext::Interface::REST, ExecutionContext::Method::V3Unary};
551580 return executor->infer (&request, &response, executionContext);
0 commit comments