Skip to content

feat: giving the gift of vision to OpenAIChatCompletions #2356

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,41 @@ class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase(

override def responseDataType: DataType = ChatCompletionResponse.schema

private[openai] def getStringEntity(messages: Seq[Row], optionalParams: Map[String, Any]): StringEntity = {
val mappedMessages: Seq[Map[String, String]] = messages.map { m =>
Seq("role", "content", "name").map(n =>
n -> Option(m.getAs[String](n))
).toMap.filter(_._2.isDefined).mapValues(_.get)
private[openai] def getStringEntity(
messages: Seq[Row],
optionalParams: Map[String, Any]
): StringEntity = {
import OpenAIJsonProtocol._

val mappedMessages = messages.map { row =>
val role = row.getAs[String]("role")
val name = row.getAs[String]("name")

val maybeContent = Option(row.getAs[String]("content"))
val maybeItems = Option(row.getAs[Seq[Row]]("contentList")).map { rows =>
rows.map { r =>
val ctype = r.getAs[String]("type")
val text = r.getAs[String]("text")
val imgUrlRow = r.getAs[Row]("image_url")
val maybeImgUrl = Option(imgUrlRow).map { irow =>
ImageUrl(irow.getAs[String]("url"))
}
OpenAIContentItem(ctype, Option(text), maybeImgUrl)
}
}

OpenAIMessage(
role = role,
content = maybeContent,
contentList = maybeItems,
name = Option(name)
)
}
val fullPayload = optionalParams.updated("messages", mappedMessages)
new StringEntity(fullPayload.toJson.compactPrint, ContentType.APPLICATION_JSON)

val messagesJson: JsValue = mappedMessages.toJson
val paramsObj = optionalParams.toJson.asJsObject
val jsonString = JsObject(paramsObj.fields + ("messages" -> messagesJson)).compactPrint
new StringEntity(jsonString, ContentType.APPLICATION_JSON)
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.core.schema.SparkBindings
import org.apache.spark.sql.Row
import spray.json.{DefaultJsonProtocol, RootJsonFormat}
import spray.json._

object CompletionResponse extends SparkBindings[CompletionResponse]

Expand Down Expand Up @@ -35,7 +34,43 @@ case class EmbeddingObject(`object`: String,
embedding: Array[Double],
index: Int)

case class OpenAIMessage(role: String, content: String, name: Option[String] = None)
case class OpenAIMessage(
role: String,
content: Option[String] = None,
contentList: Option[Seq[OpenAIContentItem]] = None,
name: Option[String] = None
)
case class ImageUrl(url: String)

case class OpenAIContentItem(`type`: String,
text: Option[String] = None,
image_url: Option[ImageUrl] = None)

object OpenAIMessage {

def apply(role: String, content: String): OpenAIMessage =
new OpenAIMessage(role, content = Some(content), name = None)

def apply(role: String, content: String, name: Option[String]): OpenAIMessage =
new OpenAIMessage(role, content = Some(content), name = name)

def apply(role: String, seq: Seq[OpenAIContentItem]): OpenAIMessage = {
new OpenAIMessage(role, contentList = Some(seq), name = None)
}

def apply(role: String, seq: Seq[OpenAIContentItem], name: Option[String]): OpenAIMessage = {
new OpenAIMessage(role, contentList = Some(seq), name = name)
}

def create(
role: String,
content: Option[String],
contentList: Option[Seq[OpenAIContentItem]],
name: Option[String]
): OpenAIMessage = new OpenAIMessage(role, content, contentList, name)
}



case class OpenAIChatChoice(message: OpenAIMessage,
index: Long,
Expand All @@ -54,5 +89,53 @@ case class ChatCompletionResponse(id: String,
object ChatCompletionResponse extends SparkBindings[ChatCompletionResponse]

object OpenAIJsonProtocol extends DefaultJsonProtocol {
implicit val MessageEnc: RootJsonFormat[OpenAIMessage] = jsonFormat3(OpenAIMessage.apply)
implicit val ImageUrlEnc: RootJsonFormat[ImageUrl] = jsonFormat1(ImageUrl)
implicit val OpenAIContentItemEnc: RootJsonFormat[OpenAIContentItem] = jsonFormat3(OpenAIContentItem)

implicit object MessageEnc extends RootJsonFormat[OpenAIMessage] {
override def write(msg: OpenAIMessage): JsValue = {
val baseFields = Map(
"role" -> JsString(msg.role)
) ++ msg.name.map("name" -> JsString(_))

val contentField: (String, JsValue) = (msg.content, msg.contentList) match {
case (Some(text), None) =>
"content" -> JsString(text)

case (None, Some(items)) =>
"content" -> JsArray(items.map(_.toJson).toVector)

case (None, None) =>
// how can we put these errors in the Error col?
//serializationError("OpenAIMessage CANNOT have both content & contentItems")
"content" -> JsString("")

case (Some(_), Some(_)) =>
"content" -> JsString("")
Comment on lines +111 to +114
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is actually an error we should yell

//serializationError("OpenAIMessage cannot have both content & contentItems")
}

JsObject(baseFields + contentField)
}

override def read(json: JsValue): OpenAIMessage = {
val obj = json.asJsObject
val role = obj.fields("role").convertTo[String]
val name = obj.fields.get("name").map(_.convertTo[String])

val contentJs = obj.fields.getOrElse("content", JsString(""))

contentJs match {
case JsString(s) =>
OpenAIMessage(role, content = Some(s), contentList = None, name = name)

case JsArray(elements) =>
val items = elements.map(_.convertTo[OpenAIContentItem])
OpenAIMessage(role, content = None, contentList = Some(items), name = name)

case _ =>
deserializationError("content must be string or array")
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ class OpenAIChatCompletionSuite extends TransformerFuzzing[OpenAIChatCompletion]
.setTemperature(0)
.setSubscriptionKey(openAIAPIKey)

lazy val completion4o: OpenAIChatCompletion = new OpenAIChatCompletion()
.setDeploymentName(deploymentNameGpt4o)
.setCustomServiceName(openAIServiceName)
.setApiVersion("2023-05-15")
.setMaxTokens(5000)
.setOutputCol("out")
.setMessagesCol("messages")
.setTemperature(0)
.setSubscriptionKey(openAIAPIKey)
.setMaxTokens(4000)


lazy val goodDf: DataFrame = Seq(
Seq(
Expand All @@ -42,11 +53,26 @@ class OpenAIChatCompletionSuite extends TransformerFuzzing[OpenAIChatCompletion]
)
).toDF("messages")

lazy val imageDf: DataFrame = Seq(
Seq(
OpenAIMessage("system", "You are an AI chatbot that specialises with images"),
OpenAIMessage(
"user",
Seq(
OpenAIContentItem("text", text = Some("What is in this image?")),
OpenAIContentItem("image_url", image_url = Some(ImageUrl(
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg" +
"/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg")))
)
)
)
).toDF("messages")

lazy val badDf: DataFrame = Seq(
Seq(),
Seq(
OpenAIMessage("system", "You are very excited"),
OpenAIMessage("user", null) //scalastyle:ignore null
OpenAIMessage("user", null: String) //scalastyle:ignore null
),
null //scalastyle:ignore null
).toDF("messages")
Expand Down Expand Up @@ -144,6 +170,10 @@ class OpenAIChatCompletionSuite extends TransformerFuzzing[OpenAIChatCompletion]
testCompletion(completion, slowDf)
}

test("Image usage") {
testCompletion(completion4o, imageDf)
}

test("Robustness to bad inputs") {
val results = completion.transform(badDf).collect()
assert(Option(results.head.getAs[Row](completion.getErrorCol)).isDefined)
Expand All @@ -166,7 +196,7 @@ class OpenAIChatCompletionSuite extends TransformerFuzzing[OpenAIChatCompletion]

val messages: Seq[Row] = Seq(
OpenAIMessage("user", "Whats your favorite color")
).toDF("role", "content", "name").collect()
).toDF("role", "content", "contentList", "name").collect()

val optionalParams: Map[String, Any] = completion.getOptionalParams(messages.head)
assert(!optionalParams.contains("response_format"))
Expand Down Expand Up @@ -286,7 +316,7 @@ class OpenAIChatCompletionSuite extends TransformerFuzzing[OpenAIChatCompletion]
val fromRow = ChatCompletionResponse.makeFromRowConverter
completion.transform(df).collect().foreach(r =>
fromRow(r.getAs[Row]("out")).choices.foreach(c =>
assert(c.message.content.length > requiredLength)))
assert(c.message.content.map(_.length).getOrElse(c.message.contentList.get.length) > requiredLength)))
}

override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = {
Expand Down
Loading