diff --git a/Experiments/Experiments/DefaultFeatureFlagService.swift b/Experiments/Experiments/DefaultFeatureFlagService.swift index b1d9069dcca..c0f7ee10d11 100644 --- a/Experiments/Experiments/DefaultFeatureFlagService.swift +++ b/Experiments/Experiments/DefaultFeatureFlagService.swift @@ -95,6 +95,8 @@ public struct DefaultFeatureFlagService: FeatureFlagService { return buildConfig == .localDeveloper || buildConfig == .alpha case .backgroundProductImageUpload: return buildConfig == .localDeveloper || buildConfig == .alpha + case .allowMerchantAIAPIKey: + return buildConfig == .localDeveloper || buildConfig == .alpha default: return true } diff --git a/Experiments/Experiments/FeatureFlag.swift b/Experiments/Experiments/FeatureFlag.swift index 35000778f6c..1e5e654999e 100644 --- a/Experiments/Experiments/FeatureFlag.swift +++ b/Experiments/Experiments/FeatureFlag.swift @@ -204,4 +204,8 @@ public enum FeatureFlag: Int { /// Supports uploading product images in background /// case backgroundProductImageUpload + + /// Allows merchants to use their own API keys for AI-powered features + /// + case allowMerchantAIAPIKey } diff --git a/Networking/Networking/Mapper/AIProductMapper.swift b/Networking/Networking/Mapper/AIProductMapper.swift index e58d99f896a..8b107a1bb94 100644 --- a/Networking/Networking/Mapper/AIProductMapper.swift +++ b/Networking/Networking/Mapper/AIProductMapper.swift @@ -14,4 +14,12 @@ struct AIProductMapper: Mapper { .removingSuffix("```") return try decoder.decode(AIProduct.self, from: Data(textCompletion.utf8)) } + + func map(dictionary: [String: Any]) throws -> AIProduct { + // For non-network requests we pass [String: Any] to the map() as data, however decoding + // relies on mapping to a JetpackAIQueryResponse model, which fails in the direct API usage key + let decoder = JSONDecoder() + let data = try JSONSerialization.data(withJSONObject: dictionary) + return try decoder.decode(AIProduct.self, from: data) + } } diff --git a/Networking/Networking/Remote/GenerativeContentRemote.swift b/Networking/Networking/Remote/GenerativeContentRemote.swift index fff77e64224..62112b4b7a6 100644 --- a/Networking/Networking/Remote/GenerativeContentRemote.swift +++ b/Networking/Networking/Remote/GenerativeContentRemote.swift @@ -1,4 +1,5 @@ import Foundation +import KeychainAccess /// Used by backend to track AI-generation usage and measure costs public enum GenerativeContentRemoteFeature: String { @@ -14,6 +15,31 @@ public enum GenerativeContentRemoteResponseFormat: String { case text = "text" } +public struct AIProviderKeyStorage { + private let keychain: Keychain + + public init(keychain: Keychain = Keychain(service: WooConstants.keychainServiceName)) { + self.keychain = keychain + } + + // Returns the merchant AI API if available + public var aiProviderAPIKey: String? { + guard let key = keychain.aiProviderAPIKey else { + return nil + } + return key + } +} + +private extension Keychain { + private static let keychainAIProviderAPIKey = "aiProviderAPIKey" + + var aiProviderAPIKey: String? { + get { self[Keychain.keychainAIProviderAPIKey] } + set { self[Keychain.keychainAIProviderAPIKey] = newValue } + } +} + /// Protocol for `GenerativeContentRemote` mainly used for mocking. /// public protocol GenerativeContentRemoteProtocol { @@ -21,11 +47,13 @@ public protocol GenerativeContentRemoteProtocol { /// - Parameters: /// - siteID: WPCOM ID of the site. /// - base: Prompt for the AI-generated text. + /// - shouldUseMerchantAIKey: If should use the merchant's API key for AI functionalities /// - feature: Used by backend to track AI-generation usage and measure costs /// - responseFormat: enum parameter to specify response format. /// - Returns: AI-generated text based on the prompt if Jetpack AI is enabled. func generateText(siteID: Int64, base: String, + shouldUseMerchantAIKey: Bool, feature: GenerativeContentRemoteFeature, responseFormat: GenerativeContentRemoteResponseFormat) async throws -> String @@ -33,10 +61,12 @@ public protocol GenerativeContentRemoteProtocol { /// - Parameters: /// - siteID: WPCOM ID of the site. /// - string: String from which we should identify the language + /// - shouldUseMerchantAIKey: If should use the merchant's API key for AI functionalities /// - feature: Used by backend to track AI-generation usage and measure costs /// - Returns: ISO code of the language func identifyLanguage(siteID: Int64, string: String, + shouldUseMerchantAIKey: Bool, feature: GenerativeContentRemoteFeature) async throws -> String /// Generates a product using provided info @@ -46,6 +76,7 @@ public protocol GenerativeContentRemoteProtocol { /// - keywords: Keywords describing the product to input for AI prompt /// - language: Language to generate the product details /// - tone: Tone of AI - Represented by `AIToneVoice` + /// - shouldUseMerchantAIKey: If should use the merchant's API key for AI functionalities /// - currencySymbol: Currency symbol to generate product price /// - dimensionUnit: Weight unit to generate product dimensions /// - weightUnit: Weight unit to generate product weight @@ -57,6 +88,7 @@ public protocol GenerativeContentRemoteProtocol { keywords: String, language: String, tone: String, + shouldUseMerchantAIKey: Bool, currencySymbol: String, dimensionUnit: String?, weightUnit: String?, @@ -72,11 +104,27 @@ public final class GenerativeContentRemote: Remote, GenerativeContentRemoteProto } private var token: JWToken? + private var storage: AIProviderKeyStorage = AIProviderKeyStorage() public func generateText(siteID: Int64, base: String, + shouldUseMerchantAIKey: Bool, feature: GenerativeContentRemoteFeature, responseFormat: GenerativeContentRemoteResponseFormat) async throws -> String { + if shouldUseMerchantAIKey { + return try await generateTextUsingMerchantAPIKey(base: base, responseFormat: responseFormat) + } else { + return try await generateTextUsingJetpack(siteID: siteID, + base: base, + feature: feature, + responseFormat: responseFormat) + } + } + + private func generateTextUsingJetpack(siteID: Int64, + base: String, + feature: GenerativeContentRemoteFeature, + responseFormat: GenerativeContentRemoteResponseFormat) async throws -> String { do { guard let token, token.isTokenValid(for: siteID) else { throw GenerativeContentRemoteError.tokenNotFound @@ -90,9 +138,53 @@ public final class GenerativeContentRemote: Remote, GenerativeContentRemoteProto } } + private func generateTextUsingMerchantAPIKey(base: String, + responseFormat: GenerativeContentRemoteResponseFormat) async throws -> String { + guard let key = storage.aiProviderAPIKey else { + throw URLError(.unknown) + } + let selectedModel = UserDefaults.standard.string(forKey: "AIProviderModel") ?? "" + let selectedProvider = AIProvider(rawValue: UserDefaults.standard.string(forKey: "AIProvider") ?? "OpenAI") ?? .openAI + + let request = try createAIRequest( + provider: selectedProvider, + apiKey: key, + model: selectedModel, + prompt: base + ) + + let (data, response) = try await URLSession.shared.data(for: request) + + guard let httpResponse = response as? HTTPURLResponse, httpResponse.statusCode == 200 else { + throw URLError(.badServerResponse) + } + + let json = try JSONSerialization.jsonObject(with: data) as? [String: Any] + guard let choices = json?["choices"] as? [[String: Any]], + let message = choices.first?["message"] as? [String: Any], + let content = message["content"] as? String else { + throw URLError(.cannotParseResponse) + } + + return content.trimmingCharacters(in: .whitespacesAndNewlines) + } + public func identifyLanguage(siteID: Int64, string: String, + shouldUseMerchantAIKey: Bool, feature: GenerativeContentRemoteFeature) async throws -> String { + if shouldUseMerchantAIKey { + return try await identifyLanguageUsingMerchantAPIKey(string: string) + } else { + return try await identifyLanguageUsingJetpack(siteID: siteID, + string: string, + feature: feature) + } + } + + private func identifyLanguageUsingJetpack(siteID: Int64, + string: String, + feature: GenerativeContentRemoteFeature) async throws -> String { do { guard let token, token.isTokenValid(for: siteID) else { throw GenerativeContentRemoteError.tokenNotFound @@ -106,49 +198,171 @@ public final class GenerativeContentRemote: Remote, GenerativeContentRemoteProto } } + private func identifyLanguageUsingMerchantAPIKey(string: String) async throws -> String { + guard let key = storage.aiProviderAPIKey else { + throw URLError(.unknown) + } + let selectedModel = UserDefaults.standard.string(forKey: "AIProviderModel") ?? "" + let selectedProvider = AIProvider(rawValue: UserDefaults.standard.string(forKey: "AIProvider") ?? "OpenAI") ?? .openAI + let prompt = String(format: AIRequestPrompts.identifyLanguage, string) + + let request = try createAIRequest( + provider: selectedProvider, + apiKey: key, + model: selectedModel, + prompt: prompt + ) + + let (data, response) = try await URLSession.shared.data(for: request) + + guard let httpResponse = response as? HTTPURLResponse, httpResponse.statusCode == 200 else { + throw URLError(.badServerResponse) + } + + let json = try JSONSerialization.jsonObject(with: data) as? [String: Any] + return try parseAIResponse(json: json, provider: selectedProvider) + } + public func generateAIProduct(siteID: Int64, productName: String?, keywords: String, language: String, tone: String, + shouldUseMerchantAIKey: Bool, currencySymbol: String, dimensionUnit: String?, weightUnit: String?, categories: [ProductCategory], tags: [ProductTag]) async throws -> AIProduct { + if shouldUseMerchantAIKey { + return try await generateAIProductUsingMerchantAPIKey(productName: productName, + keywords: keywords, + language: language, + tone: tone, + currencySymbol: currencySymbol, + dimensionUnit: dimensionUnit, + weightUnit: weightUnit, + categories: categories, + tags: tags) + } else { + return try await generateAIProductUsingJetpack(siteID: siteID, + productName: productName, + keywords: keywords, + language: language, + tone: tone, + currencySymbol: currencySymbol, + dimensionUnit: dimensionUnit, + weightUnit: weightUnit, + categories: categories, + tags: tags) + } + } + private func generateAIProductUsingJetpack(siteID: Int64, + productName: String?, + keywords: String, + language: String, + tone: String, + currencySymbol: String, + dimensionUnit: String?, + weightUnit: String?, + categories: [ProductCategory], + tags: [ProductTag]) async throws -> AIProduct { do { guard let token, token.isTokenValid(for: siteID) else { throw GenerativeContentRemoteError.tokenNotFound } return try await generateAIProduct(siteID: siteID, - productName: productName, - keywords: keywords, - language: language, - tone: tone, - currencySymbol: currencySymbol, - dimensionUnit: dimensionUnit, - weightUnit: weightUnit, - categories: categories, - tags: tags, - token: token) + productName: productName, + keywords: keywords, + language: language, + tone: tone, + currencySymbol: currencySymbol, + dimensionUnit: dimensionUnit, + weightUnit: weightUnit, + categories: categories, + tags: tags, + token: token) } catch GenerativeContentRemoteError.tokenNotFound, WordPressApiError.unknown(code: TokenExpiredError.code, message: TokenExpiredError.message) { let token = try await fetchToken(siteID: siteID) self.token = token return try await generateAIProduct(siteID: siteID, - productName: productName, - keywords: keywords, - language: language, - tone: tone, - currencySymbol: currencySymbol, - dimensionUnit: dimensionUnit, - weightUnit: weightUnit, - categories: categories, - tags: tags, - token: token) + productName: productName, + keywords: keywords, + language: language, + tone: tone, + currencySymbol: currencySymbol, + dimensionUnit: dimensionUnit, + weightUnit: weightUnit, + categories: categories, + tags: tags, + token: token) } } + + private func generateAIProductUsingMerchantAPIKey(productName: String?, + keywords: String, + language: String, + tone: String, + currencySymbol: String, + dimensionUnit: String?, + weightUnit: String?, + categories: [ProductCategory], + tags: [ProductTag]) async throws -> AIProduct { + let selectedProvider = AIProvider(rawValue: UserDefaults.standard.string(forKey: "AIProvider") ?? "OpenAI") ?? .openAI + let selectedModel = UserDefaults.standard.string(forKey: "AIProviderModel") ?? "" + + var inputComponents = [String(format: AIRequestPrompts.inputComponents, keywords, tone)] + + if let productName = productName, !productName.isEmpty { + inputComponents.insert("name: ```\(productName)```", at: 1) + } + + let jsonResponseFormatDict = generateAIProductResponseFormat( + tags: tags, + categories: categories, + language: language, + tone: tone, + currencySymbol: currencySymbol, + dimensionUnit: dimensionUnit, + weightUnit: weightUnit + ) + + let expectedJsonFormat = formatExpectedJsonResponse(jsonResponseFormatDict) + + let prompt = inputComponents.joined(separator: "\n") + "\n" + expectedJsonFormat + + guard let key = storage.aiProviderAPIKey else { + throw URLError(.unknown) + } + + let request = try createAIRequest( + provider: selectedProvider, + apiKey: key, + model: selectedModel, + prompt: prompt + ) + + let (data, response) = try await URLSession.shared.data(for: request) + + guard let httpResponse = response as? HTTPURLResponse, httpResponse.statusCode == 200 else { + throw URLError(.badServerResponse) + } + + let json = try JSONSerialization.jsonObject(with: data) as? [String: Any] + + let contentString = try parseAIResponse(json: json, provider: selectedProvider) + + guard let data = contentString.data(using: .utf8), + let productJson = try? JSONSerialization.jsonObject(with: data) as? [String: Any] else { + throw URLError(.cannotParseResponse) + } + + // We don't need siteID for direct API calls, but need a new mapper + let mapper = AIProductMapper(siteID: 0) + return try mapper.map(dictionary: productJson) + } } private extension GenerativeContentRemote { @@ -189,11 +403,7 @@ private extension GenerativeContentRemote { string: String, feature: GenerativeContentRemoteFeature, token: JWToken) async throws -> String { - let prompt = [ - "What is the ISO language code of the language used in the below text?" + - "Do not include any explanations and only provide the ISO language code in your response.", - "Text: ```\(string)```" - ].joined(separator: "\n") + let prompt = String(format: AIRequestPrompts.identifyLanguage, string) let parameters: [String: Any] = [ParameterKey.token: token.token, ParameterKey.question: prompt, ParameterKey.stream: ParameterValue.stream, @@ -219,12 +429,7 @@ private extension GenerativeContentRemote { categories: [ProductCategory], tags: [ProductTag], token: JWToken) async throws -> AIProduct { - var inputComponents = [ - "You are a WooCommerce SEO and marketing expert, perform in-depth research about the product " + - "using the provided name, keywords and tone, and give your response in the below JSON format.", - "keywords: ```\(keywords)```", - "tone: ```\(tone)```", - ] + var inputComponents = [String(format: AIRequestPrompts.inputComponents, keywords, tone)] // Name will be added only if `productName` is available. // TODO: this code related to `productName` can be removed after releasing the new product creation with AI flow. Github issue: 13108 @@ -234,58 +439,17 @@ private extension GenerativeContentRemote { let input = inputComponents.joined(separator: "\n") - let jsonResponseFormatDict: [String: Any] = { - let tagsPrompt: String = { - guard !tags.isEmpty else { - return "Suggest an array of the best matching tags for this product." - } - - return "Given the list of available tags ```\(tags.map { $0.name }.joined(separator: ", "))```, " + - "suggest an array of the best matching tags for this product. You can suggest new tags as well." - }() - - let categoriesPrompt: String = { - guard !categories.isEmpty else { - return "Suggest an array of the best matching categories for this product." - } - - return "Given the list of available categories ```\(categories.map { $0.name }.joined(separator: ", "))```, " + - "suggest an array of the best matching categories for this product. You can suggest new categories as well." - }() - - let shippingPrompt = { - var dict = [String: String]() - if let weightUnit { - dict["weight"] = "Guess and provide only the number in \(weightUnit)" - } - - if let dimensionUnit { - dict["length"] = "Guess and provide only the number in \(dimensionUnit)" - dict["width"] = "Guess and provide only the number in \(dimensionUnit)" - dict["height"] = "Guess and provide only the number in \(dimensionUnit)" - } - return dict - }() - - // swiftlint:disable line_length - return ["names": "An array of strings, containing three different names of the product, written in the language with ISO code ```\(language)```", - "descriptions": "An array of strings, each containing three different product descriptions of around 100 words long each in a ```\(tone)``` tone, " - + "written in the language with ISO code ```\(language)```", - "short_descriptions": "An array of strings, each containing three different short descriptions of the product in a ```\(tone)``` tone, " - + "written in the language with ISO code ```\(language)```", - "virtual": "A boolean value that shows whether the product is virtual or physical", - "shipping": shippingPrompt, - "price": "Guess the price in \(currencySymbol), do not include the currency symbol, " - + "only provide the price as a number", - "tags": tagsPrompt, - "categories": categoriesPrompt] - }() + let jsonResponseFormatDict = generateAIProductResponseFormat( + tags: tags, + categories: categories, + language: language, + tone: tone, + currencySymbol: currencySymbol, + dimensionUnit: dimensionUnit, + weightUnit: weightUnit + ) - let expectedJsonFormat = - "Your response should be in JSON format and don't send anything extra. " + - "Don't include the word JSON in your response:" + - "\n" + - (jsonResponseFormatDict.toJSONEncoded() ?? "") + let expectedJsonFormat = formatExpectedJsonResponse(jsonResponseFormatDict) let prompt = input + "\n" + expectedJsonFormat @@ -346,3 +510,134 @@ private extension JWToken { expiryDate > Date() && siteID == currentSelectedSiteID } } + +private struct AIRequestPrompts { + static let identifyLanguage = [ + "What is the ISO language code of the language used in the below text?", + "Do not include any explanations and only provide the ISO language code in your response.", + "Text: ```%@" + ].joined(separator: "\n") + + static let inputComponents = [ + "You are a WooCommerce SEO and marketing expert, perform in-depth research about the product " + + "using the provided name, keywords, and tone, and give your response in the below JSON format.", + "keywords: ```%@```", + "tone: ```%@```" + ].joined(separator: "\n") +} + +private extension GenerativeContentRemote { + private enum AIProvider: String { + case openAI = "OpenAI" + case anthropic = "Anthropic" + + var requestURL: String { + switch self { + case .openAI: + return "https://api.openai.com/v1/chat/completions" + case .anthropic: + return "https://api.anthropic.com/v1/messages" + } + } + } + + private func createAIRequest(provider: AIProvider, apiKey: String, model: String, prompt: String) throws -> URLRequest { + let requestBody: [String: Any] = [ + "model": model, + "messages": [["role": "user", "content": prompt]], + "max_tokens": ParameterValue.maxTokens + ] + + let requestData = try JSONSerialization.data(withJSONObject: requestBody) + var request = URLRequest(url: URL(string: provider.requestURL)!) + request.httpMethod = "POST" + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + request.httpBody = requestData + + switch provider { + case .openAI: + request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization") + case .anthropic: + request.setValue(apiKey, forHTTPHeaderField: "x-api-key") + request.setValue("2023-06-01", forHTTPHeaderField: "anthropic-version") + } + + return request + } + + private func parseAIResponse(json: [String: Any]?, provider: AIProvider) throws -> String { + if provider == .openAI { + // openAI + guard let choices = json?["choices"] as? [[String: Any]], + let message = choices.first?["message"] as? [String: Any], + let content = message["content"] as? String else { + throw URLError(.cannotParseResponse) + } + return content.trimmingCharacters(in: .whitespacesAndNewlines) + } else { + // Anthropic // TODO: Switch to be explicit + guard let content = json?["content"] as? [[String: Any]], + let text = content.first?["text"] as? String else { + throw URLError(.cannotParseResponse) + } + return text.trimmingCharacters(in: .whitespacesAndNewlines) + } + } + + private func formatExpectedJsonResponse(_ jsonResponseFormat: [String: Any]) -> String { + return "Your response should be in JSON format and don't send anything extra. " + + "Don't include the word JSON in your response:\n" + + (jsonResponseFormat.toJSONEncoded() ?? "") + } + + private func generateAIProductResponseFormat(tags: [ProductTag], + categories: [ProductCategory], + language: String, + tone: String, + currencySymbol: String, + dimensionUnit: String?, + weightUnit: String?) -> [String: Any] { + let tagsPrompt: String = { + guard !tags.isEmpty else { + return "Suggest an array of the best matching tags for this product." + } + return "Given the list of available tags ```\(tags.map { $0.name }.joined(separator: ", "))```, " + + "suggest an array of the best matching tags for this product. You can suggest new tags as well." + }() + + let categoriesPrompt: String = { + guard !categories.isEmpty else { + return "Suggest an array of the best matching categories for this product." + } + return "Given the list of available categories ```\(categories.map { $0.name }.joined(separator: ", "))```, " + + "suggest an array of the best matching categories for this product. You can suggest new categories as well." + }() + + let shippingPrompt = { + var dict = [String: String]() + if let weightUnit { + dict["weight"] = "Guess and provide only the number in \(weightUnit)" + } + if let dimensionUnit { + dict["length"] = "Guess and provide only the number in \(dimensionUnit)" + dict["width"] = "Guess and provide only the number in \(dimensionUnit)" + dict["height"] = "Guess and provide only the number in \(dimensionUnit)" + } + return dict + }() + + // swiftlint:disable line_length + return [ + "names": "An array of strings, containing three different names of the product, written in the language with ISO code ```\(language)```", + "descriptions": "An array of strings, each containing three different product descriptions of around 100 words long each in a ```\(tone)``` tone, " + + "written in the language with ISO code ```\(language)```", + "short_descriptions": "An array of strings, each containing three different short descriptions of the product in a ```\(tone)``` tone, " + + "written in the language with ISO code ```\(language)```", + "virtual": "A boolean value that shows whether the product is virtual or physical", + "shipping": shippingPrompt, + "price": "Guess the price in \(currencySymbol), do not include the currency symbol, only provide the price as a number", + "tags": tagsPrompt, + "categories": categoriesPrompt + ] + } +} diff --git a/Networking/NetworkingTests/Remote/GenerativeContentRemoteTests.swift b/Networking/NetworkingTests/Remote/GenerativeContentRemoteTests.swift index 682f826df8a..e763d6f15db 100644 --- a/Networking/NetworkingTests/Remote/GenerativeContentRemoteTests.swift +++ b/Networking/NetworkingTests/Remote/GenerativeContentRemoteTests.swift @@ -30,6 +30,7 @@ final class GenerativeContentRemoteTests: XCTestCase { // When _ = try await remote.generateText(siteID: sampleSiteID, base: "generate a product description for wapuu pencil", + shouldUseMerchantAIKey: false, feature: .productDescription, responseFormat: .text) @@ -48,6 +49,7 @@ final class GenerativeContentRemoteTests: XCTestCase { // When _ = try await remote.generateText(siteID: sampleSiteID, base: "generate a product description for wapuu pencil", + shouldUseMerchantAIKey: false, feature: .productDescription, responseFormat: .text) @@ -66,6 +68,7 @@ final class GenerativeContentRemoteTests: XCTestCase { // When let generatedText = try await remote.generateText(siteID: sampleSiteID, base: "generate a product description for wapuu pencil", + shouldUseMerchantAIKey: false, feature: .productDescription, responseFormat: .text) @@ -83,6 +86,7 @@ final class GenerativeContentRemoteTests: XCTestCase { await assertThrowsError { _ = try await remote.generateText(siteID: sampleSiteID, base: "generate a product description for wapuu pencil", + shouldUseMerchantAIKey: false, feature: .productDescription, responseFormat: .text) } errorAssert: { error in @@ -101,6 +105,7 @@ final class GenerativeContentRemoteTests: XCTestCase { await assertThrowsError { _ = try await remote.generateText(siteID: sampleSiteID, base: "generate a product description for wapuu pencil", + shouldUseMerchantAIKey: false, feature: .productDescription, responseFormat: .text) } errorAssert: { error in @@ -120,6 +125,7 @@ final class GenerativeContentRemoteTests: XCTestCase { // When _ = try await remote.generateText(siteID: sampleSiteID, base: "generate a product description for wapuu pencil", + shouldUseMerchantAIKey: false, feature: .productDescription, responseFormat: .text) // Then @@ -128,6 +134,7 @@ final class GenerativeContentRemoteTests: XCTestCase { // When _ = try await remote.generateText(siteID: sampleSiteID, base: "generate a product description for wapuu pencil", + shouldUseMerchantAIKey: false, feature: .productDescription, responseFormat: .text) @@ -140,6 +147,7 @@ final class GenerativeContentRemoteTests: XCTestCase { network.simulateResponse(requestUrlSuffix: jetpackAIQueryPath, filename: "generative-text-invalid-token") _ = try? await remote.generateText(siteID: sampleSiteID, base: "generate a product description for wapuu pencil", + shouldUseMerchantAIKey: false, feature: .productDescription, responseFormat: .text) @@ -159,6 +167,7 @@ final class GenerativeContentRemoteTests: XCTestCase { // When _ = try await remote.generateText(siteID: sampleSiteID, base: "generate a product description for wapuu pencil", + shouldUseMerchantAIKey: false, feature: .productDescription, responseFormat: .text) // Then @@ -167,6 +176,7 @@ final class GenerativeContentRemoteTests: XCTestCase { // When _ = try await remote.generateText(siteID: sampleSiteID, base: "generate a product description for wapuu pencil", + shouldUseMerchantAIKey: false, feature: .productDescription, responseFormat: .text) @@ -186,6 +196,7 @@ final class GenerativeContentRemoteTests: XCTestCase { // When _ = try await remote.identifyLanguage(siteID: sampleSiteID, string: "Woo is awesome.", + shouldUseMerchantAIKey: false, feature: .productDescription) // Then @@ -203,6 +214,7 @@ final class GenerativeContentRemoteTests: XCTestCase { // When let language = try await remote.identifyLanguage(siteID: sampleSiteID, string: "Woo is awesome.", + shouldUseMerchantAIKey: false, feature: .productDescription) // Then @@ -219,6 +231,7 @@ final class GenerativeContentRemoteTests: XCTestCase { await assertThrowsError { _ = try await remote.identifyLanguage(siteID: sampleSiteID, string: "Woo is awesome.", + shouldUseMerchantAIKey: false, feature: .productDescription) } errorAssert: { error in // Then @@ -236,6 +249,7 @@ final class GenerativeContentRemoteTests: XCTestCase { await assertThrowsError { _ = try await remote.identifyLanguage(siteID: sampleSiteID, string: "Woo is awesome.", + shouldUseMerchantAIKey: false, feature: .productDescription) } errorAssert: { error in // Then @@ -254,6 +268,7 @@ final class GenerativeContentRemoteTests: XCTestCase { // When _ = try await remote.identifyLanguage(siteID: sampleSiteID, string: "Woo is awesome.", + shouldUseMerchantAIKey: false, feature: .productDescription) // Then XCTAssertEqual(numberOfJwtRequests(in: network.requestsForResponseData), 1) @@ -261,6 +276,7 @@ final class GenerativeContentRemoteTests: XCTestCase { // When _ = try await remote.identifyLanguage(siteID: sampleSiteID, string: "Woo is awesome.", + shouldUseMerchantAIKey: false, feature: .productDescription) // Then @@ -272,6 +288,7 @@ final class GenerativeContentRemoteTests: XCTestCase { network.simulateResponse(requestUrlSuffix: jetpackAIQueryPath, filename: "identify-language-invalid-token") _ = try? await remote.identifyLanguage(siteID: sampleSiteID, string: "Woo is awesome.", + shouldUseMerchantAIKey: false, feature: .productDescription) // Then @@ -290,6 +307,7 @@ final class GenerativeContentRemoteTests: XCTestCase { // When _ = try await remote.identifyLanguage(siteID: sampleSiteID, string: "Woo is awesome.", + shouldUseMerchantAIKey: false, feature: .productDescription) // Then XCTAssertEqual(numberOfJwtRequests(in: network.requestsForResponseData), 1) @@ -297,6 +315,7 @@ final class GenerativeContentRemoteTests: XCTestCase { // When _ = try await remote.identifyLanguage(siteID: sampleSiteID, string: "Woo is awesome.", + shouldUseMerchantAIKey: false, feature: .productDescription) // Then @@ -318,6 +337,7 @@ final class GenerativeContentRemoteTests: XCTestCase { keywords: "Crunchy, Crispy", language: "en", tone: "Casual", + shouldUseMerchantAIKey: false, currencySymbol: "INR", dimensionUnit: "cm", weightUnit: "kg", @@ -342,6 +362,7 @@ final class GenerativeContentRemoteTests: XCTestCase { keywords: "Crunchy, Crispy", language: "en", tone: "Casual", + shouldUseMerchantAIKey: false, currencySymbol: "INR", dimensionUnit: "cm", weightUnit: "kg", @@ -366,6 +387,7 @@ final class GenerativeContentRemoteTests: XCTestCase { keywords: "Crunchy, Crispy", language: "en", tone: "Casual", + shouldUseMerchantAIKey: false, currencySymbol: "INR", dimensionUnit: "cm", weightUnit: "kg", @@ -390,6 +412,7 @@ final class GenerativeContentRemoteTests: XCTestCase { keywords: "Crunchy, Crispy", language: "en", tone: "Casual", + shouldUseMerchantAIKey: false, currencySymbol: "INR", dimensionUnit: "cm", weightUnit: "kg", @@ -417,6 +440,7 @@ final class GenerativeContentRemoteTests: XCTestCase { keywords: "Crunchy, Crispy", language: "en", tone: "Casual", + shouldUseMerchantAIKey: false, currencySymbol: "INR", dimensionUnit: "cm", weightUnit: "kg", @@ -442,6 +466,7 @@ final class GenerativeContentRemoteTests: XCTestCase { keywords: "Crunchy, Crispy", language: "en", tone: "Casual", + shouldUseMerchantAIKey: false, currencySymbol: "INR", dimensionUnit: "cm", weightUnit: "kg", @@ -468,6 +493,7 @@ final class GenerativeContentRemoteTests: XCTestCase { keywords: "Crunchy, Crispy", language: "en", tone: "Casual", + shouldUseMerchantAIKey: false, currencySymbol: "INR", dimensionUnit: "cm", weightUnit: "kg", @@ -492,6 +518,7 @@ final class GenerativeContentRemoteTests: XCTestCase { keywords: "Crunchy, Crispy", language: "en", tone: "Casual", + shouldUseMerchantAIKey: false, currencySymbol: "INR", dimensionUnit: "cm", weightUnit: "kg", @@ -525,6 +552,7 @@ final class GenerativeContentRemoteTests: XCTestCase { keywords: "Crunchy, Crispy", language: "en", tone: "Casual", + shouldUseMerchantAIKey: false, currencySymbol: "INR", dimensionUnit: "cm", weightUnit: "kg", @@ -548,6 +576,7 @@ final class GenerativeContentRemoteTests: XCTestCase { keywords: "Crunchy, Crispy", language: "en", tone: "Casual", + shouldUseMerchantAIKey: false, currencySymbol: "INR", dimensionUnit: "cm", weightUnit: "kg", @@ -572,6 +601,7 @@ final class GenerativeContentRemoteTests: XCTestCase { keywords: "Crunchy, Crispy", language: "en", tone: "Casual", + shouldUseMerchantAIKey: false, currencySymbol: "INR", dimensionUnit: "cm", weightUnit: "kg", @@ -597,6 +627,7 @@ final class GenerativeContentRemoteTests: XCTestCase { keywords: "Crunchy, Crispy", language: "en", tone: "Casual", + shouldUseMerchantAIKey: false, currencySymbol: "INR", dimensionUnit: "cm", weightUnit: "kg", @@ -611,6 +642,7 @@ final class GenerativeContentRemoteTests: XCTestCase { keywords: "Crunchy, Crispy", language: "en", tone: "Casual", + shouldUseMerchantAIKey: false, currencySymbol: "INR", dimensionUnit: "cm", weightUnit: "kg", @@ -629,6 +661,7 @@ final class GenerativeContentRemoteTests: XCTestCase { keywords: "Crunchy, Crispy", language: "en", tone: "Casual", + shouldUseMerchantAIKey: false, currencySymbol: "INR", dimensionUnit: "cm", weightUnit: "kg", @@ -654,6 +687,7 @@ final class GenerativeContentRemoteTests: XCTestCase { keywords: "Crunchy, Crispy", language: "en", tone: "Casual", + shouldUseMerchantAIKey: false, currencySymbol: "INR", dimensionUnit: "cm", weightUnit: "kg", @@ -668,6 +702,7 @@ final class GenerativeContentRemoteTests: XCTestCase { keywords: "Crunchy, Crispy", language: "en", tone: "Casual", + shouldUseMerchantAIKey: false, currencySymbol: "INR", dimensionUnit: "cm", weightUnit: "kg", diff --git a/WooCommerce/Classes/Analytics/WooAnalyticsEvent+ProductCreationAI.swift b/WooCommerce/Classes/Analytics/WooAnalyticsEvent+ProductCreationAI.swift index 293fd5ca1a5..6ca9497726c 100644 --- a/WooCommerce/Classes/Analytics/WooAnalyticsEvent+ProductCreationAI.swift +++ b/WooCommerce/Classes/Analytics/WooAnalyticsEvent+ProductCreationAI.swift @@ -12,6 +12,7 @@ extension WooAnalyticsEvent { case description case field case featureWordCount = "feature_word_count" + case aiSource = "ai_source" } static func entryPointDisplayed() -> WooAnalyticsEvent { @@ -19,9 +20,9 @@ extension WooAnalyticsEvent { properties: [:]) } - static func entryPointTapped() -> WooAnalyticsEvent { + static func entryPointTapped(_ aiSource: AISource) -> WooAnalyticsEvent { WooAnalyticsEvent(statName: .productCreationAIEntryPointTapped, - properties: [:]) + properties: [Key.aiSource.rawValue: aiSource.rawValue]) } static func productNameContinueTapped() -> WooAnalyticsEvent { diff --git a/WooCommerce/Classes/Analytics/WooAnalyticsStat.swift b/WooCommerce/Classes/Analytics/WooAnalyticsStat.swift index 49ba716d44f..f166e610be9 100644 --- a/WooCommerce/Classes/Analytics/WooAnalyticsStat.swift +++ b/WooCommerce/Classes/Analytics/WooAnalyticsStat.swift @@ -1088,6 +1088,7 @@ enum WooAnalyticsStat: String { case hubMenuSwitchStoreTapped = "hub_menu_switch_store_tapped" case hubMenuOptionTapped = "hub_menu_option_tapped" case hubMenuSettingsTapped = "hub_menu_settings_tapped" + case hubMenuAISettingsTapped = "hub_menu_ai_settings_tapped" // MARK: Coupons case couponsLoaded = "coupons_loaded" diff --git a/WooCommerce/Classes/Authentication/Keychain+Entries.swift b/WooCommerce/Classes/Authentication/Keychain+Entries.swift index 1b23d0bd833..676c68a8106 100644 --- a/WooCommerce/Classes/Authentication/Keychain+Entries.swift +++ b/WooCommerce/Classes/Authentication/Keychain+Entries.swift @@ -26,4 +26,11 @@ extension Keychain { get { self[WooConstants.siteCredentialPassword] } set { self[WooConstants.siteCredentialPassword] = newValue } } + + /// AI Provider API key + /// + var aiProviderAPIKey: String? { + get { self[WooConstants.aiProviderAPIKey] } + set { self[WooConstants.aiProviderAPIKey] = newValue } + } } diff --git a/WooCommerce/Classes/System/WooConstants.swift b/WooCommerce/Classes/System/WooConstants.swift index 88c5fd8031d..a4a3743d5ec 100644 --- a/WooCommerce/Classes/System/WooConstants.swift +++ b/WooCommerce/Classes/System/WooConstants.swift @@ -33,6 +33,10 @@ public enum WooConstants { /// static let siteCredentialPassword = "siteCredentialPassword" + /// Keychain Access's Key for the AI API key entered by the merchant in AI settings + /// + static let aiProviderAPIKey = "aiProviderAPIKey" + /// Keychain Access's Key for the current application password /// static let applicationPassword = "ApplicationPassword" diff --git a/WooCommerce/Classes/ViewRelated/AI Settings/AISettingsView.swift b/WooCommerce/Classes/ViewRelated/AI Settings/AISettingsView.swift new file mode 100644 index 00000000000..59323fd5f5b --- /dev/null +++ b/WooCommerce/Classes/ViewRelated/AI Settings/AISettingsView.swift @@ -0,0 +1,230 @@ +import SwiftUI +import Yosemite + +struct AISettingsView: View { + @ObservedObject private var viewModel: AISettingsViewModel + + // If we're already providing AI capabilities via WPCOM or JPAI we can + // override API key usage + private var shouldUseWPCOMJPAISource: Bool { + guard let site = ServiceLocator.stores.sessionManager.defaultSite else { + return false + } + if site.isWordPressComStore || site.isAIAssistantFeatureActive { + return true + } else { + return false + } + } + + init(viewModel: AISettingsViewModel) { + self.viewModel = viewModel + } + + var body: some View { + ScrollView { + VStack(alignment: .leading, spacing: 16) { + if shouldUseWPCOMJPAISource { + Text(Localization.builtInAIEnabled) + .font(.callout) + .foregroundColor(.secondary) + .padding() + .background( + RoundedRectangle(cornerRadius: 8) + .fill(Color(.systemGray6)) + ) + .overlay( + RoundedRectangle(cornerRadius: 8) + .stroke(Color(.gray), lineWidth: 1) + ) + .padding(.bottom, 8) + .frame(width: .infinity) + } + + HStack { + Text(Localization.aiProvider) + Picker(Localization.selectProvider, selection: $viewModel.selectedProvider) { + Text(Localization.openAI).tag("OpenAI") + Text(Localization.anthropic).tag("Anthropic") + } + .pickerStyle(MenuPickerStyle()) + .onChange(of: viewModel.selectedProvider) { newValue in + viewModel.updateProvider(newValue) + } + .disabled(shouldUseWPCOMJPAISource) + .opacity(shouldUseWPCOMJPAISource ? 0.5 : 1.0) + + if shouldUseWPCOMJPAISource { + Image(systemName: "lock.fill") + .foregroundColor(.gray) + } + } + + HStack { + Text(Localization.models) + Picker(Localization.selectModel, selection: $viewModel.selectedModel) { + ForEach(viewModel.selectedProvider == "OpenAI" ? viewModel.openAIModels : viewModel.anthropicModels, id: \.self) { model in + Text(model).tag(model) + } + } + .pickerStyle(MenuPickerStyle()) + .disabled(shouldUseWPCOMJPAISource) + .opacity(shouldUseWPCOMJPAISource ? 0.5 : 1.0) + + if shouldUseWPCOMJPAISource { + Image(systemName: "lock.fill") + .foregroundColor(.gray) + } + } + + Divider() + + VStack(alignment: .leading, spacing: 8) { + HStack { + TextField( + Localization.enterAPIKey, + text: Binding( + get: { viewModel.isEditingApiKey ? viewModel.apiKey : "**********" }, + set: { newValue in if viewModel.isEditingApiKey { viewModel.apiKey = newValue } } + ) + ) + .textFieldStyle(RoundedBorderTextFieldStyle(focused: viewModel.isEditingApiKey)) + .foregroundColor(.primary) + .privacySensitive() + .disabled(!viewModel.isEditingApiKey) + + if viewModel.isEditingApiKey, !viewModel.apiKey.isEmpty { + Button(action: viewModel.clearApiKey) { + Image(systemName: "xmark.circle.fill") + .foregroundColor(.gray) + } + } + + Button(action: viewModel.toggleEditing) { + Text(viewModel.isEditingApiKey ? Localization.save : Localization.edit) + } + .disabled(shouldUseWPCOMJPAISource) + .opacity(shouldUseWPCOMJPAISource ? 0.5 : 1.0) + + if shouldUseWPCOMJPAISource { + Image(systemName: "lock.fill") + .foregroundColor(.gray) + } + } + + Text(Localization.apiKeyDescription) + .font(.caption) + .foregroundColor(.secondary) + + Spacer() + } + Text(Localization.apiKeyDisclaimer) + .font(.caption) + .foregroundColor(.secondary) + } + .padding() + .onAppear { + viewModel.onAppear() + if shouldUseWPCOMJPAISource { + viewModel.selectedProvider = "OpenAI" + viewModel.selectedModel = "gpt-4o" + } + } + } + .navigationTitle(Localization.navigationTitle) + } +} + +private extension AISettingsView { + enum Localization { + static let navigationTitle = NSLocalizedString( + "aiSettings.navigationTitle", + value: "AI Settings", + comment: "Navigation title for the AI Settings screen" + ) + + static let aiProvider = NSLocalizedString( + "aiSettings.aiProvider", + value: "Provider", + comment: "Label for the AI provider selection in AI settings" + ) + + static let selectProvider = NSLocalizedString( + "aiSettings.selectProvider", + value: "Select Provider", + comment: "Accessibility label for the AI provider picker" + ) + + static let openAI = NSLocalizedString( + "aiSettings.openAI", + value: "OpenAI", + comment: "Label for OpenAI provider option" + ) + + static let anthropic = NSLocalizedString( + "aiSettings.anthropic", + value: "Anthropic", + comment: "Label for Anthropic provider option" + ) + + static let models = NSLocalizedString( + "aiSettings.models", + value: "Models", + comment: "Label for the AI models selection" + ) + + static let selectModel = NSLocalizedString( + "aiSettings.selectModel", + value: "Select Model", + comment: "Accessibility label for the AI model picker" + ) + + static let enterAPIKey = NSLocalizedString( + "aiSettings.enterAPIKey", + value: "Enter API Key", + comment: "Placeholder text for the API key input field" + ) + + static let save = NSLocalizedString( + "aiSettings.save", + value: "Save", + comment: "Button title to save API key" + ) + + static let edit = NSLocalizedString( + "aiSettings.edit", + value: "Edit", + comment: "Button title to edit API key" + ) + + static let apiKeyDescription = NSLocalizedString( + "aiSettings.apiKeyDescription", + value: "Enter your API key to use AI generation at public API costs.", + comment: "Description text explaining the purpose of the API key" + ) + + static let builtInAIEnabled = NSLocalizedString( + "aiSettings.builtInAIEnabled", + value: "AI capabilities are already enabled for this site.", + comment: "Message displayed when built-in AI feature is enabled" + ) + + static func currentAISource(_ source: String) -> String { + String(format: NSLocalizedString( + "aiSettings.currentAISource", + value: "Current AI source: %@", + comment: "Label showing the current AI source being used. %@ shows the provider name (e.g. Jetpack)" + ), source) + } + + static let apiKeyDisclaimer = NSLocalizedString( + "aiSettings.apiKeyDisclaimer", + value: "API keys open up access to potentially sensitive information. Do not share your API key with others or expose them.", + comment: "Warning message about keeping API keys secure" + ) + } +} + +#Preview { + AISettingsView(viewModel: AISettingsViewModel()) +} diff --git a/WooCommerce/Classes/ViewRelated/AI Settings/AISettingsViewModel.swift b/WooCommerce/Classes/ViewRelated/AI Settings/AISettingsViewModel.swift new file mode 100644 index 00000000000..10dd4b16467 --- /dev/null +++ b/WooCommerce/Classes/ViewRelated/AI Settings/AISettingsViewModel.swift @@ -0,0 +1,59 @@ +import SwiftUI +import KeychainAccess + +final class AISettingsViewModel: ObservableObject { + private var keychain = Keychain(service: WooConstants.keychainServiceName) + + @Published var apiKey: String // TODO: Make function and restrict set access + @Published var selectedModel: String + @Published var selectedProvider: String + @Published var isEditingApiKey: Bool + + private let defaults: UserDefaults + + let openAIModels = [ + "gpt-4o", + "gpt-4-turbo", + "gpt-3.5-turbo" + ] + + let anthropicModels = [ + "claude-3-haiku-20240307" + ] + + init(defaults: UserDefaults = .standard) { + self.defaults = defaults + self.apiKey = Keychain(service: WooConstants.keychainServiceName).aiProviderAPIKey ?? "" + self.selectedModel = defaults.string(forKey: "AIProviderModel") ?? "" + self.selectedProvider = defaults.string(forKey: "AIProvider") ?? "" + self.isEditingApiKey = false + } + + func onAppear() { + isEditingApiKey = apiKey.isEmpty + } + + func updateProvider(_ provider: String) { + selectedProvider = provider + selectedModel = provider == "OpenAI" ? openAIModels.first ?? "" : anthropicModels.first ?? "" + saveSettings() + } + + func clearApiKey() { + apiKey = "" + keychain.aiProviderAPIKey = nil + } + + func toggleEditing() { + if isEditingApiKey { + saveSettings() + } + isEditingApiKey.toggle() + } + + private func saveSettings() { + keychain.aiProviderAPIKey = apiKey + defaults.setValue(selectedModel, forKey: "AIProviderModel") + defaults.setValue(selectedProvider, forKey: "AIProvider") + } +} diff --git a/WooCommerce/Classes/ViewRelated/Hub Menu/HubMenu.swift b/WooCommerce/Classes/ViewRelated/Hub Menu/HubMenu.swift index ad0ccaedc7d..aee471ede0a 100644 --- a/WooCommerce/Classes/ViewRelated/Hub Menu/HubMenu.swift +++ b/WooCommerce/Classes/ViewRelated/Hub Menu/HubMenu.swift @@ -69,6 +69,8 @@ struct HubMenu: View { ServiceLocator.analytics.track(event: .Blaze.blazeCampaignListEntryPointSelected(source: .menu)) case HubMenuViewModel.PointOfSaleEntryPoint.id: viewModel.showsPOS = true + case HubMenuViewModel.AISettings.id: + ServiceLocator.analytics.track(.hubMenuAISettingsTapped) default: break } @@ -183,6 +185,10 @@ private extension HubMenu { BlazeCampaignListHostingControllerRepresentable(siteID: viewModel.siteID, selectedCampaignID: campaignID) case .blazeCampaignCreation: BlazeCampaignListHostingControllerRepresentable(siteID: viewModel.siteID, startsCampaignCreationOnAppear: true) + case .aiSettings: + // TODO: Pass eligibility, so we know what's the AI source + let viewModel = AISettingsViewModel() + AISettingsView(viewModel: viewModel) } } .navigationBarTitleDisplayMode(.inline) diff --git a/WooCommerce/Classes/ViewRelated/Hub Menu/HubMenuViewModel.swift b/WooCommerce/Classes/ViewRelated/Hub Menu/HubMenuViewModel.swift index f9ec4f2cf57..351dc3ae195 100644 --- a/WooCommerce/Classes/ViewRelated/Hub Menu/HubMenuViewModel.swift +++ b/WooCommerce/Classes/ViewRelated/Hub Menu/HubMenuViewModel.swift @@ -15,6 +15,7 @@ extension NSNotification.Name { /// Destination views that the hub menu can navigate to. enum HubMenuNavigationDestination: Hashable { + case aiSettings case payments case settings case blaze @@ -93,6 +94,7 @@ final class HubMenuViewModel: ObservableObject { private let inboxEligibilityChecker: InboxEligibilityChecker private let blazeEligibilityChecker: BlazeEligibilityCheckerProtocol private let googleAdsEligibilityChecker: GoogleAdsEligibilityChecker + private let productCreationAIEligibilityChecker: ProductCreationAIEligibilityCheckerProtocol private(set) lazy var posItemProvider: PointOfSaleItemServiceProtocol = { let storage = ServiceLocator.storageManager @@ -111,6 +113,7 @@ final class HubMenuViewModel: ObservableObject { @Published private var isSiteEligibleForBlaze = false @Published private var isSiteEligibleForGoogleAds = false @Published private var isSiteEligibleForInbox = false + @Published private var isSiteEligibleForProductAICreation = false private var cancellables: Set = [] @@ -149,6 +152,7 @@ final class HubMenuViewModel: ObservableObject { inboxEligibilityChecker: InboxEligibilityChecker = InboxEligibilityUseCase(), blazeEligibilityChecker: BlazeEligibilityCheckerProtocol = BlazeEligibilityChecker(), googleAdsEligibilityChecker: GoogleAdsEligibilityChecker = DefaultGoogleAdsEligibilityChecker(), + productCreationAIEligibilityChecker: ProductCreationAIEligibilityCheckerProtocol = ProductCreationAIEligibilityChecker(), analytics: Analytics = ServiceLocator.analytics) { self.siteID = siteID self.credentials = stores.sessionManager.defaultCredentials @@ -160,6 +164,7 @@ final class HubMenuViewModel: ObservableObject { self.inboxEligibilityChecker = inboxEligibilityChecker self.blazeEligibilityChecker = blazeEligibilityChecker self.googleAdsEligibilityChecker = googleAdsEligibilityChecker + self.productCreationAIEligibilityChecker = productCreationAIEligibilityChecker self.cardPresentPaymentsOnboarding = CardPresentPaymentsOnboardingUseCase() self.posEligibilityChecker = POSEligibilityChecker(siteSettings: ServiceLocator.selectedSiteSettings, currencySettings: ServiceLocator.currencySettings, @@ -321,18 +326,25 @@ private extension HubMenuViewModel { } func setupGeneralElements() { - $shouldShowNewFeatureBadgeOnPayments - .combineLatest($isSiteEligibleForInbox, - $isSiteEligibleForBlaze, - $isSiteEligibleForGoogleAds) + let eligibilityPublisher = $isSiteEligibleForInbox + .combineLatest($isSiteEligibleForBlaze, $isSiteEligibleForGoogleAds) + .map { (inbox, blaze, googleAds) -> (Bool, Bool, Bool) in + return (inbox, blaze, googleAds) + } + eligibilityPublisher + .combineLatest($shouldShowNewFeatureBadgeOnPayments, $isSiteEligibleForProductAICreation) .map { [weak self] combinedResult -> [HubMenuItem] in guard let self else { return [] } - let (shouldShowBadgeOnPayments, eligibleForInbox, eligibleForBlaze, eligibleForGoogleAds) = combinedResult - return createGeneralElements( + let ((eligibleForInbox, eligibleForBlaze, eligibleForGoogleAds), + shouldShowBadgeOnPayments, + eligibleForProductAICreation) = combinedResult + + return self.createGeneralElements( shouldShowBadgeOnPayments: shouldShowBadgeOnPayments, eligibleForGoogleAds: eligibleForGoogleAds, eligibleForBlaze: eligibleForBlaze, - eligibleForInbox: eligibleForInbox + eligibleForInbox: eligibleForInbox, + eligibleForProductAICreation: eligibleForProductAICreation ) } .assign(to: &$generalElements) @@ -341,11 +353,16 @@ private extension HubMenuViewModel { func createGeneralElements(shouldShowBadgeOnPayments: Bool, eligibleForGoogleAds: Bool, eligibleForBlaze: Bool, - eligibleForInbox: Bool) -> [HubMenuItem] { + eligibleForInbox: Bool, + eligibleForProductAICreation: Bool) -> [HubMenuItem] { var items: [HubMenuItem] = [ Payments(iconBadge: shouldShowBadgeOnPayments ? .dot : nil) ] + if eligibleForProductAICreation { + items.append(AISettings()) + } + if eligibleForGoogleAds { items.append(GoogleAds()) } @@ -424,8 +441,8 @@ private extension HubMenuViewModel { } func updateMenuItemEligibility(with site: Yosemite.Site) { - isSiteEligibleForInbox = inboxEligibilityChecker.isEligibleForInbox(siteID: site.siteID) + isSiteEligibleForProductAICreation = productCreationAIEligibilityChecker.isEligible Task { @MainActor in isSiteEligibleForGoogleAds = await googleAdsEligibilityChecker.isSiteEligible(siteID: site.siteID) @@ -546,6 +563,23 @@ extension HubMenuViewModel { let navigationDestination: HubMenuNavigationDestination? = .settings } + struct AISettings: HubMenuItem { + static var id = "ai-settings" + + let title: String = "AI Settings" + let description: String = "Manage your store's AI-powered features" + let icon: UIImage = UIImage(systemName: "wand.and.rays.inverse")! + let iconColor: UIColor = .withColorStudio(.green) + let accessibilityIdentifier: String = "menu-ai" + let trackingOption: String = "ai" + let iconBadge: HubMenuBadgeType? + let navigationDestination: HubMenuNavigationDestination? = .aiSettings + + init(iconBadge: HubMenuBadgeType? = nil) { + self.iconBadge = iconBadge + } + } + struct Payments: HubMenuItem { static var id = "payments" diff --git a/WooCommerce/Classes/ViewRelated/Products/AI/ProductDescriptionGenerationView.swift b/WooCommerce/Classes/ViewRelated/Products/AI/ProductDescriptionGenerationView.swift index dc2932f90a7..7fd180468af 100644 --- a/WooCommerce/Classes/ViewRelated/Products/AI/ProductDescriptionGenerationView.swift +++ b/WooCommerce/Classes/ViewRelated/Products/AI/ProductDescriptionGenerationView.swift @@ -22,7 +22,7 @@ final class ProductDescriptionGenerationHostingController: UIHostingController

{ init(productTypes: [BottomSheetProductType], + aiSource: AISource, onAIOption: @escaping () -> Void, onProductTypeOption: @escaping (BottomSheetProductType) -> Void) { let rootView = AddProductWithAIActionSheet(productTypes: productTypes, + aiSource: aiSource, onAIOption: onAIOption, onProductTypeOption: onProductTypeOption) super.init(rootView: rootView) @@ -28,13 +30,16 @@ struct AddProductWithAIActionSheet: View { @State private var isShowingManualOptions: Bool = false private let productTypes: [BottomSheetProductType] + private let aiSource: AISource private let onAIOption: () -> Void private let onProductTypeOption: (BottomSheetProductType) -> Void init(productTypes: [BottomSheetProductType], + aiSource: AISource, onAIOption: @escaping () -> Void, onProductTypeOption: @escaping (BottomSheetProductType) -> Void) { self.productTypes = productTypes + self.aiSource = aiSource self.onAIOption = onAIOption self.onProductTypeOption = onProductTypeOption } @@ -70,8 +75,13 @@ struct AddProductWithAIActionSheet: View { VStack(alignment: .leading, spacing: Constants.verticalSpacing) { Text(Localization.CreateProductWithAI.aiTitle) .bodyStyle() - Text(Localization.CreateProductWithAI.aiDescription) - .subheadlineStyle() + if aiSource == .internal { + Text(Localization.CreateProductWithAI.aiDescription) + .subheadlineStyle() + } else { + Text(Localization.CreateProductWithAI.merchantAIDescription) + .subheadlineStyle() + } AdaptiveStack(horizontalAlignment: .leading) { Text(Localization.CreateProductWithAI.legalText) Text(.init(Localization.CreateProductWithAI.learnMore)) @@ -161,6 +171,11 @@ private extension AddProductWithAIActionSheet { value: "Let us generate product details for you", comment: "Description of the option to add new product with AI assistance" ) + static let merchantAIDescription = NSLocalizedString( + "addProductWithAIActionSheet.createProductWithAI.merchantAiDescription", + value: "Generate product details using AI. Enter your API key under Settings > AI Settings", + comment: "Description of the option to add new product with AI assistance" + ) static let legalText = NSLocalizedString( "addProductWithAIActionSheet.createProductWithAI.legalText", value: "Powered by AI.", @@ -198,6 +213,7 @@ struct AddProductWithAIActionSheet_Previews: PreviewProvider { .grouped, .affiliate ], + aiSource: .internal, onAIOption: {}, onProductTypeOption: {_ in } ) diff --git a/WooCommerce/Classes/ViewRelated/Products/Add Product/AddProductWithAI/Preview/ProductDetailPreviewViewModel.swift b/WooCommerce/Classes/ViewRelated/Products/Add Product/AddProductWithAI/Preview/ProductDetailPreviewViewModel.swift index 0a592fde6ee..ec6822f749c 100644 --- a/WooCommerce/Classes/ViewRelated/Products/Add Product/AddProductWithAI/Preview/ProductDetailPreviewViewModel.swift +++ b/WooCommerce/Classes/ViewRelated/Products/Add Product/AddProductWithAI/Preview/ProductDetailPreviewViewModel.swift @@ -107,6 +107,10 @@ final class ProductDetailPreviewViewModel: ObservableObject { return productDescription != original } + private var shouldUseMerchantAIKey: Bool { + ProductCreationAIEligibilityChecker().aiSource == .merchant + } + private let productFeatures: String private let siteID: Int64 @@ -199,7 +203,8 @@ final class ProductDetailPreviewViewModel: ObservableObject { async let language = try identifyLanguage() let aiTone = userDefaults.aiTone(for: siteID) let aiProduct = try await generateProduct(language: language, - tone: aiTone) + tone: aiTone, + shouldUseMerchantAIKey: shouldUseMerchantAIKey) analytics.track(event: .ProductCreationAI.nameDescriptionOptionsGenerated( nameCount: aiProduct.names.count, shortDescriptionCount: aiProduct.shortDescriptions.count, @@ -529,6 +534,7 @@ private extension ProductDetailPreviewViewModel { let language = try await withCheckedThrowingContinuation { continuation in stores.dispatch(ProductAction.identifyLanguage(siteID: siteID, string: productInfo, + shouldUseMerchantAIKey: shouldUseMerchantAIKey, feature: .productCreation, completion: { result in continuation.resume(with: result) @@ -543,12 +549,14 @@ private extension ProductDetailPreviewViewModel { @MainActor func generateProduct(language: String, - tone: AIToneVoice) async throws -> AIProduct { + tone: AIToneVoice, + shouldUseMerchantAIKey: Bool) async throws -> AIProduct { let existingCategories = categoryResultController.fetchedObjects let existingTags = tagResultController.fetchedObjects return try await generateAIProduct(language: language, tone: tone, + shouldUseMerchantAIKey: shouldUseMerchantAIKey, existingCategories: existingCategories, existingTags: existingTags) } @@ -556,6 +564,7 @@ private extension ProductDetailPreviewViewModel { @MainActor func generateAIProduct(language: String, tone: AIToneVoice, + shouldUseMerchantAIKey: Bool, existingCategories: [ProductCategory], existingTags: [ProductTag]) async throws -> AIProduct { try await withCheckedThrowingContinuation { continuation in @@ -564,6 +573,7 @@ private extension ProductDetailPreviewViewModel { keywords: productFeatures, language: language, tone: tone.rawValue, + shouldUseMerchantAIKey: shouldUseMerchantAIKey, currencySymbol: currency, dimensionUnit: dimensionUnit, weightUnit: weightUnit, diff --git a/WooCommerce/Classes/ViewRelated/Products/Add Product/AddProductWithAI/ProductCreationAIEligibilityChecker.swift b/WooCommerce/Classes/ViewRelated/Products/Add Product/AddProductWithAI/ProductCreationAIEligibilityChecker.swift index 50c3335d130..3e2ed574f51 100644 --- a/WooCommerce/Classes/ViewRelated/Products/Add Product/AddProductWithAI/ProductCreationAIEligibilityChecker.swift +++ b/WooCommerce/Classes/ViewRelated/Products/Add Product/AddProductWithAI/ProductCreationAIEligibilityChecker.swift @@ -1,18 +1,31 @@ import Foundation import Yosemite +import Experiments /// Protocol for checking "add product using AI" eligibility for easier unit testing. protocol ProductCreationAIEligibilityCheckerProtocol { /// Checks if the user is eligible for the "add product using AI" feature. var isEligible: Bool { get } + var aiSource: AISource { get } +} + +enum AISource: String { + case none = "none" + case `internal` = "internal" + case merchant = "merchant" } /// Checks the eligibility for the "add product using AI" feature. final class ProductCreationAIEligibilityChecker: ProductCreationAIEligibilityCheckerProtocol { private let stores: StoresManager + private let featureFlagService: FeatureFlagService - init(stores: StoresManager = ServiceLocator.stores) { + private(set) var aiSource: AISource = .none + + init(stores: StoresManager = ServiceLocator.stores, + featureFlagService: FeatureFlagService = ServiceLocator.featureFlagService) { self.stores = stores + self.featureFlagService = featureFlagService } var isEligible: Bool { @@ -20,6 +33,20 @@ final class ProductCreationAIEligibilityChecker: ProductCreationAIEligibilityChe return false } - return site.isWordPressComStore || site.isAIAssistantFeatureActive + // By default, check first if we provide AI capabilities from WPCOM/JP + if site.isWordPressComStore || site.isAIAssistantFeatureActive { + aiSource = .internal + return true + } else { + // As fallback, allow personal API keys usage based on feature flag: + switch featureFlagService.isFeatureFlagEnabled(.allowMerchantAIAPIKey) { + case true: + aiSource = .merchant + return true + case false: + aiSource = .none + return false + } + } } } diff --git a/WooCommerce/WooCommerce.xcodeproj/project.pbxproj b/WooCommerce/WooCommerce.xcodeproj/project.pbxproj index 73437df3f9b..ab9f90aa074 100644 --- a/WooCommerce/WooCommerce.xcodeproj/project.pbxproj +++ b/WooCommerce/WooCommerce.xcodeproj/project.pbxproj @@ -1564,6 +1564,7 @@ 681BB5FE2D676061008AF8BB /* POSPadding.swift in Sources */ = {isa = PBXBuildFile; fileRef = 681BB5FD2D676060008AF8BB /* POSPadding.swift */; }; 682210ED2909666600814E14 /* CustomerSearchUICommandTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 682210EC2909666600814E14 /* CustomerSearchUICommandTests.swift */; }; 6827140F28A3988300E6E3F6 /* DismissableNoticeView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 6827140E28A3988300E6E3F6 /* DismissableNoticeView.swift */; }; + 682DB48F2D88230800E38449 /* AISettingsView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 682DB48E2D88230600E38449 /* AISettingsView.swift */; }; 6832C7CA26DA5C4500BA4088 /* LabeledTextViewTableViewCell.swift in Sources */ = {isa = PBXBuildFile; fileRef = 6832C7C926DA5C4500BA4088 /* LabeledTextViewTableViewCell.swift */; }; 6832C7CC26DA5FDF00BA4088 /* LabeledTextViewTableViewCell.xib in Resources */ = {isa = PBXBuildFile; fileRef = 6832C7CB26DA5FDE00BA4088 /* LabeledTextViewTableViewCell.xib */; }; 683421642ACE9391009021D7 /* ProductDiscountView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 683421632ACE9391009021D7 /* ProductDiscountView.swift */; }; @@ -1628,6 +1629,8 @@ 68E674AB2A4DAB8C0034BA1E /* CompletedUpgradeView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 68E674AA2A4DAB8C0034BA1E /* CompletedUpgradeView.swift */; }; 68E674AD2A4DAC010034BA1E /* CurrentPlanDetailsView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 68E674AC2A4DAC010034BA1E /* CurrentPlanDetailsView.swift */; }; 68E674AF2A4DACD50034BA1E /* UpgradeTopBarView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 68E674AE2A4DACD50034BA1E /* UpgradeTopBarView.swift */; }; + 68E7C0B42D8A707500934B04 /* AISettingsViewModel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 68E7C0B32D8A707100934B04 /* AISettingsViewModel.swift */; }; + 68E7C0B72D8A730200934B04 /* AISettingsViewModelTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 68E7C0B62D8A730200934B04 /* AISettingsViewModelTests.swift */; }; 68E952CC287536010095A23D /* SafariView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 68E952CB287536010095A23D /* SafariView.swift */; }; 68E952D0287587BF0095A23D /* CardReaderManualRowView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 68E952CF287587BF0095A23D /* CardReaderManualRowView.swift */; }; 68E952D22875A44B0095A23D /* CardReaderType+Manual.swift in Sources */ = {isa = PBXBuildFile; fileRef = 68E952D12875A44B0095A23D /* CardReaderType+Manual.swift */; }; @@ -4727,6 +4730,7 @@ 681BB5FD2D676060008AF8BB /* POSPadding.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = POSPadding.swift; sourceTree = ""; }; 682210EC2909666600814E14 /* CustomerSearchUICommandTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CustomerSearchUICommandTests.swift; sourceTree = ""; }; 6827140E28A3988300E6E3F6 /* DismissableNoticeView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = DismissableNoticeView.swift; sourceTree = ""; }; + 682DB48E2D88230600E38449 /* AISettingsView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AISettingsView.swift; sourceTree = ""; }; 6832C7C926DA5C4500BA4088 /* LabeledTextViewTableViewCell.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LabeledTextViewTableViewCell.swift; sourceTree = ""; }; 6832C7CB26DA5FDE00BA4088 /* LabeledTextViewTableViewCell.xib */ = {isa = PBXFileReference; lastKnownFileType = file.xib; path = LabeledTextViewTableViewCell.xib; sourceTree = ""; }; 683421632ACE9391009021D7 /* ProductDiscountView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ProductDiscountView.swift; sourceTree = ""; }; @@ -4791,6 +4795,8 @@ 68E674AA2A4DAB8C0034BA1E /* CompletedUpgradeView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CompletedUpgradeView.swift; sourceTree = ""; }; 68E674AC2A4DAC010034BA1E /* CurrentPlanDetailsView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CurrentPlanDetailsView.swift; sourceTree = ""; }; 68E674AE2A4DACD50034BA1E /* UpgradeTopBarView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = UpgradeTopBarView.swift; sourceTree = ""; }; + 68E7C0B32D8A707100934B04 /* AISettingsViewModel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AISettingsViewModel.swift; sourceTree = ""; }; + 68E7C0B62D8A730200934B04 /* AISettingsViewModelTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AISettingsViewModelTests.swift; sourceTree = ""; }; 68E952CB287536010095A23D /* SafariView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SafariView.swift; sourceTree = ""; }; 68E952CF287587BF0095A23D /* CardReaderManualRowView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CardReaderManualRowView.swift; sourceTree = ""; }; 68E952D12875A44B0095A23D /* CardReaderType+Manual.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = "CardReaderType+Manual.swift"; sourceTree = ""; }; @@ -9828,6 +9834,15 @@ path = Customer; sourceTree = ""; }; + 682DB48D2D8822FC00E38449 /* AI Settings */ = { + isa = PBXGroup; + children = ( + 682DB48E2D88230600E38449 /* AISettingsView.swift */, + 68E7C0B32D8A707100934B04 /* AISettingsViewModel.swift */, + ); + path = "AI Settings"; + sourceTree = ""; + }; 6850C5EC2B69E6460026A93B /* Receipts */ = { isa = PBXGroup; children = ( @@ -9894,6 +9909,14 @@ path = Coupons; sourceTree = ""; }; + 68E7C0B52D8A72E200934B04 /* AI Settings */ = { + isa = PBXGroup; + children = ( + 68E7C0B62D8A730200934B04 /* AISettingsViewModelTests.swift */, + ); + path = "AI Settings"; + sourceTree = ""; + }; 68F151DF2C0DA7800082AEC8 /* Models */ = { isa = PBXGroup; children = ( @@ -10724,6 +10747,7 @@ B56DB3EF2049C06D00D4AA8E /* ViewRelated */ = { isa = PBXGroup; children = ( + 682DB48D2D8822FC00E38449 /* AI Settings */, B626C7192876599B0083820C /* Custom Fields */, 86023FA82B15CA8D00A28F07 /* Themes */, DED91DF72AD78A0C00CDCC53 /* Blaze */, @@ -12827,6 +12851,7 @@ D816DDBA22265D8000903E59 /* ViewRelated */ = { isa = PBXGroup; children = ( + 68E7C0B52D8A72E200934B04 /* AI Settings */, 864059FE2C6F67A000DA04DC /* Custom Fields */, 86023FAB2B16D80E00A28F07 /* Themes */, EE45E2C02A42C9C70085F227 /* Feature Highlight */, @@ -16115,6 +16140,7 @@ CCCFFC5A2934EF5E006130AF /* StatsIntervalDataParser.swift in Sources */, DE5746342B4512900034B10D /* BlazeBudgetSettingView.swift in Sources */, 26771A14256FFA8700EE030E /* IssueRefundCoordinatingController.swift in Sources */, + 682DB48F2D88230800E38449 /* AISettingsView.swift in Sources */, 027A2E162513356100DA6ACB /* AppleIDCredentialChecker.swift in Sources */, 26B98758273C5BE30090E8CA /* EditCustomerNoteViewModelProtocol.swift in Sources */, DE78082D2BDF9852005D1E30 /* OrderStatsV4Interval+Chart.swift in Sources */, @@ -16239,6 +16265,7 @@ 267F60132A0C24D700CD1E4E /* PrivacyBannerViewController.swift in Sources */, D8610BD2256F291000A5DF27 /* JetpackErrorViewModel.swift in Sources */, DAAF53B82CF75701006D8880 /* WooShippingAddPackageViewModel.swift in Sources */, + 68E7C0B42D8A707500934B04 /* AISettingsViewModel.swift in Sources */, 26C6E8E626E6B5F500C7BB0F /* StateSelectorViewModel.swift in Sources */, B99B30CD2A85200D0066743D /* AddressFormViewModel.swift in Sources */, 205E79462C204387001BA266 /* PointOfSaleCardPresentPaymentCancelledOnReaderMessageViewModel.swift in Sources */, @@ -17445,6 +17472,7 @@ 0375799D2822F9040083F2E1 /* MockCardPresentPaymentsOnboardingPresenter.swift in Sources */, 455800CC24C6F83F00A8D117 /* ProductSettingsSectionsTests.swift in Sources */, 86EC6EB92CD0BA6A00D7D2FE /* CustomFieldEditorViewModelTests.swift in Sources */, + 68E7C0B72D8A730200934B04 /* AISettingsViewModelTests.swift in Sources */, CE86C8332CC8F9BB00B1764D /* WooShippingServiceCardViewModelTests.swift in Sources */, 26F92DBE2C7ECAB20074A208 /* EditOrderFormTests.swift in Sources */, D85B833F2230F268002168F3 /* SummaryTableViewCellTests.swift in Sources */, diff --git a/WooCommerce/WooCommerceTests/Mocks/MockFeatureFlagService.swift b/WooCommerce/WooCommerceTests/Mocks/MockFeatureFlagService.swift index 04ae3185a36..d12cb3eb6f4 100644 --- a/WooCommerce/WooCommerceTests/Mocks/MockFeatureFlagService.swift +++ b/WooCommerce/WooCommerceTests/Mocks/MockFeatureFlagService.swift @@ -23,6 +23,7 @@ final class MockFeatureFlagService: FeatureFlagService { var favoriteProducts: Bool var isProductGlobalUniqueIdentifierSupported: Bool var hideSitesInStorePicker: Bool + var allowMerchantAIAPIKey: Bool init(isInboxOn: Bool = false, isShowInboxCTAEnabled: Bool = false, @@ -44,7 +45,8 @@ final class MockFeatureFlagService: FeatureFlagService { viewEditCustomFieldsInProductsAndOrders: Bool = false, favoriteProducts: Bool = false, isProductGlobalUniqueIdentifierSupported: Bool = false, - hideSitesInStorePicker: Bool = false) { + hideSitesInStorePicker: Bool = false, + allowMerchantAIAPIKey: Bool = false) { self.isInboxOn = isInboxOn self.isShowInboxCTAEnabled = isShowInboxCTAEnabled self.isUpdateOrderOptimisticallyOn = isUpdateOrderOptimisticallyOn @@ -66,6 +68,7 @@ final class MockFeatureFlagService: FeatureFlagService { self.favoriteProducts = favoriteProducts self.isProductGlobalUniqueIdentifierSupported = isProductGlobalUniqueIdentifierSupported self.hideSitesInStorePicker = hideSitesInStorePicker + self.allowMerchantAIAPIKey = allowMerchantAIAPIKey } func isFeatureFlagEnabled(_ featureFlag: FeatureFlag) -> Bool { @@ -112,6 +115,8 @@ final class MockFeatureFlagService: FeatureFlagService { return isProductGlobalUniqueIdentifierSupported case .hideSitesInStorePicker: return hideSitesInStorePicker + case .allowMerchantAIAPIKey: + return allowMerchantAIAPIKey default: return false } diff --git a/WooCommerce/WooCommerceTests/Mocks/MockProductCreationAIEligibilityChecker.swift b/WooCommerce/WooCommerceTests/Mocks/MockProductCreationAIEligibilityChecker.swift index 21aec37a16e..21da20d8a0d 100644 --- a/WooCommerce/WooCommerceTests/Mocks/MockProductCreationAIEligibilityChecker.swift +++ b/WooCommerce/WooCommerceTests/Mocks/MockProductCreationAIEligibilityChecker.swift @@ -12,4 +12,8 @@ final class MockProductCreationAIEligibilityChecker: ProductCreationAIEligibilit var isEligible: Bool { eligible } + + var aiSource: AISource { + .none + } } diff --git a/WooCommerce/WooCommerceTests/ViewRelated/AI Settings/AISettingsViewModelTests.swift b/WooCommerce/WooCommerceTests/ViewRelated/AI Settings/AISettingsViewModelTests.swift new file mode 100644 index 00000000000..7ed848b7318 --- /dev/null +++ b/WooCommerce/WooCommerceTests/ViewRelated/AI Settings/AISettingsViewModelTests.swift @@ -0,0 +1,164 @@ +import XCTest +@testable import WooCommerce + +final class AISettingsViewModelTests: XCTestCase { + var sut: AISettingsViewModel! + var defaults: UserDefaults! + let suiteName = #file + + private let testAIProvider = "test-ai-model" + private let testAIModel = "test-ai-provider" + + let expectedOpenAIModels = [ + "gpt-4o", + "gpt-4-turbo", + "gpt-3.5-turbo" + ] + + let expectedAnthropicModels = [ + "claude-3-haiku-20240307" + ] + + override func setUp() { + super.setUp() + defaults = UserDefaults(suiteName: suiteName) + defaults.removePersistentDomain(forName: suiteName) + sut = AISettingsViewModel(defaults: defaults) + } + + override func tearDown() { + defaults.removePersistentDomain(forName: suiteName) + defaults = nil + sut = nil + super.tearDown() + } + + func test_sut_when_init_then_loads_values_from_userdefaults() { + // Given + guard let defaults = UserDefaults(suiteName: suiteName) else { + return XCTFail() + } + defaults.set(testAIModel, forKey: "AIProviderModel") + defaults.set(testAIProvider, forKey: "AIProvider") + + // When + sut = AISettingsViewModel(defaults: defaults) + + // Then + XCTAssertEqual(sut.selectedModel, testAIModel) + XCTAssertEqual(sut.selectedProvider, testAIProvider) + } + + func test_init_should_use_empty_values_when_userdefaults_is_empty() { + XCTAssertEqual(sut.selectedModel, "") + XCTAssertEqual(sut.selectedProvider, "") + } + + func test_sut_when_init_then_has_correct_openai_models() { + XCTAssertEqual(sut.openAIModels, expectedOpenAIModels) + } + + func test_sut_when_init_then_has_correct_anthropic_models() { + XCTAssertEqual(sut.anthropicModels, expectedAnthropicModels) + } + + func test_sut_when_update_provider_to_openai_then_selects_first_openai_model_and_saves() { + // When + sut.updateProvider("OpenAI") + + // Then + XCTAssertEqual(sut.selectedProvider, "OpenAI") + XCTAssertEqual(sut.selectedModel, expectedOpenAIModels.first) + + // Verify that is saved to UserDefaults + XCTAssertEqual(defaults.string(forKey: "AIProvider"), "OpenAI") + XCTAssertEqual(defaults.string(forKey: "AIProviderModel"), expectedOpenAIModels.first) + } + + func test_sut_when_update_provider_to_anthropic_then_selects_first_anthropic_model_and_saves() { + // When + sut.updateProvider("Anthropic") + + // Then + XCTAssertEqual(sut.selectedProvider, "Anthropic") + XCTAssertEqual(sut.selectedModel, expectedAnthropicModels.first) + + // And verify saved to defaults + XCTAssertEqual(defaults.string(forKey: "AIProvider"), "Anthropic") + XCTAssertEqual(defaults.string(forKey: "AIProviderModel"), expectedAnthropicModels.first) + } + + func test_sut_when_update_provider_then_persists_all_values_to_userdefaults() { + // Given + sut.apiKey = "new-api-key" + + // When + sut.updateProvider("OpenAI") + + // Then + XCTAssertEqual(defaults.string(forKey: "AIProviderModel"), expectedOpenAIModels.first) + XCTAssertEqual(defaults.string(forKey: "AIProvider"), "OpenAI") + } + + func test_sut_when_onAppear_with_empty_apiKey_then_enables_editing() { + // Given + sut.apiKey = "" + sut.isEditingApiKey = false + + // When + sut.onAppear() + + // Then + XCTAssertTrue(sut.isEditingApiKey) + } + + func test_sut_when_onAppear_with_apiKey_then_disables_editing() { + // Given + sut.apiKey = "some-key" + sut.isEditingApiKey = true + + // When + sut.onAppear() + + // Then + XCTAssertFalse(sut.isEditingApiKey) + } + + func test_sut_when_clearApiKey_then_empties_apiKey() { + // Given + sut.apiKey = "some-key" + + // When + sut.clearApiKey() + + // Then + XCTAssertTrue(sut.apiKey.isEmpty) + } + + func test_sut_when_toggleEditing_then_toggles_editing_state() { + // Given + sut.isEditingApiKey = false + + // When + sut.toggleEditing() + + // Then + XCTAssertTrue(sut.isEditingApiKey) + } + + func test_sut_when_togglEediting_from_editing_then_saves_settings() { + // Given + sut.isEditingApiKey = true + sut.apiKey = "new-key" + sut.selectedModel = "test-model" + sut.selectedProvider = "test-provider" + + // When + sut.toggleEditing() + + // Then + XCTAssertFalse(sut.isEditingApiKey) + XCTAssertEqual(defaults.string(forKey: "AIProviderModel"), "test-model") + XCTAssertEqual(defaults.string(forKey: "AIProvider"), "test-provider") + } +} diff --git a/WooCommerce/WooCommerceTests/ViewRelated/Products/AI/ProductDescriptionGenerationViewModelTests.swift b/WooCommerce/WooCommerceTests/ViewRelated/Products/AI/ProductDescriptionGenerationViewModelTests.swift index 423c480d295..0ed2d0ae126 100644 --- a/WooCommerce/WooCommerceTests/ViewRelated/Products/AI/ProductDescriptionGenerationViewModelTests.swift +++ b/WooCommerce/WooCommerceTests/ViewRelated/Products/AI/ProductDescriptionGenerationViewModelTests.swift @@ -209,9 +209,9 @@ final class ProductDescriptionGenerationViewModelTests: XCTestCase { viewModel.name = "Fun" stores.whenReceivingAction(ofType: ProductAction.self) { action in switch action { - case let .generateProductDescription(_, _, _, _, completion): + case let .generateProductDescription(_, _, _, _, _, completion): completion(.success("Must buy")) - case let .identifyLanguage(_, _, _, completion): + case let .identifyLanguage(_, _, _, _, completion): completion(.success("en")) identifyLanguageRequestCounter += 1 default: @@ -251,9 +251,9 @@ final class ProductDescriptionGenerationViewModelTests: XCTestCase { viewModel.name = "Fun" stores.whenReceivingAction(ofType: ProductAction.self) { action in switch action { - case let .generateProductDescription(_, _, _, _, completion): + case let .generateProductDescription(_, _, _, _, _, completion): completion(.success("Must buy")) - case let .identifyLanguage(_, _, _, completion): + case let .identifyLanguage(_, _, _, _, completion): completion(.success("en")) identifyLanguageRequestCounter += 1 default: @@ -448,9 +448,9 @@ private extension ProductDescriptionGenerationViewModelTests { identifyLaunguage: Result = .success("en")) { stores.whenReceivingAction(ofType: ProductAction.self) { action in switch action { - case let .generateProductDescription(_, _, _, _, completion): + case let .generateProductDescription(_, _, _, _, _, completion): completion(generatedDescription) - case let .identifyLanguage(_, _, _, completion): + case let .identifyLanguage(_, _, _, _, completion): completion(identifyLaunguage) default: return XCTFail("Unexpected action: \(action)") diff --git a/WooCommerce/WooCommerceTests/ViewRelated/Products/AI/ProductSharingMessageGenerationViewModelTests.swift b/WooCommerce/WooCommerceTests/ViewRelated/Products/AI/ProductSharingMessageGenerationViewModelTests.swift index 9080dc57e7c..c52bedd1be4 100644 --- a/WooCommerce/WooCommerceTests/ViewRelated/Products/AI/ProductSharingMessageGenerationViewModelTests.swift +++ b/WooCommerce/WooCommerceTests/ViewRelated/Products/AI/ProductSharingMessageGenerationViewModelTests.swift @@ -44,10 +44,10 @@ final class ProductSharingMessageGenerationViewModelTests: XCTestCase { XCTAssertFalse(viewModel.generationInProgress) stores.whenReceivingAction(ofType: ProductAction.self) { action in switch action { - case let .generateProductSharingMessage(_, _, _, _, _, completion): + case let .generateProductSharingMessage(_, _, _, _, _, _, completion): XCTAssertTrue(viewModel.generationInProgress) completion(.success("Check this out!")) - case let .identifyLanguage(_, _, _, completion): + case let .identifyLanguage(_, _, _, _, completion): completion(.success("en")) default: return @@ -72,9 +72,9 @@ final class ProductSharingMessageGenerationViewModelTests: XCTestCase { stores: stores) stores.whenReceivingAction(ofType: ProductAction.self) { action in switch action { - case let .generateProductSharingMessage(_, _, _, _, _, completion): + case let .generateProductSharingMessage(_, _, _, _, _, _, completion): completion(.success(expectedString)) - case let .identifyLanguage(_, _, _, completion): + case let .identifyLanguage(_, _, _, _, completion): completion(.success("en")) default: return @@ -99,9 +99,9 @@ final class ProductSharingMessageGenerationViewModelTests: XCTestCase { stores: stores) stores.whenReceivingAction(ofType: ProductAction.self) { action in switch action { - case let .generateProductSharingMessage(_, _, _, _, _, completion): + case let .generateProductSharingMessage(_, _, _, _, _, _, completion): completion(.failure(NSError(domain: "Test", code: 500))) - case let .identifyLanguage(_, _, _, completion): + case let .identifyLanguage(_, _, _, _, completion): completion(.success("en")) default: return @@ -126,7 +126,7 @@ final class ProductSharingMessageGenerationViewModelTests: XCTestCase { stores: stores) stores.whenReceivingAction(ofType: ProductAction.self) { action in switch action { - case let .identifyLanguage(_, _, _, completion): + case let .identifyLanguage(_, _, _, _, completion): completion(.failure(NSError(domain: "Test", code: 500))) default: return @@ -154,9 +154,9 @@ final class ProductSharingMessageGenerationViewModelTests: XCTestCase { analytics: analytics) stores.whenReceivingAction(ofType: ProductAction.self) { action in switch action { - case let .generateProductSharingMessage(_, _, _, _, _, completion): + case let .generateProductSharingMessage(_, _, _, _, _, _, completion): completion(.success("Test")) - case let .identifyLanguage(_, _, _, completion): + case let .identifyLanguage(_, _, _, _, completion): completion(.success(expectedLanguage)) default: return @@ -207,9 +207,9 @@ final class ProductSharingMessageGenerationViewModelTests: XCTestCase { analytics: analytics) stores.whenReceivingAction(ofType: ProductAction.self) { action in switch action { - case let .generateProductSharingMessage(_, _, _, _, _, completion): + case let .generateProductSharingMessage(_, _, _, _, _, _, completion): completion(.failure(NSError(domain: "Test", code: 500))) - case let .identifyLanguage(_, _, _, completion): + case let .identifyLanguage(_, _, _, _, completion): completion(.success("en")) default: return @@ -240,9 +240,9 @@ final class ProductSharingMessageGenerationViewModelTests: XCTestCase { analytics: analytics) stores.whenReceivingAction(ofType: ProductAction.self) { action in switch action { - case let .generateProductSharingMessage(_, _, _, _, _, completion): + case let .generateProductSharingMessage(_, _, _, _, _, _, completion): completion(.success("Test")) - case let .identifyLanguage(_, _, _, completion): + case let .identifyLanguage(_, _, _, _, completion): completion(.failure(NSError(domain: "Test", code: 500))) default: return @@ -293,9 +293,9 @@ final class ProductSharingMessageGenerationViewModelTests: XCTestCase { stores: stores) stores.whenReceivingAction(ofType: ProductAction.self) { action in switch action { - case let .generateProductSharingMessage(_, _, _, _, _, completion): + case let .generateProductSharingMessage(_, _, _, _, _, _, completion): completion(.success(expectedString)) - case let .identifyLanguage(_, _, _, completion): + case let .identifyLanguage(_, _, _, _, completion): completion(.success("en")) default: return @@ -392,9 +392,9 @@ final class ProductSharingMessageGenerationViewModelTests: XCTestCase { stores: stores) stores.whenReceivingAction(ofType: ProductAction.self) { action in switch action { - case let .generateProductSharingMessage(_, _, _, _, _, completion): + case let .generateProductSharingMessage(_, _, _, _, _, _, completion): completion(.success(expectedString)) - case let .identifyLanguage(_, _, _, completion): + case let .identifyLanguage(_, _, _, _, completion): completion(.success("en")) default: return @@ -422,9 +422,9 @@ final class ProductSharingMessageGenerationViewModelTests: XCTestCase { stores: stores) stores.whenReceivingAction(ofType: ProductAction.self) { action in switch action { - case let .generateProductSharingMessage(_, _, _, _, _, completion): + case let .generateProductSharingMessage(_, _, _, _, _, _, completion): completion(.success(expectedString)) - case let .identifyLanguage(_, _, _, completion): + case let .identifyLanguage(_, _, _, _, completion): completion(.success("en")) default: return @@ -457,9 +457,9 @@ final class ProductSharingMessageGenerationViewModelTests: XCTestCase { stores: stores) stores.whenReceivingAction(ofType: ProductAction.self) { action in switch action { - case let .generateProductSharingMessage(_, _, _, _, _, completion): + case let .generateProductSharingMessage(_, _, _, _, _, _, completion): completion(.success("Must buy")) - case let .identifyLanguage(_, _, _, completion): + case let .identifyLanguage(_, _, _, _, completion): completion(.success("en")) identifyLanguageRequestCounter += 1 default: @@ -495,9 +495,9 @@ final class ProductSharingMessageGenerationViewModelTests: XCTestCase { stores: stores) stores.whenReceivingAction(ofType: ProductAction.self) { action in switch action { - case let .generateProductSharingMessage(_, _, _, _, _, completion): + case let .generateProductSharingMessage(_, _, _, _, _, _, completion): completion(.success("Must buy")) - case let .identifyLanguage(_, _, _, completion): + case let .identifyLanguage(_, _, _, _, completion): completion(.success("en")) identifyLanguageRequestCounter += 1 default: diff --git a/WooCommerce/WooCommerceTests/ViewRelated/Products/Add Product/AddProductWithAI/ProductCreationAIEligibilityCheckerTests.swift b/WooCommerce/WooCommerceTests/ViewRelated/Products/Add Product/AddProductWithAI/ProductCreationAIEligibilityCheckerTests.swift index e23b6b74a1b..3a9add928d0 100644 --- a/WooCommerce/WooCommerceTests/ViewRelated/Products/Add Product/AddProductWithAI/ProductCreationAIEligibilityCheckerTests.swift +++ b/WooCommerce/WooCommerceTests/ViewRelated/Products/Add Product/AddProductWithAI/ProductCreationAIEligibilityCheckerTests.swift @@ -5,21 +5,25 @@ import XCTest final class ProductCreationAIEligibilityCheckerTests: XCTestCase { private var stores: MockStoresManager! + private var featureFlagService: MockFeatureFlagService! override func setUp() { super.setUp() stores = MockStoresManager(sessionManager: .makeForTesting()) + featureFlagService = MockFeatureFlagService(allowMerchantAIAPIKey: false) } override func tearDown() { stores = nil + featureFlagService = nil super.tearDown() } func test_isEligible_is_true_for_wpcom_store() throws { // Given updateDefaultStore(isWPCOMStore: true) - let checker = ProductCreationAIEligibilityChecker(stores: stores) + let checker = ProductCreationAIEligibilityChecker(stores: stores, + featureFlagService: featureFlagService) // When let isEligible = checker.isEligible @@ -31,7 +35,8 @@ final class ProductCreationAIEligibilityCheckerTests: XCTestCase { func test_isEligible_is_false_for_non_wpcom_store() throws { // Given updateDefaultStore(isWPCOMStore: false) - let checker = ProductCreationAIEligibilityChecker(stores: stores) + let checker = ProductCreationAIEligibilityChecker(stores: stores, + featureFlagService: featureFlagService) // When let isEligible = checker.isEligible @@ -42,13 +47,42 @@ final class ProductCreationAIEligibilityCheckerTests: XCTestCase { func test_isEligible_is_true_for_non_wpcom_store_when_ai_assistant_feature_is_active() throws { // Given updateDefaultStore(isWPCOMStore: false, isAIAssistantActive: true) - let checker = ProductCreationAIEligibilityChecker(stores: stores) + let checker = ProductCreationAIEligibilityChecker(stores: stores, + featureFlagService: featureFlagService) // When let isEligible = checker.isEligible // Then XCTAssertTrue(isEligible) } + + func test_allow_merchant_ai_api_key_as_fallback_when_flag_is_true_then_isEligible() { + // Given + updateDefaultStore(isWPCOMStore: false, isAIAssistantActive: false) + let enabledFlag = MockFeatureFlagService(allowMerchantAIAPIKey: true) + let checker = ProductCreationAIEligibilityChecker(stores: stores, + featureFlagService: enabledFlag) + + // When + let isEligible = checker.isEligible + + // Then + XCTAssertTrue(isEligible) + } + + func test_allow_merchant_ai_api_key_as_fallback_when_flag_is_false_then_is_not_eligible() { + // Given + updateDefaultStore(isWPCOMStore: false, isAIAssistantActive: false) + let disabledFlag = MockFeatureFlagService(allowMerchantAIAPIKey: false) + let checker = ProductCreationAIEligibilityChecker(stores: stores, + featureFlagService: disabledFlag) + + // When + let isEligible = checker.isEligible + + // Then + XCTAssertFalse(isEligible) + } } private extension ProductCreationAIEligibilityCheckerTests { diff --git a/WooCommerce/WooCommerceTests/ViewRelated/Products/Add Product/AddProductWithAI/ProductDetailPreviewViewModelTests.swift b/WooCommerce/WooCommerceTests/ViewRelated/Products/Add Product/AddProductWithAI/ProductDetailPreviewViewModelTests.swift index 55995927fbb..35a829a354c 100644 --- a/WooCommerce/WooCommerceTests/ViewRelated/Products/Add Product/AddProductWithAI/ProductDetailPreviewViewModelTests.swift +++ b/WooCommerce/WooCommerceTests/ViewRelated/Products/Add Product/AddProductWithAI/ProductDetailPreviewViewModelTests.swift @@ -184,9 +184,9 @@ final class ProductDetailPreviewViewModelTests: XCTestCase { stores.whenReceivingAction(ofType: ProductAction.self) { action in switch action { - case let .generateAIProduct(_, _, _, _, _, _, _, _, _, _, completion): + case let .generateAIProduct(_, _, _, _, _, _, _, _, _, _, _, completion): completion(.success(.fake())) - case let .identifyLanguage(_, string, _, completion): + case let .identifyLanguage(_, string, _, _, completion): // Then XCTAssertEqual(string, productFeatures) completion(.success("en")) @@ -219,10 +219,10 @@ final class ProductDetailPreviewViewModelTests: XCTestCase { stores.whenReceivingAction(ofType: ProductAction.self) { action in switch action { - case let .generateAIProduct(_, _, _, language, _, _, _, _, _, _, completion): + case let .generateAIProduct(_, _, _, language, _, _, _, _, _, _, _, completion): XCTAssertEqual(language, expectedLanguage) completion(.success(.fake())) - case let .identifyLanguage(_, _, _, completion): + case let .identifyLanguage(_, _, _, _, completion): identifyingLanguageRequestCount += 1 completion(.success(expectedLanguage)) default: @@ -359,6 +359,7 @@ final class ProductDetailPreviewViewModelTests: XCTestCase { keywords, language, tone, + _, currencySymbol, dimensionUnit, weightUnit, @@ -376,7 +377,7 @@ final class ProductDetailPreviewViewModelTests: XCTestCase { XCTAssertEqual(categories, sampleCategories) XCTAssertEqual(tags, sampleTags) completion(.success(.fake())) - case let .identifyLanguage(_, _, _, completion): + case let .identifyLanguage(_, _, _, _, completion): completion(.success(sampleLanguage)) default: break @@ -405,10 +406,10 @@ final class ProductDetailPreviewViewModelTests: XCTestCase { // When stores.whenReceivingAction(ofType: ProductAction.self) { action in switch action { - case let .generateAIProduct(_, _, _, _, _, _, _, _, _, _, completion): + case let .generateAIProduct(_, _, _, _, _, _, _, _, _, _, _, completion): XCTAssertTrue(viewModel.isGeneratingDetails) completion(.success(self.sampleAIProduct)) - case let .identifyLanguage(_, _, _, completion): + case let .identifyLanguage(_, _, _, _, completion): XCTAssertTrue(viewModel.isGeneratingDetails) completion(.success("en")) default: @@ -444,10 +445,10 @@ final class ProductDetailPreviewViewModelTests: XCTestCase { // When stores.whenReceivingAction(ofType: ProductAction.self) { action in switch action { - case let .generateAIProduct(_, _, _, _, _, _, _, _, _, _, completion): + case let .generateAIProduct(_, _, _, _, _, _, _, _, _, _, _, completion): XCTAssertEqual(viewModel.errorState, .none) completion(.failure(expectedError)) - case let .identifyLanguage(_, _, _, completion): + case let .identifyLanguage(_, _, _, _, completion): XCTAssertEqual(viewModel.errorState, .none) completion(.success("en")) default: @@ -1080,10 +1081,10 @@ final class ProductDetailPreviewViewModelTests: XCTestCase { // When stores.whenReceivingAction(ofType: ProductAction.self) { action in switch action { - case let .generateAIProduct(_, _, _, _, _, _, _, _, _, _, completion): + case let .generateAIProduct(_, _, _, _, _, _, _, _, _, _, _, completion): XCTAssertFalse(viewModel.isSavingProduct) completion(.success(aiProduct)) - case let .identifyLanguage(_, _, _, completion): + case let .identifyLanguage(_, _, _, _, completion): XCTAssertFalse(viewModel.isSavingProduct) completion(.success("en")) case let .addProduct(_, onCompletion): @@ -1457,13 +1458,13 @@ private extension ProductDetailPreviewViewModelTests { addedProductResult: (Result)? = nil) { stores.whenReceivingAction(ofType: ProductAction.self) { action in switch action { - case let .generateAIProduct(_, _, _, _, _, _, _, _, _, _, completion): + case let .generateAIProduct(_, _, _, _, _, _, _, _, _, _, _, completion): if let aiGeneratedProductResult { completion(aiGeneratedProductResult) } else { completion(.success(self.sampleAIProduct)) } - case let .identifyLanguage(_, _, _, completion): + case let .identifyLanguage(_, _, _, _, completion): completion(.success(identifiedLanguage)) case let .addProduct(product, onCompletion): if let addedProductResult { diff --git a/Yosemite/Yosemite/Actions/ProductAction.swift b/Yosemite/Yosemite/Actions/ProductAction.swift index e6b1c9b276f..bd70993cb2d 100644 --- a/Yosemite/Yosemite/Actions/ProductAction.swift +++ b/Yosemite/Yosemite/Actions/ProductAction.swift @@ -130,6 +130,7 @@ public enum ProductAction: Action { /// case identifyLanguage(siteID: Int64, string: String, + shouldUseMerchantAIKey: Bool, feature: GenerativeContentRemoteFeature, completion: (Result) -> Void) @@ -138,12 +139,14 @@ public enum ProductAction: Action { case generateProductName(siteID: Int64, keywords: String, language: String, + shouldUseMerchantAIKey: Bool, completion: (Result) -> Void) /// Generates a product description with Jetpack AI given the name and features. /// case generateProductDescription(siteID: Int64, name: String, + shouldUseMerchantAIKey: Bool, features: String, language: String, completion: (Result) -> Void) @@ -155,6 +158,7 @@ public enum ProductAction: Action { name: String, description: String, language: String, + shouldUseMerchantAIKey: Bool, completion: (Result) -> Void) /// Generates product details (e.g. name and description) with Jetpack AI given the scanned texts from an image and optional product name . @@ -163,6 +167,7 @@ public enum ProductAction: Action { productName: String?, scannedTexts: [String], language: String, + shouldUseMerchantAIKey: Bool, completion: (Result) -> Void) /// Fetches the total number of products in the site given the site ID. @@ -189,6 +194,7 @@ public enum ProductAction: Action { keywords: String, language: String, tone: String, + shouldUseMerchantAIKey: Bool, currencySymbol: String, dimensionUnit: String?, weightUnit: String?, diff --git a/Yosemite/Yosemite/Stores/ProductStore.swift b/Yosemite/Yosemite/Stores/ProductStore.swift index a68c1df7a88..e97e082cbe1 100644 --- a/Yosemite/Yosemite/Stores/ProductStore.swift +++ b/Yosemite/Yosemite/Stores/ProductStore.swift @@ -116,18 +116,40 @@ public class ProductStore: Store { replaceProductLocally(product: product, onCompletion: onCompletion) case let .checkIfStoreHasProducts(siteID, status, onCompletion): checkIfStoreHasProducts(siteID: siteID, status: status, onCompletion: onCompletion) - case let .identifyLanguage(siteID, string, feature, completion): + case let .identifyLanguage(siteID, string, shouldUseMerchantAIKey, feature, completion): identifyLanguage(siteID: siteID, - string: string, feature: feature, + string: string, + shouldUseMerchantAIKey: shouldUseMerchantAIKey, + feature: feature, completion: completion) - case let .generateProductDescription(siteID, name, features, language, completion): - generateProductDescription(siteID: siteID, name: name, features: features, language: language, completion: completion) - case let .generateProductSharingMessage(siteID, url, name, description, language, completion): - generateProductSharingMessage(siteID: siteID, url: url, name: name, description: description, language: language, completion: completion) - case let .generateProductName(siteID, keywords, language, completion): - generateProductName(siteID: siteID, keywords: keywords, language: language, completion: completion) - case let .generateProductDetails(siteID, productName, scannedTexts, language, completion): - generateProductDetails(siteID: siteID, productName: productName, scannedTexts: scannedTexts, language: language, completion: completion) + case let .generateProductDescription(siteID, name, shouldUseMerchantAIKey, features, language, completion): + generateProductDescription(siteID: siteID, + name: name, + shouldUseMerchantAIKey: shouldUseMerchantAIKey, + features: features, + language: language, + completion: completion) + case let .generateProductSharingMessage(siteID, url, name, description, language, shouldUseMerchantAIKey, completion): + generateProductSharingMessage(siteID: siteID, + url: url, + name: name, + description: description, + language: language, + shouldUseMerchantAIKey: shouldUseMerchantAIKey, + completion: completion) + case let .generateProductName(siteID, keywords, language, shouldUseMerchantAIKey, completion): + generateProductName(siteID: siteID, + keywords: keywords, + language: language, + shouldUseMerchantAIKey: shouldUseMerchantAIKey, + completion: completion) + case let .generateProductDetails(siteID, productName, scannedTexts, language, shouldUseMerchantAIKey, completion): + generateProductDetails(siteID: siteID, + productName: productName, + scannedTexts: scannedTexts, + language: language, + shouldUseMerchantAIKey: shouldUseMerchantAIKey, + completion: completion) case let .fetchNumberOfProducts(siteID, completion): fetchNumberOfProducts(siteID: siteID, completion: completion) case let .generateAIProduct(siteID, @@ -135,6 +157,7 @@ public class ProductStore: Store { keywords, language, tone, + shouldUseMerchantAIKey, currencySymbol, dimensionUnit, weightUnit, @@ -146,6 +169,7 @@ public class ProductStore: Store { keywords: keywords, language: language, tone: tone, + shouldUseMerchantAIKey: shouldUseMerchantAIKey, currencySymbol: currencySymbol, dimensionUnit: dimensionUnit, weightUnit: weightUnit, @@ -594,12 +618,14 @@ private extension ProductStore { func identifyLanguage(siteID: Int64, string: String, + shouldUseMerchantAIKey: Bool, feature: GenerativeContentRemoteFeature, completion: @escaping (Result) -> Void) { Task { @MainActor in let result = await Result { try await generativeContentRemote.identifyLanguage(siteID: siteID, string: string, + shouldUseMerchantAIKey: shouldUseMerchantAIKey, feature: feature) } completion(result) @@ -608,6 +634,7 @@ private extension ProductStore { func generateProductDescription(siteID: Int64, name: String, + shouldUseMerchantAIKey: Bool, features: String, language: String, completion: @escaping (Result) -> Void) { @@ -622,7 +649,11 @@ private extension ProductStore { Task { @MainActor in let result = await Result { - let description = try await generativeContentRemote.generateText(siteID: siteID, base: prompt, feature: .productDescription, responseFormat: .text) + let description = try await generativeContentRemote.generateText(siteID: siteID, + base: prompt, + shouldUseMerchantAIKey: shouldUseMerchantAIKey, + feature: .productDescription, + responseFormat: .text) return description } completion(result) @@ -634,6 +665,7 @@ private extension ProductStore { name: String, description: String, language: String, + shouldUseMerchantAIKey: Bool, completion: @escaping (Result) -> Void) { let prompt = [ // swiftlint:disable:next line_length @@ -649,7 +681,11 @@ private extension ProductStore { Task { @MainActor in let result = await Result { - let message = try await generativeContentRemote.generateText(siteID: siteID, base: prompt, feature: .productSharing, responseFormat: .text) + let message = try await generativeContentRemote.generateText(siteID: siteID, + base: prompt, + shouldUseMerchantAIKey: shouldUseMerchantAIKey, + feature: .productSharing, + responseFormat: .text) .trimmingCharacters(in: CharacterSet(["\""])) // Trims quotation mark return message } @@ -661,6 +697,7 @@ private extension ProductStore { productName: String?, scannedTexts: [String], language: String, + shouldUseMerchantAIKey: Bool, completion: @escaping (Result) -> Void) { let keywords: [String] = { guard let productName else { @@ -683,6 +720,7 @@ private extension ProductStore { do { let jsonString = try await generativeContentRemote.generateText(siteID: siteID, base: prompt, + shouldUseMerchantAIKey: shouldUseMerchantAIKey, feature: .productDetailsFromScannedTexts, responseFormat: .json) guard let jsonData = jsonString.data(using: .utf8) else { @@ -699,6 +737,7 @@ private extension ProductStore { func generateProductName(siteID: Int64, keywords: String, language: String, + shouldUseMerchantAIKey: Bool, completion: @escaping (Result) -> Void) { let prompt = [ "You are a WooCommerce SEO and marketing expert.", @@ -710,7 +749,11 @@ private extension ProductStore { Task { @MainActor in let result = await Result { - let description = try await generativeContentRemote.generateText(siteID: siteID, base: prompt, feature: .productName, responseFormat: .text) + let description = try await generativeContentRemote.generateText(siteID: siteID, + base: prompt, + shouldUseMerchantAIKey: shouldUseMerchantAIKey, + feature: .productName, + responseFormat: .text) return description } completion(result) @@ -733,6 +776,7 @@ private extension ProductStore { keywords: String, language: String, tone: String, + shouldUseMerchantAIKey: Bool, currencySymbol: String, dimensionUnit: String?, weightUnit: String?, @@ -746,6 +790,7 @@ private extension ProductStore { keywords: keywords, language: language, tone: tone, + shouldUseMerchantAIKey: shouldUseMerchantAIKey, currencySymbol: currencySymbol, dimensionUnit: dimensionUnit, weightUnit: weightUnit, diff --git a/Yosemite/YosemiteTests/Mocks/Networking/Remote/MockGenerativeContentRemote.swift b/Yosemite/YosemiteTests/Mocks/Networking/Remote/MockGenerativeContentRemote.swift index d5be1ddbada..a2ff5683335 100644 --- a/Yosemite/YosemiteTests/Mocks/Networking/Remote/MockGenerativeContentRemote.swift +++ b/Yosemite/YosemiteTests/Mocks/Networking/Remote/MockGenerativeContentRemote.swift @@ -39,6 +39,7 @@ final class MockGenerativeContentRemote { extension MockGenerativeContentRemote: GenerativeContentRemoteProtocol { func generateText(siteID: Int64, base: String, + shouldUseMerchantAIKey: Bool, feature: GenerativeContentRemoteFeature, responseFormat: GenerativeContentRemoteResponseFormat) async throws -> String { generateTextBase = base @@ -53,6 +54,7 @@ extension MockGenerativeContentRemote: GenerativeContentRemoteProtocol { func identifyLanguage(siteID: Int64, string: String, + shouldUseMerchantAIKey: Bool, feature: GenerativeContentRemoteFeature) async throws -> String { identifyLanguageString = string identifyLanguageFeature = feature @@ -68,6 +70,7 @@ extension MockGenerativeContentRemote: GenerativeContentRemoteProtocol { keywords: String, language: String, tone: String, + shouldUseMerchantAIKey: Bool, currencySymbol: String, dimensionUnit: String?, weightUnit: String?, diff --git a/Yosemite/YosemiteTests/Stores/ProductStoreTests.swift b/Yosemite/YosemiteTests/Stores/ProductStoreTests.swift index 2ece6794770..b5a6e25b189 100644 --- a/Yosemite/YosemiteTests/Stores/ProductStoreTests.swift +++ b/Yosemite/YosemiteTests/Stores/ProductStoreTests.swift @@ -1905,6 +1905,7 @@ final class ProductStoreTests: XCTestCase { let result = waitFor { promise in productStore.onAction(ProductAction.generateProductDescription(siteID: self.sampleSiteID, name: "A product", + shouldUseMerchantAIKey: false, features: "Trendy", language: "en") { result in promise(result) @@ -1931,6 +1932,7 @@ final class ProductStoreTests: XCTestCase { let result = waitFor { promise in productStore.onAction(ProductAction.generateProductDescription(siteID: self.sampleSiteID, name: "A product", + shouldUseMerchantAIKey: false, features: "Trendy", language: "en") { result in promise(result) @@ -1956,6 +1958,7 @@ final class ProductStoreTests: XCTestCase { waitFor { promise in productStore.onAction(ProductAction.generateProductDescription(siteID: self.sampleSiteID, name: "A product name", + shouldUseMerchantAIKey: false, features: "Trendy, cool, fun", language: "en") { _ in promise(()) @@ -1983,6 +1986,7 @@ final class ProductStoreTests: XCTestCase { waitFor { promise in productStore.onAction(ProductAction.generateProductDescription(siteID: self.sampleSiteID, name: "A product name", + shouldUseMerchantAIKey: false, features: "Trendy, cool, fun", language: "en") { _ in promise(()) @@ -2008,6 +2012,7 @@ final class ProductStoreTests: XCTestCase { waitFor { promise in productStore.onAction(ProductAction.generateProductDescription(siteID: self.sampleSiteID, name: "A product name", + shouldUseMerchantAIKey: false, features: "Trendy, cool, fun", language: "en") { _ in promise(()) @@ -2039,7 +2044,8 @@ final class ProductStoreTests: XCTestCase { url: "https://example.com", name: "Sample product", description: "Sample description", - language: "en" + language: "en", + shouldUseMerchantAIKey: false ) { result in promise(result) }) @@ -2068,7 +2074,8 @@ final class ProductStoreTests: XCTestCase { url: "https://example.com", name: "Sample product", description: "Sample description", - language: "en" + language: "en", + shouldUseMerchantAIKey: false ) { result in promise(result) }) @@ -2097,7 +2104,8 @@ final class ProductStoreTests: XCTestCase { url: "https://example.com", name: "Sample product", description: "Sample description", - language: "en" + language: "en", + shouldUseMerchantAIKey: false ) { result in promise(result) }) @@ -2113,7 +2121,7 @@ final class ProductStoreTests: XCTestCase { let expectedURL = "https://example.com" let expectedName = "Sample product" let expectedDescription = "Sample description" - let expectedLangugae = "en" + let expectedLanguage = "en" let generativeContentRemote = MockGenerativeContentRemote() generativeContentRemote.whenIdentifyingLanguage(thenReturn: .success("")) generativeContentRemote.whenGeneratingText(thenReturn: .success("")) @@ -2130,8 +2138,8 @@ final class ProductStoreTests: XCTestCase { url: expectedURL, name: expectedName, description: expectedDescription, - language: expectedLangugae - ) { result in + language: expectedLanguage, + shouldUseMerchantAIKey: false) { result in promise(()) }) } @@ -2141,7 +2149,7 @@ final class ProductStoreTests: XCTestCase { XCTAssertTrue(base.contains(expectedURL)) XCTAssertTrue(base.contains(expectedName)) XCTAssertTrue(base.contains(expectedDescription)) - XCTAssertTrue(base.contains(expectedLangugae)) + XCTAssertTrue(base.contains(expectedLanguage)) } func test_generateProductSharingMessage_uses_correct_feature() throws { @@ -2161,7 +2169,8 @@ final class ProductStoreTests: XCTestCase { url: "https://example.com", name: "Sample product", description: "Sample description", - language: "en" + language: "en", + shouldUseMerchantAIKey: false ) { result in promise(()) }) @@ -2189,7 +2198,8 @@ final class ProductStoreTests: XCTestCase { url: "https://example.com", name: "Sample product", description: "Sample description", - language: "en" + language: "en", + shouldUseMerchantAIKey: false ) { result in promise(()) }) @@ -2217,6 +2227,7 @@ final class ProductStoreTests: XCTestCase { let result = waitFor { promise in productStore.onAction(ProductAction.identifyLanguage(siteID: self.sampleSiteID, string: "Woo is awesome", + shouldUseMerchantAIKey: false, feature: .productSharing) { result in promise(result) }) @@ -2242,6 +2253,7 @@ final class ProductStoreTests: XCTestCase { let result = waitFor { promise in productStore.onAction(ProductAction.identifyLanguage(siteID: self.sampleSiteID, string: "Woo is awesome", + shouldUseMerchantAIKey: false, feature: .productSharing) { result in promise(result) }) @@ -2640,7 +2652,8 @@ final class ProductStoreTests: XCTestCase { productStore.onAction(ProductAction.generateProductDetails(siteID: self.sampleSiteID, productName: nil, scannedTexts: [""], - language: "en") { result in + language: "en", + shouldUseMerchantAIKey: false) { result in promise(result) }) } @@ -2666,7 +2679,8 @@ final class ProductStoreTests: XCTestCase { productStore.onAction(ProductAction.generateProductDetails(siteID: self.sampleSiteID, productName: nil, scannedTexts: [""], - language: "en") { result in + language: "en", + shouldUseMerchantAIKey: false) { result in promise(result) }) } @@ -2694,7 +2708,8 @@ final class ProductStoreTests: XCTestCase { productStore.onAction(ProductAction.generateProductDetails(siteID: self.sampleSiteID, productName: productName, scannedTexts: scannedTexts, - language: language) { _ in + language: language, + shouldUseMerchantAIKey: false) { _ in promise(()) }) } @@ -2721,7 +2736,8 @@ final class ProductStoreTests: XCTestCase { productStore.onAction(ProductAction.generateProductDetails(siteID: self.sampleSiteID, productName: nil, scannedTexts: [""], - language: "en") { _ in + language: "en", + shouldUseMerchantAIKey: false) { _ in promise(()) }) } @@ -2746,7 +2762,8 @@ final class ProductStoreTests: XCTestCase { productStore.onAction(ProductAction.generateProductDetails(siteID: self.sampleSiteID, productName: nil, scannedTexts: [""], - language: "en") { _ in + language: "en", + shouldUseMerchantAIKey: false) { _ in promise(()) }) } @@ -2771,7 +2788,10 @@ final class ProductStoreTests: XCTestCase { // When let result = waitFor { promise in - productStore.onAction(ProductAction.generateProductName(siteID: 123, keywords: "iPhone 15", language: "en") { result in + productStore.onAction(ProductAction.generateProductName(siteID: 123, + keywords: "iPhone 15", + language: "en", + shouldUseMerchantAIKey: false) { result in promise(result) }) } @@ -2793,7 +2813,10 @@ final class ProductStoreTests: XCTestCase { // When let result = waitFor { promise in - productStore.onAction(ProductAction.generateProductName(siteID: 123, keywords: "iPhone 15", language: "en") { result in + productStore.onAction(ProductAction.generateProductName(siteID: 123, + keywords: "iPhone 15", + language: "en", + shouldUseMerchantAIKey: false) { result in promise(result) }) } @@ -2816,7 +2839,7 @@ final class ProductStoreTests: XCTestCase { // When waitFor { promise in - productStore.onAction(ProductAction.generateProductName(siteID: 123, keywords: keyword, language: "en") { _ in + productStore.onAction(ProductAction.generateProductName(siteID: 123, keywords: keyword, language: "en", shouldUseMerchantAIKey: false) { _ in promise(()) }) } @@ -2838,7 +2861,7 @@ final class ProductStoreTests: XCTestCase { // When waitFor { promise in - productStore.onAction(ProductAction.generateProductName(siteID: 123, keywords: "keyword", language: "en") { _ in + productStore.onAction(ProductAction.generateProductName(siteID: 123, keywords: "keyword", language: "en", shouldUseMerchantAIKey: false) { _ in promise(()) }) } @@ -2860,7 +2883,7 @@ final class ProductStoreTests: XCTestCase { // When waitFor { promise in - productStore.onAction(ProductAction.generateProductName(siteID: 123, keywords: "keyword", language: "en") { _ in + productStore.onAction(ProductAction.generateProductName(siteID: 123, keywords: "keyword", language: "en", shouldUseMerchantAIKey: false) { _ in promise(()) }) } @@ -2929,6 +2952,7 @@ final class ProductStoreTests: XCTestCase { keywords: "Leather strip, silver", language: "en", tone: "Casual", + shouldUseMerchantAIKey: false, currencySymbol: "INR", dimensionUnit: "cm", weightUnit: "kg", @@ -2960,6 +2984,7 @@ final class ProductStoreTests: XCTestCase { keywords: "Leather strip, silver", language: "en", tone: "Casual", + shouldUseMerchantAIKey: false, currencySymbol: "INR", dimensionUnit: "cm", weightUnit: "kg",