@@ -195,13 +195,8 @@ class OpenAIPrompt(override val uid: String) extends Transformer
195195 private val imageExtensions = Set (" jpg" , " jpeg" , " png" , " gif" , " webp" )
196196 private val audioExtensions = Set (" mp3" , " wav" )
197197
198- private def attachmentPlaceholder (columnName : String ): String =
199- s " [Content for column ' $columnName' will be provided later as an attachment.] "
200-
201- private [openai] def applyPathPlaceholders (template : String , pathColumns : Seq [String ]): String = {
202- pathColumns.foldLeft(template) { (current, columnName) =>
203- current.replace(s " { $columnName} " , attachmentPlaceholder(columnName))
204- }
198+ private def extractFilename = udf { (path : String ) =>
199+ Option (path).map(_.trim).filter(_.nonEmpty).map(p => new HPath (p).getName).orNull
205200 }
206201
207202 private def addRAIErrors [T <: OpenAIServicesBase with HasRAIContentFilter ](
@@ -313,32 +308,51 @@ class OpenAIPrompt(override val uid: String) extends Transformer
313308 results.select(results.columns.filter(_ != getErrorCol).map(col) :+ col(getErrorCol): _* )
314309 }
315310
311+ private def processPathColumns (df : DataFrame ): (DataFrame , Seq [String ], Map [String , String ], String ) = {
312+ val columnTypeMap = if (isSet(columnTypes)) getColumnTypes else Map .empty[String , String ]
313+
314+ columnTypeMap.foreach { case (colName, colType) =>
315+ require(colType == " text" || colType == " path" ,
316+ s " Unsupported column type ' $colType' for column ' $colName'. Supported types are 'text' and 'path'. " )
317+ }
318+
319+ val pathColumnNames = columnTypeMap.collect {
320+ case (colName, colType) if colType == " path" => colName
321+ }.toSeq
322+
323+ pathColumnNames.foreach { colName =>
324+ require(
325+ df.columns.contains(colName),
326+ s " Column ' $colName' specified in columnTypes was not found in the DataFrame. " +
327+ s " Available columns: ${df.columns.mkString(" , " )}"
328+ )
329+ }
330+
331+ import com .microsoft .azure .synapse .ml .core .schema .DatasetExtensions ._
332+ val (dfWithFilenames, filenameColMapping) = pathColumnNames.foldLeft((df, Map .empty[String , String ])) {
333+ case ((currentDf, mapping), colName) =>
334+ val filenameCol = currentDf.withDerivativeCol(s " ${colName}_filename " )
335+ (currentDf.withColumn(filenameCol, extractFilename(F .col(colName))), mapping + (colName -> filenameCol))
336+ }
337+
338+ val templateWithFilenameRefs = filenameColMapping.foldLeft(getPromptTemplate) {
339+ case (template, (colName, filenameCol)) =>
340+ template.replace(s " { $colName} " , s " { $filenameCol} " )
341+ }
342+
343+ (dfWithFilenames, pathColumnNames, filenameColMapping, templateWithFilenameRefs)
344+ }
345+
316346 override def transform (dataset : Dataset [_]): DataFrame = {
317347 transferGlobalParamsToParamMap()
318348 logTransform[DataFrame ]({
319349 val df = dataset.toDF
320350 val service = getOpenAIChatService
321- val columnTypeMap = if (isSet(columnTypes)) getColumnTypes else Map .empty[String , String ]
322351
323- columnTypeMap.foreach { case (colName, colType) =>
324- val normalized = colType.toLowerCase
325- require(normalized == " text" || normalized == " path" ,
326- s " Unsupported column type ' $colType' for column ' $colName'. Supported types are 'text' and 'path'. " )
327- }
328-
329- val pathColumnNames = columnTypeMap.collect {
330- case (colName, colType) if colType.equalsIgnoreCase(" path" ) => colName
331- }.toSeq
352+ val (dfWithFilenames, pathColumnNames, filenameColMapping, templateWithFilenameRefs) =
353+ processPathColumns(df)
332354
333- val promptTemplateWithPlaceholders = applyPathPlaceholders(getPromptTemplate, pathColumnNames)
334- val promptCol = Functions .template(promptTemplateWithPlaceholders)
335-
336- pathColumnNames.foreach { colName =>
337- require(
338- df.columns.contains(colName),
339- s " Column ' $colName' specified in columnTypes was not found in the DataFrame. "
340- )
341- }
355+ val promptCol = Functions .template(templateWithFilenameRefs)
342356
343357 val attachmentsColumn : Column =
344358 if (pathColumnNames.nonEmpty) {
@@ -364,9 +378,14 @@ class OpenAIPrompt(override val uid: String) extends Transformer
364378 }
365379
366380 val (dfTemplated, inputColName, serviceConfigured) =
367- configureService(service, df , promptCol, createMessagesUDF, attachmentsColumn)
381+ configureService(service, dfWithFilenames , promptCol, createMessagesUDF, attachmentsColumn)
368382 val result = generateText(serviceConfigured, dfTemplated)
369- if (getDropPrompt) result.drop(inputColName) else result
383+
384+ val resultCleaned = filenameColMapping.values.foldLeft(result) { (df, colName) =>
385+ if (df.columns.contains(colName)) df.drop(colName) else df
386+ }
387+
388+ if (getDropPrompt) resultCleaned.drop(inputColName) else resultCleaned
370389 }, dataset.columns.length)
371390 }
372391
0 commit comments