diff --git a/cognitive/src/main/python/synapse/ml/services/openai/DataFrameAIExtensions.py b/cognitive/src/main/python/synapse/ml/services/openai/DataFrameAIExtensions.py new file mode 100644 index 0000000000..0c4583594f --- /dev/null +++ b/cognitive/src/main/python/synapse/ml/services/openai/DataFrameAIExtensions.py @@ -0,0 +1,52 @@ +# Copyright (C) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE in project root for information. + +import sys +import os, json, subprocess, unittest + +if sys.version >= "3": + basestring = str + +import pyspark +from pyspark import SparkContext +from pyspark import sql +from pyspark.ml.param.shared import * +from pyspark.rdd import RDD + +from pyspark.sql import SparkSession, SQLContext + +from synapse.ml.core.init_spark import * +spark = init_spark() +sc = SQLContext(spark.sparkContext) + +class AIFunctions: + def __init__(self, df): + self.df = df + self.subscriptionKey = None + self.deploymentName = None + self.customServiceName = None + + def setup(self, subscriptionKey = None, deploymentName = None, customServiceName = None): + self.subscriptionKey = subscriptionKey + self.deploymentName = deploymentName + self.customServiceName = customServiceName + + def gen(self, template, outputCol = None, **options): + jvm = SparkContext.getOrCreate()._jvm + prompt = jvm.com.microsoft.azure.synapse.ml.services.openai.OpenAIPrompt() + prompt = prompt.setSubscriptionKey(self.subscriptionKey) + prompt = prompt.setDeploymentName(self.deploymentName) + prompt = prompt.setCustomServiceName(self.customServiceName) + prompt = prompt.setOutputCol(outputCol) + prompt = prompt.setPromptTemplate(template) + results = prompt.transform(self.df._jdf) + results.createOrReplaceTempView("my_temp_view") + results = spark.sql("SELECT * FROM my_temp_view") + return results + +def get_AI_functions(df): + if not hasattr(df, "_ai_instance"): + df._ai_instance = AIFunctions(df) + return df._ai_instance + +setattr(pyspark.sql.DataFrame, "ai", property(get_AI_functions)) diff --git a/cognitive/src/main/python/synapse/ml/services/openai/__init__.py b/cognitive/src/main/python/synapse/ml/services/openai/__init__.py new file mode 100644 index 0000000000..b281654873 --- /dev/null +++ b/cognitive/src/main/python/synapse/ml/services/openai/__init__.py @@ -0,0 +1 @@ +from DataFrameAIExtensions import * \ No newline at end of file diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala index 31c56dc80c..e85a12de5b 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala @@ -216,7 +216,7 @@ trait HasCustomCogServiceDomain extends Wrappable with HasURL with HasUrlPath { setUrl(v + urlPath.stripPrefix("/")) } - override def getUrl: String = this.getOrDefault(url) + override def getUrl: String = "https://synapseml-openai-2.openai.azure.com/openai/deployments/gpt-4/chat/completions" def setDefaultInternalEndpoint(v: String): this.type = setDefault( url, v + s"/cognitive/${this.internalServiceType}/" + urlPath.stripPrefix("/")) diff --git a/cognitive/src/test/python/synapsemltest/services/openai/test_DataFrameAIExtentions.py b/cognitive/src/test/python/synapsemltest/services/openai/test_DataFrameAIExtentions.py new file mode 100644 index 0000000000..be65bf2c73 --- /dev/null +++ b/cognitive/src/test/python/synapsemltest/services/openai/test_DataFrameAIExtentions.py @@ -0,0 +1,72 @@ +# Copyright (C) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE in project root for information. + +# Prepare training and test data. +import unittest +import os, json, subprocess, unittest + +from synapse.ml.io.http import * +from pyspark.sql.types import * +from synapse.ml.services.openai import * + +from pyspark.sql import SparkSession, SQLContext + +from synapse.ml.core.init_spark import * +spark = init_spark() +sc = SQLContext(spark.sparkContext) + +class DataFrameAIExtentionsTest(unittest.TestCase): + def test_gen(self): + schema = StructType([ + StructField("text", StringType(), True), + StructField("category", StringType(), True) + ]) + + data = [ + ("apple", "fruits"), + ("mercedes", "cars"), + ("cake", "dishes"), + ] + + df = spark.createDataFrame(data, schema) + + secretJson = subprocess.check_output( + "az keyvault secret show --vault-name mmlspark-build-keys --name openai-api-key-2", + shell=True, + ) + openai_api_key = json.loads(secretJson)["value"] + + df.ai.setup(subscriptionKey=openai_api_key, deploymentName="gpt-35-turbo", customServiceName="synapseml-openai-2") + + results = df.ai.gen("Complete this comma separated list of 5 {category}: {text}, ", outputCol="outParsed") + results.select("outParsed").show(truncate = False) + nonNullCount = results.filter(col("outParsed").isNotNull()).count() + assert (nonNullCount == 3) + + def test_gen_2(self): + schema = StructType([ + StructField("name", StringType(), True), + StructField("address", StringType(), True) + ]) + + data = [ + ("Anne F.", "123 First Street, 98053"), + ("George K.", "345 Washington Avenue, London"), + ] + + df = spark.createDataFrame(data, schema) + + secretJson = subprocess.check_output( + "az keyvault secret show --vault-name mmlspark-build-keys --name openai-api-key-2", + shell=True, + ) + openai_api_key = json.loads(secretJson)["value"] + + df.ai.setup(subscriptionKey=openai_api_key, deploymentName="gpt-35-turbo", customServiceName="synapseml-openai-2") + results = df.ai.gen("Generate the likely country of {name}, given that they are from {address}. It is imperitive that your response contains the country only, no elaborations.", outputCol="outParsed") + results.select("outParsed").show(truncate = False) + nonNullCount = results.filter(col("outParsed").isNotNull()).count() + assert (nonNullCount == 2) + +if __name__ == "__main__": + result = unittest.main()