@@ -199,13 +199,8 @@ class OpenAIPrompt(override val uid: String) extends Transformer
199199 private val imageExtensions = Set (" jpg" , " jpeg" , " png" , " gif" , " webp" )
200200 private val audioExtensions = Set (" mp3" , " wav" )
201201
202- private def attachmentPlaceholder (columnName : String ): String =
203- s " [Content for column ' $columnName' will be provided later as an attachment.] "
204-
205- private [openai] def applyPathPlaceholders (template : String , pathColumns : Seq [String ]): String = {
206- pathColumns.foldLeft(template) { (current, columnName) =>
207- current.replace(s " { $columnName} " , attachmentPlaceholder(columnName))
208- }
202+ private def extractFilename = udf { (path : String ) =>
203+ Option (path).map(_.trim).filter(_.nonEmpty).map(p => new HPath (p).getName).orNull
209204 }
210205
211206 private def addRAIErrors [T <: OpenAIServicesBase with HasRAIContentFilter ](
@@ -317,32 +312,51 @@ class OpenAIPrompt(override val uid: String) extends Transformer
317312 results.select(results.columns.filter(_ != getErrorCol).map(col) :+ col(getErrorCol): _* )
318313 }
319314
315+ private def processPathColumns (df : DataFrame ): (DataFrame , Seq [String ], Map [String , String ], String ) = {
316+ val columnTypeMap = if (isSet(columnTypes)) getColumnTypes else Map .empty[String , String ]
317+
318+ columnTypeMap.foreach { case (colName, colType) =>
319+ require(colType == " text" || colType == " path" ,
320+ s " Unsupported column type ' $colType' for column ' $colName'. Supported types are 'text' and 'path'. " )
321+ }
322+
323+ val pathColumnNames = columnTypeMap.collect {
324+ case (colName, colType) if colType == " path" => colName
325+ }.toSeq
326+
327+ pathColumnNames.foreach { colName =>
328+ require(
329+ df.columns.contains(colName),
330+ s " Column ' $colName' specified in columnTypes was not found in the DataFrame. " +
331+ s " Available columns: ${df.columns.mkString(" , " )}"
332+ )
333+ }
334+
335+ import com .microsoft .azure .synapse .ml .core .schema .DatasetExtensions ._
336+ val (dfWithFilenames, filenameColMapping) = pathColumnNames.foldLeft((df, Map .empty[String , String ])) {
337+ case ((currentDf, mapping), colName) =>
338+ val filenameCol = currentDf.withDerivativeCol(s " ${colName}_filename " )
339+ (currentDf.withColumn(filenameCol, extractFilename(F .col(colName))), mapping + (colName -> filenameCol))
340+ }
341+
342+ val templateWithFilenameRefs = filenameColMapping.foldLeft(getPromptTemplate) {
343+ case (template, (colName, filenameCol)) =>
344+ template.replace(s " { $colName} " , s " { $filenameCol} " )
345+ }
346+
347+ (dfWithFilenames, pathColumnNames, filenameColMapping, templateWithFilenameRefs)
348+ }
349+
320350 override def transform (dataset : Dataset [_]): DataFrame = {
321351 transferGlobalParamsToParamMap()
322352 logTransform[DataFrame ]({
323353 val df = dataset.toDF
324354 val service = getOpenAIChatService
325- val columnTypeMap = if (isSet(columnTypes)) getColumnTypes else Map .empty[String , String ]
326355
327- columnTypeMap.foreach { case (colName, colType) =>
328- val normalized = colType.toLowerCase
329- require(normalized == " text" || normalized == " path" ,
330- s " Unsupported column type ' $colType' for column ' $colName'. Supported types are 'text' and 'path'. " )
331- }
332-
333- val pathColumnNames = columnTypeMap.collect {
334- case (colName, colType) if colType.equalsIgnoreCase(" path" ) => colName
335- }.toSeq
356+ val (dfWithFilenames, pathColumnNames, filenameColMapping, templateWithFilenameRefs) =
357+ processPathColumns(df)
336358
337- val promptTemplateWithPlaceholders = applyPathPlaceholders(getPromptTemplate, pathColumnNames)
338- val promptCol = Functions .template(promptTemplateWithPlaceholders)
339-
340- pathColumnNames.foreach { colName =>
341- require(
342- df.columns.contains(colName),
343- s " Column ' $colName' specified in columnTypes was not found in the DataFrame. "
344- )
345- }
359+ val promptCol = Functions .template(templateWithFilenameRefs)
346360
347361 val attachmentsColumn : Column =
348362 if (pathColumnNames.nonEmpty) {
@@ -368,9 +382,14 @@ class OpenAIPrompt(override val uid: String) extends Transformer
368382 }
369383
370384 val (dfTemplated, inputColName, serviceConfigured) =
371- configureService(service, df , promptCol, createMessagesUDF, attachmentsColumn)
385+ configureService(service, dfWithFilenames , promptCol, createMessagesUDF, attachmentsColumn)
372386 val result = generateText(serviceConfigured, dfTemplated)
373- if (getDropPrompt) result.drop(inputColName) else result
387+
388+ val resultCleaned = filenameColMapping.values.foldLeft(result) { (df, colName) =>
389+ if (df.columns.contains(colName)) df.drop(colName) else df
390+ }
391+
392+ if (getDropPrompt) resultCleaned.drop(inputColName) else resultCleaned
374393 }, dataset.columns.length)
375394 }
376395
0 commit comments