diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/search/AzureSearch.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/search/AzureSearch.scala index 3b1940b19e..eae4a33db5 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/search/AzureSearch.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/search/AzureSearch.scala @@ -226,23 +226,33 @@ object AzureSearchWriter extends IndexParser with IndexJsonGetter with SLogging private def prepareDF(df: DataFrame, //scalastyle:ignore method.length options: Map[String, String] = Map()): DataFrame = { val applicableOptions = Set( - "subscriptionKey", "actionCol", "serviceName", "indexName", "indexJson", - "apiVersion", "batchSize", "fatalErrors", "filterNulls", "keyCol", "vectorCols" + "subscriptionKey", "AADToken", "actionCol", "serviceName", "indexName", "indexJson", + "apiVersion", "batchSize", "fatalErrors", "filterNulls", "keyCol", "vectorCols", "url" ) options.keys.foreach(k => assert(applicableOptions(k), s"$k not an applicable option ${applicableOptions.toList}")) - val subscriptionKey = options("subscriptionKey") + val subscriptionKey = options.get("subscriptionKey") + val aadToken = options.get("AADToken") + val actionCol = options.getOrElse("actionCol", "@search.action") + val serviceName = options("serviceName") val indexJsonOpt = options.get("indexJson") val apiVersion = options.getOrElse("apiVersion", AzureSearchAPIConstants.DefaultAPIVersion) + val batchSize = options.getOrElse("batchSize", "100").toInt val fatalErrors = options.getOrElse("fatalErrors", "true").toBoolean val filterNulls = options.getOrElse("filterNulls", "false").toBoolean val vectorColsInfo = options.get("vectorCols") + + assert(!(subscriptionKey.isEmpty && aadToken.isEmpty), + "No auth found: Please set either subscriptionKey or AADToken") + assert(!(subscriptionKey.isDefined && aadToken.isDefined), + "Both subscriptionKey and AADToken is set. Please set either subscriptionKey or AADToken") + val keyCol = options.get("keyCol") val indexName = options.getOrElse("indexName", parseIndexJson(indexJsonOpt.get).name.get) if (indexJsonOpt.isDefined) { @@ -260,12 +270,13 @@ object AzureSearchWriter extends IndexParser with IndexJsonGetter with SLogging } } - val (indexJson, preppedDF) = if (getExisting(subscriptionKey, serviceName, apiVersion).contains(indexName)) { + val existingIndices = getExisting(subscriptionKey, aadToken, serviceName, apiVersion) + val (indexJson, preppedDF) = if (existingIndices.contains(indexName)) { if (indexJsonOpt.isDefined) { println(f"indexJsonOpt is specified, however an index for $indexName already exists," + f"we will use the index definition obtained from the existing index instead") } - val existingIndexJson = getIndexJsonFromExistingIndex(subscriptionKey, serviceName, indexName) + val existingIndexJson = getIndexJsonFromExistingIndex(subscriptionKey, aadToken, serviceName, indexName) val vectorColNameTypeTuple = getVectorColConf(existingIndexJson) (existingIndexJson, makeColsCompatible(vectorColNameTypeTuple, df)) } else if (indexJsonOpt.isDefined) { @@ -283,7 +294,7 @@ object AzureSearchWriter extends IndexParser with IndexJsonGetter with SLogging // Throws an exception if any nested field is a vector in the schema parseIndexJson(indexJson).fields.foreach(_.fields.foreach(assertNoNestedVectors)) - SearchIndex.createIfNoneExists(subscriptionKey, serviceName, indexJson, apiVersion) + SearchIndex.createIfNoneExists(subscriptionKey, aadToken, serviceName, indexJson, apiVersion) logInfo("checking schema parity") checkSchemaParity(preppedDF.schema, indexJson, actionCol) @@ -297,15 +308,17 @@ object AzureSearchWriter extends IndexParser with IndexJsonGetter with SLogging preppedDF } - new AddDocuments() - .setSubscriptionKey(subscriptionKey) + val ad = new AddDocuments() .setServiceName(serviceName) .setIndexName(indexName) .setActionCol(actionCol) .setBatchSize(batchSize) .setOutputCol("out") .setErrorCol("error") - .transform(df1) + val ad1 = subscriptionKey.map(key => ad.setSubscriptionKey(key)).getOrElse(ad) + val ad2 = aadToken.map(token => ad1.setAADToken(token)).getOrElse(ad1) + + ad2.transform(df1) .withColumn("error", UDFUtils.oldUdf(checkForErrors(fatalErrors) _, ErrorUtils.ErrorSchema)(col("error"), col("input"))) } diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/search/AzureSearchAPI.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/search/AzureSearchAPI.scala index 5d2fc8eb4a..f16f0b6676 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/search/AzureSearchAPI.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/search/AzureSearchAPI.scala @@ -11,12 +11,14 @@ import org.apache.http.entity.StringEntity import org.apache.log4j.{LogManager, Logger} import spray.json._ +import java.util.UUID import scala.util.{Failure, Success, Try} object AzureSearchAPIConstants { val DefaultAPIVersion = "2023-07-01-Preview" val VectorConfigName = "vectorConfig" val VectorSearchAlgorithm = "hnsw" + val AADHeaderName = "Authorization" } import com.microsoft.azure.synapse.ml.services.search.AzureSearchAPIConstants._ @@ -27,34 +29,44 @@ trait IndexParser { } trait IndexLister { - def getExisting(key: String, + + def getExisting(key: Option[String], + AADToken: Option[String], serviceName: String, apiVersion: String = DefaultAPIVersion): Seq[String] = { - val indexListRequest = new HttpGet( + val req = new HttpGet( s"https://$serviceName.search.windows.net/indexes?api-version=$apiVersion&$$select=name" ) - indexListRequest.setHeader("api-key", key) - val indexListResponse = safeSend(indexListRequest, close = false) - val indexList = IOUtils.toString(indexListResponse.getEntity.getContent, "utf-8").parseJson.convertTo[IndexList] - indexListResponse.close() + key.foreach(k => req.setHeader("api-key", k)) + AADToken.foreach { token => + req.setHeader(AADHeaderName, "Bearer " + token) + } + + val response = safeSend(req, close = false) + val indexList = IOUtils.toString(response.getEntity.getContent, "utf-8").parseJson.convertTo[IndexList] + response.close() for (i <- indexList.value.seq) yield i.name } } trait IndexJsonGetter extends IndexLister { - def getIndexJsonFromExistingIndex(key: String, + def getIndexJsonFromExistingIndex(key: Option[String], + AADToken: Option[String], serviceName: String, indexName: String, apiVersion: String = DefaultAPIVersion): String = { - val existingIndexNames = getExisting(key, serviceName, apiVersion) + val existingIndexNames = getExisting(key, AADToken, serviceName, apiVersion) assert(existingIndexNames.contains(indexName), s"Cannot find an existing index name with $indexName") - val indexJsonRequest = new HttpGet( + val req = new HttpGet( s"https://$serviceName.search.windows.net/indexes/$indexName?api-version=$apiVersion" ) - indexJsonRequest.setHeader("api-key", key) - indexJsonRequest.setHeader("Content-Type", "application/json") - val indexJsonResponse = safeSend(indexJsonRequest, close = false) + key.foreach(k => req.setHeader("api-key", k)) + AADToken.foreach { token => + req.setHeader(AADHeaderName, "Bearer " + token) + } + req.setHeader("Content-Type", "application/json") + val indexJsonResponse = safeSend(req, close = false) val indexJson = IOUtils.toString(indexJsonResponse.getEntity.getContent, "utf-8") indexJsonResponse.close() indexJson @@ -67,20 +79,24 @@ object SearchIndex extends IndexParser with IndexLister { val Logger: Logger = LogManager.getRootLogger - def createIfNoneExists(key: String, + def createIfNoneExists(key: Option[String], + AADToken: Option[String], serviceName: String, indexJson: String, apiVersion: String = DefaultAPIVersion): Unit = { val indexName = parseIndexJson(indexJson).name.get - val existingIndexNames = getExisting(key, serviceName, apiVersion) + val existingIndexNames = getExisting(key, AADToken, serviceName, apiVersion) if (!existingIndexNames.contains(indexName)) { - val createRequest = new HttpPost(s"https://$serviceName.search.windows.net/indexes?api-version=$apiVersion") - createRequest.setHeader("Content-Type", "application/json") - createRequest.setHeader("api-key", key) - createRequest.setEntity(prepareEntity(indexJson)) - val response = safeSend(createRequest) + val req = new HttpPost(s"https://$serviceName.search.windows.net/indexes?api-version=$apiVersion") + req.setHeader("Content-Type", "application/json") + key.foreach(k => req.setHeader("api-key", k)) + AADToken.foreach { token => + req.setHeader(AADHeaderName, "Bearer " + token) + } + req.setEntity(prepareEntity(indexJson)) + val response = safeSend(req) val status = response.getStatusLine.getStatusCode assert(status == 201) () @@ -133,7 +149,7 @@ object SearchIndex extends IndexParser with IndexLister { } private def validType(t: String, fields: Option[Seq[IndexField]]): Try[String] = { - val tdt = Try(AzureSearchWriter.edmTypeToSparkType(t,fields)) + val tdt = Try(AzureSearchWriter.edmTypeToSparkType(t, fields)) tdt.map(_ => t) } diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/search/SearchWriterSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/search/SearchWriterSuite.scala index 8b2069b664..c201a75106 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/search/SearchWriterSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/search/SearchWriterSuite.scala @@ -4,6 +4,7 @@ package com.microsoft.azure.synapse.ml.services.search import com.microsoft.azure.synapse.ml.Secrets +import com.microsoft.azure.synapse.ml.Secrets.getAccessToken import com.microsoft.azure.synapse.ml.services._ import com.microsoft.azure.synapse.ml.services.openai.{OpenAIAPIKey, OpenAIEmbedding} import com.microsoft.azure.synapse.ml.services.vision.AnalyzeImage @@ -132,9 +133,13 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexJsonGette override def beforeAll(): Unit = { println("WARNING CREATING SEARCH ENGINE!") - SearchIndex.createIfNoneExists(azureSearchKey, + SearchIndex.createIfNoneExists( + Some(azureSearchKey), + None, testServiceName, createSimpleIndexJson(indexName)) + val aadToken = getAccessToken("https://search.azure.com") + println(s"Triggering token creation early ${aadToken.length}") } def deleteIndex(indexName: String): Int = { @@ -148,7 +153,7 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexJsonGette override def afterAll(): Unit = { //TODO make this existing search indices when multiple builds are allowed println("Cleaning up services") - val successfulCleanup = getExisting(azureSearchKey, testServiceName) + val successfulCleanup = getExisting(Some(azureSearchKey), None, testServiceName) .intersect(createdIndexes).map { n => deleteIndex(n) }.forall(_ == 204) @@ -163,7 +168,7 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexJsonGette val twoDaysAgo = LocalDateTime.now().minusDays(2) val endingDatePattern: Regex = "^.*-(\\d{17})$".r - val e = getExisting(azureSearchKey, testServiceName) + val e = getExisting(Some(azureSearchKey), None, testServiceName) e.foreach { name => name match { case endingDatePattern(dateString) => @@ -235,7 +240,7 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexJsonGette } ignore("clean up all search indexes") { - getExisting(azureSearchKey, testServiceName) + getExisting(Some(azureSearchKey), None, testServiceName) .foreach { n => val deleteRequest = new HttpDelete( s"https://$testServiceName.search.windows.net/indexes/$n?api-version=2017-11-11") @@ -266,7 +271,7 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexJsonGette dependsOn(2, writeHelper(dfA, in2, isVectorField=false)) dependsOn(2, retryWithBackoff({ - if (getExisting(azureSearchKey, testServiceName).contains(in2)) { + if (getExisting(Some(azureSearchKey), None, testServiceName).contains(in2)) { writeHelper(dfB, in2, isVectorField=false) } else { throw new RuntimeException("No existing service found") @@ -315,7 +320,7 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexJsonGette """.stripMargin assertThrows[IllegalArgumentException] { - SearchIndex.createIfNoneExists(azureSearchKey, testServiceName, badJson) + SearchIndex.createIfNoneExists(Some(azureSearchKey), None, testServiceName, badJson) } } @@ -370,7 +375,9 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexJsonGette ("upload", "1", "file1", Array("p4", null, "p6"))) .toDF("searchAction", "id", "fileName", "phrases") - SearchIndex.createIfNoneExists(azureSearchKey, + SearchIndex.createIfNoneExists( + Some(azureSearchKey), + None, testServiceName, phraseIndex) @@ -404,6 +411,27 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexJsonGette retryWithBackoff(assertSize(in, 2)) } + test("Use AAD") { + val in = generateIndexName() + val phraseDF = Seq( + ("upload", "0", "file0", Array("p1", "p2", "p3")), + ("upload", "1", "file1", Array("p4", null, "p6"))) + .toDF("searchAction", "id", "fileName", "phrases") + val aadToken = getAccessToken("https://search.azure.com") + + AzureSearchWriter.write(phraseDF, + Map( + "AADToken" -> aadToken, + "actionCol" -> "searchAction", + "serviceName" -> testServiceName, + "filterNulls" -> "true", + "indexName" -> in, + "keyCol" -> "id" + )) + + retryWithBackoff(assertSize(in, 2)) + } + test("pipeline with analyze image") { val in = generateIndexName() @@ -449,7 +477,7 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexJsonGette writeHelper(dfA, in2, isVectorField=true) retryWithBackoff({ - if (getExisting(azureSearchKey, testServiceName).contains(in2)) { + if (getExisting(Some(azureSearchKey), None, testServiceName).contains(in2)) { writeHelper(dfB, in2, isVectorField=true) } else { throw new RuntimeException("No existing service found") @@ -459,7 +487,7 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexJsonGette retryWithBackoff(assertSize(in1, 4)) retryWithBackoff(assertSize(in2, 10)) - val indexJson = retryWithBackoff(getIndexJsonFromExistingIndex(azureSearchKey, testServiceName, in1)) + val indexJson = retryWithBackoff(getIndexJsonFromExistingIndex(Some(azureSearchKey), None, testServiceName, in1)) // assert if vectorCol is a vector field assert(parseIndexJson(indexJson).fields.find(_.name == "vectorCol").get.vectorSearchConfiguration.nonEmpty) } @@ -496,7 +524,7 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexJsonGette retryWithBackoff(assertSize(in, 2)) // assert if vectorCols are a vector field - val indexJson = retryWithBackoff(getIndexJsonFromExistingIndex(azureSearchKey, testServiceName, in)) + val indexJson = retryWithBackoff(getIndexJsonFromExistingIndex(Some(azureSearchKey), None, testServiceName, in)) assert(parseIndexJson(indexJson).fields.find(_.name == "vectorCol1").get.vectorSearchConfiguration.nonEmpty) assert(parseIndexJson(indexJson).fields.find(_.name == "vectorCol2").get.vectorSearchConfiguration.nonEmpty) assert(parseIndexJson(indexJson).fields.find(_.name == "vectorCol3").get.vectorSearchConfiguration.nonEmpty) @@ -578,7 +606,7 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexJsonGette """.stripMargin assertThrows[IllegalArgumentException] { - SearchIndex.createIfNoneExists(azureSearchKey, testServiceName, badJson) + SearchIndex.createIfNoneExists(Some(azureSearchKey), None, testServiceName, badJson) } } @@ -661,7 +689,7 @@ class SearchWriterSuite extends TestBase with AzureSearchKey with IndexJsonGette )) retryWithBackoff(assertSize(in, 2)) - val indexJson = retryWithBackoff(getIndexJsonFromExistingIndex(azureSearchKey, testServiceName, in)) + val indexJson = retryWithBackoff(getIndexJsonFromExistingIndex(Some(azureSearchKey), None, testServiceName, in)) assert(parseIndexJson(indexJson).fields.find(_.name == "vectorContent").get.vectorSearchConfiguration.nonEmpty) } }