@@ -204,42 +204,38 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
204204 assert(outputDetails != null )
205205 }
206206
207- lazy val promptGpt4 : OpenAIPrompt = new OpenAIPrompt ()
208- .setSubscriptionKey(openAIAPIKey)
209- .setDeploymentName(deploymentName)
210- .setCustomServiceName(openAIServiceName)
211- .setOutputCol(" outParsed" )
212- .setTemperature(0 )
207+ test(" null input returns null output with returnUsage false" ) {
208+ val dfWithNull = Seq (
209+ (Some (" apple" ), " fruits" ),
210+ (None , " cars" ),
211+ (Some (" cake" ), " dishes" )
212+ ).toDF(" text" , " category" )
213+
214+ val p = usagePrompt(" null_test" )
215+ val results = p.transform(dfWithNull).select(" null_test" ).collect()
216+
217+ assert(results(0 ).getSeq[String ](0 ) != null )
218+ assert(results(1 ).get(0 ) == null )
219+ assert(results(2 ).getSeq[String ](0 ) != null )
220+ }
213221
214- test(" Basic Usage - Gpt 4" ) {
215- val nonNullCount = promptGpt4
216- .setPromptTemplate(" give me a comma separated list of 5 {category}, starting with {text} " )
217- .setPostProcessing(" csv" )
218- .transform(df)
219- .select(" outParsed" )
220- .collect()
221- .count(r => Option (r.getSeq[String ](0 )).isDefined)
222+ test(" null input returns null output with returnUsage true" ) {
223+ val dfWithNull = Seq (
224+ (Some (" apple" ), " fruits" ),
225+ (None , " cars" ),
226+ (Some (" cake" ), " dishes" )
227+ ).toDF(" text" , " category" )
222228
223- assert(nonNullCount == 3 )
224- }
229+ val p = usagePrompt( " null_test_usage " ).setReturnUsage( true )
230+ val results = p.transform(dfWithNull).select( " null_test_usage " ).collect()
225231
226- test(" Basic Usage JSON - Gpt 4" ) {
227- promptGpt4.setPromptTemplate(
228- """ Split a word into prefix and postfix a respond in JSON
229- |Cherry: {{"prefix": "Che", "suffix": "rry"}}
230- |{text}:
231- |""" .stripMargin)
232- .setPostProcessing(" json" )
233- .setPostProcessingOptions(Map (" jsonSchema" -> " prefix STRING, suffix STRING" ))
234- .transform(df)
235- .select(" outParsed" )
236- .where(col(" outParsed" ).isNotNull)
237- .collect()
238- .foreach(r => assert(r.getStruct(0 ).getString(0 ).nonEmpty))
232+ assert(results(0 ).getStruct(0 ) != null )
233+ assert(results(1 ).get(0 ) == null )
234+ assert(results(2 ).getStruct(0 ) != null )
239235 }
240236
241- test(" Basic Usage JSON - Gpt 4 without explicit post-processing" ) {
242- promptGpt4 .setPromptTemplate(
237+ test(" Basic Usage JSON - without explicit post-processing" ) {
238+ prompt .setPromptTemplate(
243239 """ Split a word into prefix and postfix a respond in JSON
244240 |Cherry: {{"prefix": "Che", "suffix": "rry"}}
245241 |{text}:
@@ -252,8 +248,8 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
252248 .foreach(r => assert(r.getStruct(0 ).getString(0 ).nonEmpty))
253249 }
254250
255- test(" Setting and Keeping Messages Col - Gpt 4 " ) {
256- promptGpt4 .setMessagesCol(" messages" )
251+ test(" Setting and Keeping Messages Col" ) {
252+ prompt .setMessagesCol(" messages" )
257253 .setDropPrompt(false )
258254 .setPromptTemplate(
259255 """ Classify each word as to whether they are an F1 team or not
@@ -268,23 +264,23 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
268264 .foreach(r => assert(r.get(0 ) != null ))
269265 }
270266
271- test(" Basic Usage JSON - Gpt 4o with responseFormat " ) {
272- val promptGpt4o : OpenAIPrompt = new OpenAIPrompt ()
267+ test(" json_object Response Format Usage " ) {
268+ val promptJSONObject : OpenAIPrompt = new OpenAIPrompt ()
273269 .setSubscriptionKey(openAIAPIKey)
274270 .setDeploymentName(deploymentName)
275271 .setCustomServiceName(openAIServiceName)
276272 .setOutputCol(" outParsed" )
277273 .setTemperature(0 )
278274 .setPromptTemplate(
279- """ Split a word into prefix and postfix
275+ """ Split a word into prefix and postfix in JSON format
280276 |Cherry: {{"prefix": "Che", "suffix": "rry"}}
281277 |{text}:
282278 |""" .stripMargin)
283279 .setResponseFormat(" json_object" )
284280 .setPostProcessingOptions(Map (" jsonSchema" -> " prefix STRING, suffix STRING" ))
285281
286282
287- promptGpt4o .transform(df)
283+ promptJSONObject .transform(df)
288284 .select(" outParsed" )
289285 .where(col(" outParsed" ).isNotNull)
290286 .collect()
@@ -331,21 +327,21 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
331327 lazy val customRootUrlValue : String = sys.env.getOrElse(" CUSTOM_ROOT_URL" , " " )
332328 lazy val customHeadersValues : Map [String , String ] = Map (" X-ModelType" -> " gpt-4-turbo-chat-completions" )
333329
334- lazy val customPromptGpt4 : OpenAIPrompt = new OpenAIPrompt ()
330+ lazy val customPrompt : OpenAIPrompt = new OpenAIPrompt ()
335331 .setCustomUrlRoot(customRootUrlValue)
336332 .setOutputCol(" outParsed" )
337333 .setTemperature(0 )
338334
339335 if (accessToken.isEmpty) {
340- customPromptGpt4 .setSubscriptionKey(openAIAPIKey)
336+ customPrompt .setSubscriptionKey(openAIAPIKey)
341337 .setDeploymentName(deploymentName)
342338 .setCustomServiceName(openAIServiceName)
343339 } else {
344- customPromptGpt4 .setAADToken(accessToken)
340+ customPrompt .setAADToken(accessToken)
345341 .setCustomHeaders(customHeadersValues)
346342 }
347343
348- customPromptGpt4 .setPromptTemplate(" give me a comma separated list of 5 {category}, starting with {text} " )
344+ customPrompt .setPromptTemplate(" give me a comma separated list of 5 {category}, starting with {text} " )
349345 .setPostProcessing(" csv" )
350346 .transform(df)
351347 .select(" outParsed" )
0 commit comments