From 7f9f7ace7c0d35d693b56470436f09725fe9d6c4 Mon Sep 17 00:00:00 2001 From: Will Sackfield Date: Thu, 6 Mar 2025 13:43:45 -0500 Subject: [PATCH 1/7] Add CommonChatSchema types MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add typed dictionaries which conform to OpenAI’s protocol specification as first class objects due to the ubiquity of their use. --- python/cog/__init__.py | 2 + python/cog/predictor.py | 2 + python/cog/types.py | 108 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 112 insertions(+) diff --git a/python/cog/__init__.py b/python/cog/__init__.py index 72f1399cd0..b79d806ac9 100644 --- a/python/cog/__init__.py +++ b/python/cog/__init__.py @@ -7,6 +7,7 @@ from .server.scope import current_scope, emit_metric from .types import ( AsyncConcatenateIterator, + CommonChatSchemaChatMessage, ConcatenateIterator, ExperimentalFeatureWarning, File, @@ -36,4 +37,5 @@ "Input", "Path", "Secret", + "CommonChatSchemaChatMessage", ] diff --git a/python/cog/predictor.py b/python/cog/predictor.py index a9b32f7553..890e915ac5 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -35,6 +35,7 @@ from .code_xforms import load_module_from_string, strip_model_source_code from .types import ( PYDANTIC_V2, + CommonChatSchemaChatMessage, Input, Weights, ) @@ -56,6 +57,7 @@ CogFile, CogPath, CogSecret, + CommonChatSchemaChatMessage, ] diff --git a/python/cog/types.py b/python/cog/types.py index 8b0de604cd..e551bf98a7 100644 --- a/python/cog/types.py +++ b/python/cog/types.py @@ -64,6 +64,114 @@ class CogConcurrencyConfig(TypedDict, total=False): # pylint: disable=too-many- max: NotRequired[int] +class CommonChatSchemaTextContentPart(TypedDict): + type: str + text: str + + +class CommonChatSchemaImageURL(TypedDict): + url: str + detail: Optional[str] + + +class CommonChatSchemaImageContentPart(TypedDict): + type: str + image_url: CommonChatSchemaImageURL + + +class CommonChatSchemaInputAudio(TypedDict): + data: str + format: str + + +class CommonChatSchemaAudioContentPart(TypedDict): + type: str + input_audio: CommonChatSchemaInputAudio + + +class CommonChatSchemaRefuslaContentPart(TypedDict): + type: str + refusal: str + + +class CommonChatSchemaAudio(TypedDict): + id: str + + +class CommonChatSchemaFunction(TypedDict): + name: str + arguments: str + + +class CommonChatSchemaToolCall(TypedDict): + id: str + type: str + function: CommonChatSchemaFunction + + +class CommonChatSchemaDeveloperMessage(TypedDict, total=False): + content: Union[str, List[str]] + role: str + name: Optional[str] + + +class CommonChatSchemaSystemMessage(TypedDict, total=False): + content: Union[str, List[str]] + role: str + name: Optional[str] + + +class CommonChatSchemaUserMessage(TypedDict, total=False): + content: Union[ + str, + List[ + Union[ + CommonChatSchemaTextContentPart, + CommonChatSchemaImageContentPart, + CommonChatSchemaAudioContentPart, + ] + ], + ] + role: str + name: Optional[str] + + +class CommonChatSchemaAssistantMessage(TypedDict, total=False): + content: Union[ + str, + List[ + Union[CommonChatSchemaTextContentPart, CommonChatSchemaRefuslaContentPart] + ], + ] + role: str + name: Optional[str] + audio: Optional[CommonChatSchemaAudio] + tool_calls: Optional[List[CommonChatSchemaToolCall]] + function_call: Optional[CommonChatSchemaFunction] + + +class CommonChatSchemaToolMessage(TypedDict): + role: str + content: Union[str, List[str]] + tool_call_id: str + + +class CommonChatSchemaFunctionMessage(TypedDict, total=False): + role: str + content: Optional[str] + name: str + + +CommonChatSchemaChatMessage = Union[ + CommonChatSchemaDeveloperMessage, + CommonChatSchemaSystemMessage, + CommonChatSchemaUserMessage, + CommonChatSchemaAssistantMessage, + CommonChatSchemaToolMessage, + CommonChatSchemaFunctionMessage, +] + + def Input( # pylint: disable=invalid-name, too-many-arguments default: Any = ..., description: Optional[str] = None, From f3c1593e083f113d799d6df1d3b37e4055c1b26c Mon Sep 17 00:00:00 2001 From: Will Sackfield Date: Thu, 6 Mar 2025 13:43:56 -0500 Subject: [PATCH 2/7] Add integration test for chat messages --- .../fixtures/chat-message-project/cog.yaml | 3 +++ .../fixtures/chat-message-project/predict.py | 8 ++++++++ test-integration/test_integration/test_predict.py | 15 +++++++++++++++ tox.ini | 1 + 4 files changed, 27 insertions(+) create mode 100644 test-integration/test_integration/fixtures/chat-message-project/cog.yaml create mode 100644 test-integration/test_integration/fixtures/chat-message-project/predict.py diff --git a/test-integration/test_integration/fixtures/chat-message-project/cog.yaml b/test-integration/test_integration/fixtures/chat-message-project/cog.yaml new file mode 100644 index 0000000000..e357cab833 --- /dev/null +++ b/test-integration/test_integration/fixtures/chat-message-project/cog.yaml @@ -0,0 +1,3 @@ +build: + python_version: "3.9" +predict: predict.py:Predictor diff --git a/test-integration/test_integration/fixtures/chat-message-project/predict.py b/test-integration/test_integration/fixtures/chat-message-project/predict.py new file mode 100644 index 0000000000..b4833c1b4d --- /dev/null +++ b/test-integration/test_integration/fixtures/chat-message-project/predict.py @@ -0,0 +1,8 @@ +from cog import BasePredictor, CommonChatSchemaChatMessage + + +class Predictor(BasePredictor): + + def predict(self, messages: list[CommonChatSchemaChatMessage]) -> str: + print(messages) + return f"HELLO {messages[0]['role']}" diff --git a/test-integration/test_integration/test_predict.py b/test-integration/test_integration/test_predict.py index 459f09f03e..26e90751dc 100644 --- a/test-integration/test_integration/test_predict.py +++ b/test-integration/test_integration/test_predict.py @@ -7,6 +7,7 @@ import httpx import pytest +import requests from .util import cog_server_http_run @@ -369,3 +370,17 @@ async def make_request(i: int) -> httpx.Response: for i, task in enumerate(tasks): assert task.result().status_code == 200 assert task.result().json()["output"] == f"wake up sleepyhead{i}" + + +def test_predict_chat_message(): + with cog_server_http_run( + Path(__file__).parent / "fixtures" / "chat-message-project" + ) as addr: + response = requests.post( + addr + "/predictions", + json={"input": {"messages": [{"role": "User", "content": "Hello There!"}]}}, + timeout=3.0, + ) + response.raise_for_status() + body = response.json() + assert body["output"] == "HELLO User" diff --git a/tox.ini b/tox.ini index 848b138b83..f32787ed6a 100644 --- a/tox.ini +++ b/tox.ini @@ -71,4 +71,5 @@ deps = pytest-rerunfailures pytest-timeout pytest-xdist + requests commands = pytest {posargs:-n auto -vv --reruns 3} From 8107a5f5d8817ac43a8b55bc968ccc75c3a117f9 Mon Sep 17 00:00:00 2001 From: Will Sackfield Date: Thu, 6 Mar 2025 15:00:20 -0500 Subject: [PATCH 3/7] Fix ALLOWED_INPUT_TYPES type linting --- python/cog/predictor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cog/predictor.py b/python/cog/predictor.py index 890e915ac5..25acc28cd1 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -49,7 +49,7 @@ log = structlog.get_logger("cog.server.predictor") -ALLOWED_INPUT_TYPES: List[Type[Any]] = [ +ALLOWED_INPUT_TYPES: List[Union[Type[Any], Type[CommonChatSchemaChatMessage]]] = [ str, int, float, From ec7ac313c0dbe7466c3973a584f8d4c9f067b83f Mon Sep 17 00:00:00 2001 From: Will Sackfield Date: Fri, 7 Mar 2025 09:40:49 -0500 Subject: [PATCH 4/7] Use alphanumeric ordering --- python/cog/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cog/__init__.py b/python/cog/__init__.py index b79d806ac9..b656d77fe9 100644 --- a/python/cog/__init__.py +++ b/python/cog/__init__.py @@ -31,11 +31,11 @@ "AsyncConcatenateIterator", "BaseModel", "BasePredictor", + "CommonChatSchemaChatMessage", "ConcatenateIterator", "ExperimentalFeatureWarning", "File", "Input", "Path", "Secret", - "CommonChatSchemaChatMessage", ] From 076c47817c5fb4e75b9e508d97e4c2af0c48ef28 Mon Sep 17 00:00:00 2001 From: Will Sackfield Date: Fri, 7 Mar 2025 15:28:23 -0500 Subject: [PATCH 5/7] Change CommonChatSchemaChatMessage to ChatMessage --- python/cog/__init__.py | 4 ++-- python/cog/predictor.py | 6 +++--- python/cog/types.py | 2 +- .../fixtures/chat-message-project/predict.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/python/cog/__init__.py b/python/cog/__init__.py index b656d77fe9..7e8403d522 100644 --- a/python/cog/__init__.py +++ b/python/cog/__init__.py @@ -7,7 +7,7 @@ from .server.scope import current_scope, emit_metric from .types import ( AsyncConcatenateIterator, - CommonChatSchemaChatMessage, + ChatMessage, ConcatenateIterator, ExperimentalFeatureWarning, File, @@ -31,7 +31,7 @@ "AsyncConcatenateIterator", "BaseModel", "BasePredictor", - "CommonChatSchemaChatMessage", + "ChatMessage", "ConcatenateIterator", "ExperimentalFeatureWarning", "File", diff --git a/python/cog/predictor.py b/python/cog/predictor.py index 25acc28cd1..3c60f4604c 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -35,7 +35,7 @@ from .code_xforms import load_module_from_string, strip_model_source_code from .types import ( PYDANTIC_V2, - CommonChatSchemaChatMessage, + ChatMessage, Input, Weights, ) @@ -49,15 +49,15 @@ log = structlog.get_logger("cog.server.predictor") -ALLOWED_INPUT_TYPES: List[Union[Type[Any], Type[CommonChatSchemaChatMessage]]] = [ +ALLOWED_INPUT_TYPES: List[Union[Type[Any], Type[ChatMessage]]] = [ str, int, float, bool, + ChatMessage, CogFile, CogPath, CogSecret, - CommonChatSchemaChatMessage, ] diff --git a/python/cog/types.py b/python/cog/types.py index e551bf98a7..15bc98c95b 100644 --- a/python/cog/types.py +++ b/python/cog/types.py @@ -162,7 +162,7 @@ class CommonChatSchemaFunctionMessage(TypedDict, total=False): name: str -CommonChatSchemaChatMessage = Union[ +ChatMessage = Union[ CommonChatSchemaDeveloperMessage, CommonChatSchemaSystemMessage, CommonChatSchemaUserMessage, diff --git a/test-integration/test_integration/fixtures/chat-message-project/predict.py b/test-integration/test_integration/fixtures/chat-message-project/predict.py index b4833c1b4d..371edbdda6 100644 --- a/test-integration/test_integration/fixtures/chat-message-project/predict.py +++ b/test-integration/test_integration/fixtures/chat-message-project/predict.py @@ -1,8 +1,8 @@ -from cog import BasePredictor, CommonChatSchemaChatMessage +from cog import BasePredictor, ChatMessage class Predictor(BasePredictor): - def predict(self, messages: list[CommonChatSchemaChatMessage]) -> str: + def predict(self, messages: list[ChatMessage]) -> str: print(messages) return f"HELLO {messages[0]['role']}" From 631ac1f7f7138f6dda4b6f5f2b692502a272f87d Mon Sep 17 00:00:00 2001 From: Will Sackfield Date: Mon, 10 Mar 2025 14:55:44 -0400 Subject: [PATCH 6/7] Support chat message in cog predict * Use the schema to parse the expected format of the key, and if it is in our anticipated formats encode it with JSON and send that via API to the cog service. --- pkg/cli/predict.go | 6 ++-- pkg/predict/input.go | 65 ++++++++++++++++++++++++++++++++++----- pkg/predict/input_test.go | 24 +++++++++++++++ 3 files changed, 85 insertions(+), 10 deletions(-) create mode 100644 pkg/predict/input_test.go diff --git a/pkg/cli/predict.go b/pkg/cli/predict.go index c3c325c54b..99c2544878 100644 --- a/pkg/cli/predict.go +++ b/pkg/cli/predict.go @@ -185,7 +185,7 @@ func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, o return err } - inputs, err := parseInputFlags(inputFlags) + inputs, err := parseInputFlags(inputFlags, schema) if err != nil { return err } @@ -361,7 +361,7 @@ func writeDataURLOutput(outputString string, outputPath string, addExtension boo return nil } -func parseInputFlags(inputs []string) (predict.Inputs, error) { +func parseInputFlags(inputs []string, schema *openapi3.T) (predict.Inputs, error) { keyVals := map[string][]string{} for _, input := range inputs { var name, value string @@ -383,7 +383,7 @@ func parseInputFlags(inputs []string) (predict.Inputs, error) { keyVals[name] = append(keyVals[name], value) } - return predict.NewInputs(keyVals), nil + return predict.NewInputs(keyVals, schema) } func addSetupTimeoutFlag(cmd *cobra.Command) { diff --git a/pkg/predict/input.go b/pkg/predict/input.go index 7965c79431..509306c298 100644 --- a/pkg/predict/input.go +++ b/pkg/predict/input.go @@ -1,11 +1,13 @@ package predict import ( + "encoding/json" "fmt" "os" "path/filepath" "strings" + "github.com/getkin/kin-openapi/openapi3" "github.com/mitchellh/go-homedir" "github.com/vincent-petithory/dataurl" @@ -13,22 +15,69 @@ import ( ) type Input struct { - String *string - File *string - Array *[]any + String *string + File *string + Array *[]any + ChatMessage *json.RawMessage } type Inputs map[string]Input -func NewInputs(keyVals map[string][]string) Inputs { +var jsonSerializableSchemas = map[string]bool{ + "#/components/schemas/CommonChatSchemaDeveloperMessage": true, + "#/components/schemas/CommonChatSchemaSystemMessage": true, + "#/components/schemas/CommonChatSchemaUserMessage": true, + "#/components/schemas/CommonChatSchemaAssistantMessage": true, + "#/components/schemas/CommonChatSchemaToolMessage": true, + "#/components/schemas/CommonChatSchemaFunctionMessage": true, +} + +func NewInputs(keyVals map[string][]string, schema *openapi3.T) (Inputs, error) { + var inputComponent *openapi3.SchemaRef + for name, component := range schema.Components.Schemas { + if name == "Input" { + inputComponent = component + break + } + } + input := Inputs{} for key, vals := range keyVals { if len(vals) == 1 { val := vals[0] - if strings.HasPrefix(val, "@") { + + // Check if we should explicitly parse the JSON based on a known schema + if inputComponent != nil { + properties, err := inputComponent.JSONLookup("properties") + if err != nil { + return input, err + } + propertiesSchemas := properties.(openapi3.Schemas) + messages, err := propertiesSchemas.JSONLookup("messages") + if err != nil { + return input, err + } + messagesSchemas := messages.(*openapi3.Schema) + found := false + for _, schemaRef := range messagesSchemas.Items.Value.AnyOf { + if _, ok := jsonSerializableSchemas[schemaRef.Ref]; ok { + found = true + message := json.RawMessage(val) + input[key] = Input{ChatMessage: &message} + break + } + } + if found { + continue + } + } + + switch { + case strings.HasPrefix(val, "@"): val = val[1:] input[key] = Input{File: &val} - } else { + + default: input[key] = Input{String: &val} } } else if len(vals) > 1 { @@ -39,7 +88,7 @@ func NewInputs(keyVals map[string][]string) Inputs { input[key] = Input{Array: &anyVals} } } - return input + return input, nil } func NewInputsWithBaseDir(keyVals map[string]string, baseDir string) Inputs { @@ -86,6 +135,8 @@ func (inputs *Inputs) toMap() (map[string]any, error) { } } keyVals[key] = dataURLs + case input.ChatMessage != nil: + keyVals[key] = *input.ChatMessage } } return keyVals, nil diff --git a/pkg/predict/input_test.go b/pkg/predict/input_test.go new file mode 100644 index 0000000000..cbd5ff4a74 --- /dev/null +++ b/pkg/predict/input_test.go @@ -0,0 +1,24 @@ +package predict + +import ( + "encoding/json" + "testing" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/stretchr/testify/require" +) + +func TestNewInputsChatMessage(t *testing.T) { + chatMessage := `[{"role": "user", "content": "hello"}]` + key := "Key" + expected := json.RawMessage(chatMessage) + keyVals := map[string][]string{ + key: {chatMessage}, + } + openapiBody := `{"components":{"schemas":{"CommonChatSchemaAssistantMessage":{"properties":{"audio":{"$ref":"#/components/schemas/CommonChatSchemaAudio"},"content":{"anyOf":[{"type":"string"},{"items":{"anyOf":[{"$ref":"#/components/schemas/CommonChatSchemaTextContentPart"},{"$ref":"#/components/schemas/CommonChatSchemaRefuslaContentPart"}]},"type":"array"}],"title":"Content"},"function_call":{"$ref":"#/components/schemas/CommonChatSchemaFunction"},"name":{"title":"Name","type":"string"},"role":{"title":"Role","type":"string"},"tool_calls":{"items":{"$ref":"#/components/schemas/CommonChatSchemaToolCall"},"title":"Tool Calls","type":"array"}},"title":"CommonChatSchemaAssistantMessage","type":"object"},"CommonChatSchemaAudio":{"properties":{"id":{"title":"Id","type":"string"}},"required":["id"],"title":"CommonChatSchemaAudio","type":"object"},"CommonChatSchemaAudioContentPart":{"properties":{"input_audio":{"$ref":"#/components/schemas/CommonChatSchemaInputAudio"},"type":{"title":"Type","type":"string"}},"required":["type","input_audio"],"title":"CommonChatSchemaAudioContentPart","type":"object"},"CommonChatSchemaDeveloperMessage":{"properties":{"content":{"anyOf":[{"type":"string"},{"items":{"type":"string"},"type":"array"}],"title":"Content"},"name":{"title":"Name","type":"string"},"role":{"title":"Role","type":"string"}},"title":"CommonChatSchemaDeveloperMessage","type":"object"},"CommonChatSchemaFunction":{"properties":{"arguments":{"title":"Arguments","type":"string"},"name":{"title":"Name","type":"string"}},"required":["name","arguments"],"title":"CommonChatSchemaFunction","type":"object"},"CommonChatSchemaFunctionMessage":{"properties":{"content":{"title":"Content","type":"string"},"name":{"title":"Name","type":"string"},"role":{"title":"Role","type":"string"}},"title":"CommonChatSchemaFunctionMessage","type":"object"},"CommonChatSchemaImageContentPart":{"properties":{"image_url":{"$ref":"#/components/schemas/CommonChatSchemaImageURL"},"type":{"title":"Type","type":"string"}},"required":["type","image_url"],"title":"CommonChatSchemaImageContentPart","type":"object"},"CommonChatSchemaImageURL":{"properties":{"detail":{"title":"Detail","type":"string"},"url":{"title":"Url","type":"string"}},"required":["url","detail"],"title":"CommonChatSchemaImageURL","type":"object"},"CommonChatSchemaInputAudio":{"properties":{"data":{"title":"Data","type":"string"},"format":{"title":"Format","type":"string"}},"required":["data","format"],"title":"CommonChatSchemaInputAudio","type":"object"},"CommonChatSchemaRefuslaContentPart":{"properties":{"refusal":{"title":"Refusal","type":"string"},"type":{"title":"Type","type":"string"}},"required":["type","refusal"],"title":"CommonChatSchemaRefuslaContentPart","type":"object"},"CommonChatSchemaSystemMessage":{"properties":{"content":{"anyOf":[{"type":"string"},{"items":{"type":"string"},"type":"array"}],"title":"Content"},"name":{"title":"Name","type":"string"},"role":{"title":"Role","type":"string"}},"title":"CommonChatSchemaSystemMessage","type":"object"},"CommonChatSchemaTextContentPart":{"properties":{"text":{"title":"Text","type":"string"},"type":{"title":"Type","type":"string"}},"required":["type","text"],"title":"CommonChatSchemaTextContentPart","type":"object"},"CommonChatSchemaToolCall":{"properties":{"function":{"$ref":"#/components/schemas/CommonChatSchemaFunction"},"id":{"title":"Id","type":"string"},"type":{"title":"Type","type":"string"}},"required":["id","type","function"],"title":"CommonChatSchemaToolCall","type":"object"},"CommonChatSchemaToolMessage":{"properties":{"content":{"anyOf":[{"type":"string"},{"items":{"type":"string"},"type":"array"}],"title":"Content"},"role":{"title":"Role","type":"string"},"tool_call_id":{"title":"Tool Call Id","type":"string"}},"required":["role","content","tool_call_id"],"title":"CommonChatSchemaToolMessage","type":"object"},"CommonChatSchemaUserMessage":{"properties":{"content":{"anyOf":[{"type":"string"},{"items":{"anyOf":[{"$ref":"#/components/schemas/CommonChatSchemaTextContentPart"},{"$ref":"#/components/schemas/CommonChatSchemaImageContentPart"},{"$ref":"#/components/schemas/CommonChatSchemaAudioContentPart"}]},"type":"array"}],"title":"Content"},"name":{"title":"Name","type":"string"},"role":{"title":"Role","type":"string"}},"title":"CommonChatSchemaUserMessage","type":"object"},"HTTPValidationError":{"properties":{"detail":{"items":{"$ref":"#/components/schemas/ValidationError"},"title":"Detail","type":"array"}},"title":"HTTPValidationError","type":"object"},"Input":{"properties":{"messages":{"items":{"anyOf":[{"$ref":"#/components/schemas/CommonChatSchemaDeveloperMessage"},{"$ref":"#/components/schemas/CommonChatSchemaSystemMessage"},{"$ref":"#/components/schemas/CommonChatSchemaUserMessage"},{"$ref":"#/components/schemas/CommonChatSchemaAssistantMessage"},{"$ref":"#/components/schemas/CommonChatSchemaToolMessage"},{"$ref":"#/components/schemas/CommonChatSchemaFunctionMessage"}]},"title":"Messages","type":"array","x-order":0}},"required":["messages"],"title":"Input","type":"object"},"Output":{"title":"Output","type":"string"},"PredictionRequest":{"properties":{"created_at":{"format":"date-time","title":"Created At","type":"string"},"id":{"title":"Id","type":"string"},"input":{"$ref":"#/components/schemas/Input"},"output_file_prefix":{"title":"Output File Prefix","type":"string"},"webhook":{"format":"uri","maxLength":65536,"minLength":1,"title":"Webhook","type":"string"},"webhook_events_filter":{"default":["start","output","logs","completed"],"items":{"$ref":"#/components/schemas/WebhookEvent"},"type":"array"}},"title":"PredictionRequest","type":"object"},"PredictionResponse":{"properties":{"completed_at":{"format":"date-time","title":"Completed At","type":"string"},"created_at":{"format":"date-time","title":"Created At","type":"string"},"error":{"title":"Error","type":"string"},"id":{"title":"Id","type":"string"},"input":{"$ref":"#/components/schemas/Input"},"logs":{"default":"","title":"Logs","type":"string"},"metrics":{"title":"Metrics","type":"object"},"output":{"$ref":"#/components/schemas/Output"},"started_at":{"format":"date-time","title":"Started At","type":"string"},"status":{"$ref":"#/components/schemas/Status"},"version":{"title":"Version","type":"string"}},"title":"PredictionResponse","type":"object"},"Status":{"description":"An enumeration.","enum":["starting","processing","succeeded","canceled","failed"],"title":"Status","type":"string"},"ValidationError":{"properties":{"loc":{"items":{"anyOf":[{"type":"string"},{"type":"integer"}]},"title":"Location","type":"array"},"msg":{"title":"Message","type":"string"},"type":{"title":"Error Type","type":"string"}},"required":["loc","msg","type"],"title":"ValidationError","type":"object"},"WebhookEvent":{"description":"An enumeration.","enum":["start","output","logs","completed"],"title":"WebhookEvent","type":"string"}}},"info":{"title":"Cog","version":"0.1.0"},"openapi":"3.0.2","paths":{"/":{"get":{"operationId":"root__get","responses":{"200":{"content":{"application/json":{"schema":{"title":"Response Root Get"}}},"description":"Successful Response"}},"summary":"Root"}},"/health-check":{"get":{"operationId":"healthcheck_health_check_get","responses":{"200":{"content":{"application/json":{"schema":{"title":"Response Healthcheck Health Check Get"}}},"description":"Successful Response"}},"summary":"Healthcheck"}},"/predictions":{"post":{"description":"Run a single prediction on the model","operationId":"predict_predictions_post","parameters":[{"in":"header","name":"prefer","required":false,"schema":{"title":"Prefer","type":"string"}}],"requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/PredictionRequest"}}}},"responses":{"200":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/PredictionResponse"}}},"description":"Successful Response"},"422":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}},"description":"Validation Error"}},"summary":"Predict"}},"/predictions/{prediction_id}":{"put":{"description":"Run a single prediction on the model (idempotent creation).","operationId":"predict_idempotent_predictions__prediction_id__put","parameters":[{"in":"path","name":"prediction_id","required":true,"schema":{"title":"Prediction ID","type":"string"}},{"in":"header","name":"prefer","required":false,"schema":{"title":"Prefer","type":"string"}}],"requestBody":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/PredictionRequest"}],"title":"Prediction Request"}}},"required":true},"responses":{"200":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/PredictionResponse"}}},"description":"Successful Response"},"422":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}},"description":"Validation Error"}},"summary":"Predict Idempotent"}},"/predictions/{prediction_id}/cancel":{"post":{"description":"Cancel a running prediction","operationId":"cancel_predictions__prediction_id__cancel_post","parameters":[{"in":"path","name":"prediction_id","required":true,"schema":{"title":"Prediction ID","type":"string"}}],"responses":{"200":{"content":{"application/json":{"schema":{"title":"Response Cancel Predictions Prediction Id Cancel Post"}}},"description":"Successful Response"},"422":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}},"description":"Validation Error"}},"summary":"Cancel"}},"/shutdown":{"post":{"operationId":"start_shutdown_shutdown_post","responses":{"200":{"content":{"application/json":{"schema":{"title":"Response Start Shutdown Shutdown Post"}}},"description":"Successful Response"}},"summary":"Start Shutdown"}}}}` + schema, err := openapi3.NewLoader().LoadFromData([]byte(openapiBody)) + require.NoError(t, err) + inputs, err := NewInputs(keyVals, schema) + require.NoError(t, err) + require.Equal(t, expected, *inputs[key].ChatMessage) +} From fda86d47269171331818990e7dd18faa618864a4 Mon Sep 17 00:00:00 2001 From: Will Sackfield Date: Mon, 10 Mar 2025 16:21:13 -0400 Subject: [PATCH 7/7] Fix no messages within Input schema * This is valid in OpenAPI --- pkg/predict/input.go | 28 ++++++++++++++-------------- pkg/predict/input_test.go | 14 ++++++++++++++ 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/pkg/predict/input.go b/pkg/predict/input.go index 509306c298..9e21709f87 100644 --- a/pkg/predict/input.go +++ b/pkg/predict/input.go @@ -54,21 +54,21 @@ func NewInputs(keyVals map[string][]string, schema *openapi3.T) (Inputs, error) } propertiesSchemas := properties.(openapi3.Schemas) messages, err := propertiesSchemas.JSONLookup("messages") - if err != nil { - return input, err - } - messagesSchemas := messages.(*openapi3.Schema) - found := false - for _, schemaRef := range messagesSchemas.Items.Value.AnyOf { - if _, ok := jsonSerializableSchemas[schemaRef.Ref]; ok { - found = true - message := json.RawMessage(val) - input[key] = Input{ChatMessage: &message} - break + // If there is an error it means messages was not found, this is valid for an OpenAPI schema. + if err == nil { + messagesSchemas := messages.(*openapi3.Schema) + found := false + for _, schemaRef := range messagesSchemas.Items.Value.AnyOf { + if _, ok := jsonSerializableSchemas[schemaRef.Ref]; ok { + found = true + message := json.RawMessage(val) + input[key] = Input{ChatMessage: &message} + break + } + } + if found { + continue } - } - if found { - continue } } diff --git a/pkg/predict/input_test.go b/pkg/predict/input_test.go index cbd5ff4a74..0b9c536a26 100644 --- a/pkg/predict/input_test.go +++ b/pkg/predict/input_test.go @@ -22,3 +22,17 @@ func TestNewInputsChatMessage(t *testing.T) { require.NoError(t, err) require.Equal(t, expected, *inputs[key].ChatMessage) } + +func TestNewInputsWithoutMessages(t *testing.T) { + expected := "world" + key := "s" + keyVals := map[string][]string{ + key: {expected}, + } + openapiBody := `{"components":{"schemas":{"CommonChatSchemaAssistantMessage":{"properties":{"audio":{"$ref":"#/components/schemas/CommonChatSchemaAudio"},"content":{"anyOf":[{"type":"string"},{"items":{"anyOf":[{"$ref":"#/components/schemas/CommonChatSchemaTextContentPart"},{"$ref":"#/components/schemas/CommonChatSchemaRefuslaContentPart"}]},"type":"array"}],"title":"Content"},"function_call":{"$ref":"#/components/schemas/CommonChatSchemaFunction"},"name":{"title":"Name","type":"string"},"role":{"title":"Role","type":"string"},"tool_calls":{"items":{"$ref":"#/components/schemas/CommonChatSchemaToolCall"},"title":"Tool Calls","type":"array"}},"title":"CommonChatSchemaAssistantMessage","type":"object"},"CommonChatSchemaAudio":{"properties":{"id":{"title":"Id","type":"string"}},"required":["id"],"title":"CommonChatSchemaAudio","type":"object"},"CommonChatSchemaAudioContentPart":{"properties":{"input_audio":{"$ref":"#/components/schemas/CommonChatSchemaInputAudio"},"type":{"title":"Type","type":"string"}},"required":["type","input_audio"],"title":"CommonChatSchemaAudioContentPart","type":"object"},"CommonChatSchemaDeveloperMessage":{"properties":{"content":{"anyOf":[{"type":"string"},{"items":{"type":"string"},"type":"array"}],"title":"Content"},"name":{"title":"Name","type":"string"},"role":{"title":"Role","type":"string"}},"title":"CommonChatSchemaDeveloperMessage","type":"object"},"CommonChatSchemaFunction":{"properties":{"arguments":{"title":"Arguments","type":"string"},"name":{"title":"Name","type":"string"}},"required":["name","arguments"],"title":"CommonChatSchemaFunction","type":"object"},"CommonChatSchemaFunctionMessage":{"properties":{"content":{"title":"Content","type":"string"},"name":{"title":"Name","type":"string"},"role":{"title":"Role","type":"string"}},"title":"CommonChatSchemaFunctionMessage","type":"object"},"CommonChatSchemaImageContentPart":{"properties":{"image_url":{"$ref":"#/components/schemas/CommonChatSchemaImageURL"},"type":{"title":"Type","type":"string"}},"required":["type","image_url"],"title":"CommonChatSchemaImageContentPart","type":"object"},"CommonChatSchemaImageURL":{"properties":{"detail":{"title":"Detail","type":"string"},"url":{"title":"Url","type":"string"}},"required":["url","detail"],"title":"CommonChatSchemaImageURL","type":"object"},"CommonChatSchemaInputAudio":{"properties":{"data":{"title":"Data","type":"string"},"format":{"title":"Format","type":"string"}},"required":["data","format"],"title":"CommonChatSchemaInputAudio","type":"object"},"CommonChatSchemaRefuslaContentPart":{"properties":{"refusal":{"title":"Refusal","type":"string"},"type":{"title":"Type","type":"string"}},"required":["type","refusal"],"title":"CommonChatSchemaRefuslaContentPart","type":"object"},"CommonChatSchemaSystemMessage":{"properties":{"content":{"anyOf":[{"type":"string"},{"items":{"type":"string"},"type":"array"}],"title":"Content"},"name":{"title":"Name","type":"string"},"role":{"title":"Role","type":"string"}},"title":"CommonChatSchemaSystemMessage","type":"object"},"CommonChatSchemaTextContentPart":{"properties":{"text":{"title":"Text","type":"string"},"type":{"title":"Type","type":"string"}},"required":["type","text"],"title":"CommonChatSchemaTextContentPart","type":"object"},"CommonChatSchemaToolCall":{"properties":{"function":{"$ref":"#/components/schemas/CommonChatSchemaFunction"},"id":{"title":"Id","type":"string"},"type":{"title":"Type","type":"string"}},"required":["id","type","function"],"title":"CommonChatSchemaToolCall","type":"object"},"CommonChatSchemaToolMessage":{"properties":{"content":{"anyOf":[{"type":"string"},{"items":{"type":"string"},"type":"array"}],"title":"Content"},"role":{"title":"Role","type":"string"},"tool_call_id":{"title":"Tool Call Id","type":"string"}},"required":["role","content","tool_call_id"],"title":"CommonChatSchemaToolMessage","type":"object"},"CommonChatSchemaUserMessage":{"properties":{"content":{"anyOf":[{"type":"string"},{"items":{"anyOf":[{"$ref":"#/components/schemas/CommonChatSchemaTextContentPart"},{"$ref":"#/components/schemas/CommonChatSchemaImageContentPart"},{"$ref":"#/components/schemas/CommonChatSchemaAudioContentPart"}]},"type":"array"}],"title":"Content"},"name":{"title":"Name","type":"string"},"role":{"title":"Role","type":"string"}},"title":"CommonChatSchemaUserMessage","type":"object"},"HTTPValidationError":{"properties":{"detail":{"items":{"$ref":"#/components/schemas/ValidationError"},"title":"Detail","type":"array"}},"title":"HTTPValidationError","type":"object"},"Input":{"properties":{},"required":[],"title":"Input","type":"object"},"Output":{"title":"Output","type":"string"},"PredictionRequest":{"properties":{"created_at":{"format":"date-time","title":"Created At","type":"string"},"id":{"title":"Id","type":"string"},"input":{"$ref":"#/components/schemas/Input"},"output_file_prefix":{"title":"Output File Prefix","type":"string"},"webhook":{"format":"uri","maxLength":65536,"minLength":1,"title":"Webhook","type":"string"},"webhook_events_filter":{"default":["start","output","logs","completed"],"items":{"$ref":"#/components/schemas/WebhookEvent"},"type":"array"}},"title":"PredictionRequest","type":"object"},"PredictionResponse":{"properties":{"completed_at":{"format":"date-time","title":"Completed At","type":"string"},"created_at":{"format":"date-time","title":"Created At","type":"string"},"error":{"title":"Error","type":"string"},"id":{"title":"Id","type":"string"},"input":{"$ref":"#/components/schemas/Input"},"logs":{"default":"","title":"Logs","type":"string"},"metrics":{"title":"Metrics","type":"object"},"output":{"$ref":"#/components/schemas/Output"},"started_at":{"format":"date-time","title":"Started At","type":"string"},"status":{"$ref":"#/components/schemas/Status"},"version":{"title":"Version","type":"string"}},"title":"PredictionResponse","type":"object"},"Status":{"description":"An enumeration.","enum":["starting","processing","succeeded","canceled","failed"],"title":"Status","type":"string"},"ValidationError":{"properties":{"loc":{"items":{"anyOf":[{"type":"string"},{"type":"integer"}]},"title":"Location","type":"array"},"msg":{"title":"Message","type":"string"},"type":{"title":"Error Type","type":"string"}},"required":["loc","msg","type"],"title":"ValidationError","type":"object"},"WebhookEvent":{"description":"An enumeration.","enum":["start","output","logs","completed"],"title":"WebhookEvent","type":"string"}}},"info":{"title":"Cog","version":"0.1.0"},"openapi":"3.0.2","paths":{"/":{"get":{"operationId":"root__get","responses":{"200":{"content":{"application/json":{"schema":{"title":"Response Root Get"}}},"description":"Successful Response"}},"summary":"Root"}},"/health-check":{"get":{"operationId":"healthcheck_health_check_get","responses":{"200":{"content":{"application/json":{"schema":{"title":"Response Healthcheck Health Check Get"}}},"description":"Successful Response"}},"summary":"Healthcheck"}},"/predictions":{"post":{"description":"Run a single prediction on the model","operationId":"predict_predictions_post","parameters":[{"in":"header","name":"prefer","required":false,"schema":{"title":"Prefer","type":"string"}}],"requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/PredictionRequest"}}}},"responses":{"200":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/PredictionResponse"}}},"description":"Successful Response"},"422":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}},"description":"Validation Error"}},"summary":"Predict"}},"/predictions/{prediction_id}":{"put":{"description":"Run a single prediction on the model (idempotent creation).","operationId":"predict_idempotent_predictions__prediction_id__put","parameters":[{"in":"path","name":"prediction_id","required":true,"schema":{"title":"Prediction ID","type":"string"}},{"in":"header","name":"prefer","required":false,"schema":{"title":"Prefer","type":"string"}}],"requestBody":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/PredictionRequest"}],"title":"Prediction Request"}}},"required":true},"responses":{"200":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/PredictionResponse"}}},"description":"Successful Response"},"422":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}},"description":"Validation Error"}},"summary":"Predict Idempotent"}},"/predictions/{prediction_id}/cancel":{"post":{"description":"Cancel a running prediction","operationId":"cancel_predictions__prediction_id__cancel_post","parameters":[{"in":"path","name":"prediction_id","required":true,"schema":{"title":"Prediction ID","type":"string"}}],"responses":{"200":{"content":{"application/json":{"schema":{"title":"Response Cancel Predictions Prediction Id Cancel Post"}}},"description":"Successful Response"},"422":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}},"description":"Validation Error"}},"summary":"Cancel"}},"/shutdown":{"post":{"operationId":"start_shutdown_shutdown_post","responses":{"200":{"content":{"application/json":{"schema":{"title":"Response Start Shutdown Shutdown Post"}}},"description":"Successful Response"}},"summary":"Start Shutdown"}}}}` + schema, err := openapi3.NewLoader().LoadFromData([]byte(openapiBody)) + require.NoError(t, err) + inputs, err := NewInputs(keyVals, schema) + require.NoError(t, err) + require.Equal(t, expected, *inputs[key].String) +}