Skip to content

Commit b979064

Browse files
Fix null cases for openai prompt and embeddings (#2457)
1 parent bfa8c46 commit b979064

File tree

4 files changed

+84
-48
lines changed

4 files changed

+84
-48
lines changed

cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import org.apache.spark.ml.functions.array_to_vector
1616
import org.apache.spark.sql.{DataFrame, Dataset, Row}
1717
import org.apache.spark.sql.types._
1818
import scala.language.existentials
19-
import org.apache.spark.sql.functions.{col, element_at, struct}
19+
import org.apache.spark.sql.functions.{col, element_at, struct, when}
2020
import spray.json.DefaultJsonProtocol._
2121
import spray.json._
2222
import HasReturnUsage.UsageMappings
@@ -101,13 +101,18 @@ class OpenAIEmbedding (override val uid: String) extends OpenAIServicesBase(uid)
101101
)
102102
parsed.withColumn(
103103
getOutputCol,
104-
struct(
105-
vectorCol.alias("response"),
106-
usageCol.alias("usage")
104+
when(responseCol.isNotNull,
105+
struct(
106+
vectorCol.alias("response"),
107+
usageCol.alias("usage")
108+
)
107109
)
108110
)
109111
} else {
110-
parsed.withColumn(getOutputCol, vectorCol)
112+
parsed.withColumn(
113+
getOutputCol,
114+
when(responseCol.isNotNull, vectorCol)
115+
)
111116
}
112117
}
113118

cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,10 +277,15 @@ class OpenAIPrompt(override val uid: String) extends Transformer
277277
val usageCol = normalizeUsageColumn(responseCol.getField("usage"), mapping)
278278
df.withColumn(
279279
getOutputCol,
280-
F.struct(parsedCol.alias("response"), usageCol.alias("usage"))
280+
F.when(parsedCol.isNotNull,
281+
F.struct(parsedCol.alias("response"), usageCol.alias("usage"))
282+
)
281283
)
282284
case None =>
283-
df.withColumn(getOutputCol, parsedCol)
285+
df.withColumn(
286+
getOutputCol,
287+
F.when(parsedCol.isNotNull, parsedCol)
288+
)
284289
}
285290
}
286291

cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbeddingsSuite.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,36 @@ class OpenAIEmbeddingsSuite extends TransformerFuzzing[OpenAIEmbedding] with Ope
173173
assert(outputDetails != null)
174174
}
175175

176+
test("null input returns null output with returnUsage false") {
177+
val dfWithNull = Seq(
178+
Some("Once upon a time"),
179+
None,
180+
Some("SynapseML is ")
181+
).toDF("text")
182+
183+
val e = usageEmbedding("null_test_vec")
184+
val results = e.transform(dfWithNull).collect()
185+
186+
assert(results(0).getAs[Vector]("null_test_vec") != null)
187+
assert(results(1).getAs[Vector]("null_test_vec") == null)
188+
assert(results(2).getAs[Vector]("null_test_vec") != null)
189+
}
190+
191+
test("null input returns null output with returnUsage true") {
192+
val dfWithNull = Seq(
193+
Some("Once upon a time"),
194+
None,
195+
Some("SynapseML is ")
196+
).toDF("text")
197+
198+
val e = usageEmbedding("null_test_usage").setReturnUsage(true)
199+
val results = e.transform(dfWithNull).collect()
200+
201+
assert(results(0).getAs[Row]("null_test_usage") != null)
202+
assert(results(1).getAs[Row]("null_test_usage") == null)
203+
assert(results(2).getAs[Row]("null_test_usage") != null)
204+
}
205+
176206

177207
override def testObjects(): Seq[TestObject[OpenAIEmbedding]] =
178208
Seq(new TestObject(embedding, df))

cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala

Lines changed: 37 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -204,42 +204,38 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
204204
assert(outputDetails != null)
205205
}
206206

207-
lazy val promptGpt4: OpenAIPrompt = new OpenAIPrompt()
208-
.setSubscriptionKey(openAIAPIKey)
209-
.setDeploymentName(deploymentName)
210-
.setCustomServiceName(openAIServiceName)
211-
.setOutputCol("outParsed")
212-
.setTemperature(0)
207+
test("null input returns null output with returnUsage false") {
208+
val dfWithNull = Seq(
209+
(Some("apple"), "fruits"),
210+
(None, "cars"),
211+
(Some("cake"), "dishes")
212+
).toDF("text", "category")
213+
214+
val p = usagePrompt("null_test")
215+
val results = p.transform(dfWithNull).select("null_test").collect()
216+
217+
assert(results(0).getSeq[String](0) != null)
218+
assert(results(1).get(0) == null)
219+
assert(results(2).getSeq[String](0) != null)
220+
}
213221

214-
test("Basic Usage - Gpt 4") {
215-
val nonNullCount = promptGpt4
216-
.setPromptTemplate("give me a comma separated list of 5 {category}, starting with {text} ")
217-
.setPostProcessing("csv")
218-
.transform(df)
219-
.select("outParsed")
220-
.collect()
221-
.count(r => Option(r.getSeq[String](0)).isDefined)
222+
test("null input returns null output with returnUsage true") {
223+
val dfWithNull = Seq(
224+
(Some("apple"), "fruits"),
225+
(None, "cars"),
226+
(Some("cake"), "dishes")
227+
).toDF("text", "category")
222228

223-
assert(nonNullCount == 3)
224-
}
229+
val p = usagePrompt("null_test_usage").setReturnUsage(true)
230+
val results = p.transform(dfWithNull).select("null_test_usage").collect()
225231

226-
test("Basic Usage JSON - Gpt 4") {
227-
promptGpt4.setPromptTemplate(
228-
"""Split a word into prefix and postfix a respond in JSON
229-
|Cherry: {{"prefix": "Che", "suffix": "rry"}}
230-
|{text}:
231-
|""".stripMargin)
232-
.setPostProcessing("json")
233-
.setPostProcessingOptions(Map("jsonSchema" -> "prefix STRING, suffix STRING"))
234-
.transform(df)
235-
.select("outParsed")
236-
.where(col("outParsed").isNotNull)
237-
.collect()
238-
.foreach(r => assert(r.getStruct(0).getString(0).nonEmpty))
232+
assert(results(0).getStruct(0) != null)
233+
assert(results(1).get(0) == null)
234+
assert(results(2).getStruct(0) != null)
239235
}
240236

241-
test("Basic Usage JSON - Gpt 4 without explicit post-processing") {
242-
promptGpt4.setPromptTemplate(
237+
test("Basic Usage JSON - without explicit post-processing") {
238+
prompt.setPromptTemplate(
243239
"""Split a word into prefix and postfix a respond in JSON
244240
|Cherry: {{"prefix": "Che", "suffix": "rry"}}
245241
|{text}:
@@ -252,8 +248,8 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
252248
.foreach(r => assert(r.getStruct(0).getString(0).nonEmpty))
253249
}
254250

255-
test("Setting and Keeping Messages Col - Gpt 4") {
256-
promptGpt4.setMessagesCol("messages")
251+
test("Setting and Keeping Messages Col") {
252+
prompt.setMessagesCol("messages")
257253
.setDropPrompt(false)
258254
.setPromptTemplate(
259255
"""Classify each word as to whether they are an F1 team or not
@@ -268,23 +264,23 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
268264
.foreach(r => assert(r.get(0) != null))
269265
}
270266

271-
test("Basic Usage JSON - Gpt 4o with responseFormat") {
272-
val promptGpt4o: OpenAIPrompt = new OpenAIPrompt()
267+
test("json_object Response Format Usage") {
268+
val promptJSONObject: OpenAIPrompt = new OpenAIPrompt()
273269
.setSubscriptionKey(openAIAPIKey)
274270
.setDeploymentName(deploymentName)
275271
.setCustomServiceName(openAIServiceName)
276272
.setOutputCol("outParsed")
277273
.setTemperature(0)
278274
.setPromptTemplate(
279-
"""Split a word into prefix and postfix
275+
"""Split a word into prefix and postfix in JSON format
280276
|Cherry: {{"prefix": "Che", "suffix": "rry"}}
281277
|{text}:
282278
|""".stripMargin)
283279
.setResponseFormat("json_object")
284280
.setPostProcessingOptions(Map("jsonSchema" -> "prefix STRING, suffix STRING"))
285281

286282

287-
promptGpt4o.transform(df)
283+
promptJSONObject.transform(df)
288284
.select("outParsed")
289285
.where(col("outParsed").isNotNull)
290286
.collect()
@@ -331,21 +327,21 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
331327
lazy val customRootUrlValue: String = sys.env.getOrElse("CUSTOM_ROOT_URL", "")
332328
lazy val customHeadersValues: Map[String, String] = Map("X-ModelType" -> "gpt-4-turbo-chat-completions")
333329

334-
lazy val customPromptGpt4: OpenAIPrompt = new OpenAIPrompt()
330+
lazy val customPrompt: OpenAIPrompt = new OpenAIPrompt()
335331
.setCustomUrlRoot(customRootUrlValue)
336332
.setOutputCol("outParsed")
337333
.setTemperature(0)
338334

339335
if (accessToken.isEmpty) {
340-
customPromptGpt4.setSubscriptionKey(openAIAPIKey)
336+
customPrompt.setSubscriptionKey(openAIAPIKey)
341337
.setDeploymentName(deploymentName)
342338
.setCustomServiceName(openAIServiceName)
343339
} else {
344-
customPromptGpt4.setAADToken(accessToken)
340+
customPrompt.setAADToken(accessToken)
345341
.setCustomHeaders(customHeadersValues)
346342
}
347343

348-
customPromptGpt4.setPromptTemplate("give me a comma separated list of 5 {category}, starting with {text} ")
344+
customPrompt.setPromptTemplate("give me a comma separated list of 5 {category}, starting with {text} ")
349345
.setPostProcessing("csv")
350346
.transform(df)
351347
.select("outParsed")

0 commit comments

Comments
 (0)