diff --git a/Modules/Sources/Networking/Remote/POSCatalogSyncRemote.swift b/Modules/Sources/Networking/Remote/POSCatalogSyncRemote.swift index 5228428bf75..b7854e37543 100644 --- a/Modules/Sources/Networking/Remote/POSCatalogSyncRemote.swift +++ b/Modules/Sources/Networking/Remote/POSCatalogSyncRemote.swift @@ -39,6 +39,18 @@ public protocol POSCatalogSyncRemoteProtocol { /// - pageNumber: Page number for pagination. /// - Returns: Paginated list of POS product variations. func loadProductVariations(siteID: Int64, pageNumber: Int) async throws -> PagedItems + + /// Gets the total count of products for the specified site. + /// + /// - Parameter siteID: Site ID to get product count for. + /// - Returns: Total number of products. + func getProductCount(siteID: Int64) async throws -> Int + + /// Gets the total count of product variations for the specified site. + /// + /// - Parameter siteID: Site ID to get variation count for. + /// - Returns: Total number of variations. + func getProductVariationCount(siteID: Int64) async throws -> Int } /// POS Catalog Sync: Remote Endpoints @@ -174,6 +186,58 @@ public class POSCatalogSyncRemote: Remote, POSCatalogSyncRemoteProtocol { return createPagedItems(items: variations, responseHeaders: responseHeaders, currentPageNumber: pageNumber) } + + // MARK: - Count Endpoints + + /// Gets the total count of products for the specified site. + /// + /// - Parameter siteID: Site ID to get product count for. + /// - Returns: Total number of products. + public func getProductCount(siteID: Int64) async throws -> Int { + let path = Path.products + let parameters = [ + ParameterKey.page: String(1), + ParameterKey.perPage: String(1), + ParameterKey.fields: POSProductVariation.requestFields.first ?? "" + ] + + let request = JetpackRequest( + wooApiVersion: .mark3, + method: .get, + siteID: siteID, + path: path, + parameters: parameters, + availableAsRESTRequest: true + ) + let responseHeaders = try await enqueueWithResponseHeaders(request) + + return totalItemsCount(from: responseHeaders) ?? 0 + } + + /// Gets the total count of product variations for the specified site. + /// + /// - Parameter siteID: Site ID to get variation count for. + /// - Returns: Total number of variations. + public func getProductVariationCount(siteID: Int64) async throws -> Int { + let path = Path.variations + let parameters = [ + ParameterKey.page: String(1), + ParameterKey.perPage: String(1), + ParameterKey.fields: POSProductVariation.requestFields.first ?? "" + ] + + let request = JetpackRequest( + wooApiVersion: .wcAnalytics, + method: .get, + siteID: siteID, + path: path, + parameters: parameters, + availableAsRESTRequest: true + ) + let responseHeaders = try await enqueueWithResponseHeaders(request) + + return totalItemsCount(from: responseHeaders) ?? 0 + } } // MARK: - Constants diff --git a/Modules/Sources/NetworkingCore/Remote/Remote.swift b/Modules/Sources/NetworkingCore/Remote/Remote.swift index 5d2c6959db2..3c8d20e9614 100644 --- a/Modules/Sources/NetworkingCore/Remote/Remote.swift +++ b/Modules/Sources/NetworkingCore/Remote/Remote.swift @@ -248,6 +248,23 @@ open class Remote: NSObject { throw mapNetworkError(error: error, for: request) } } + + /// Enqueues the specified Network Request using Swift Concurrency, for fetching the headers + /// + /// - Important: + /// - No data will be parsed. This is intended for use with `HEAD` requests, but will make whatever request you specify + /// + /// - Parameter request: Request that should be performed. + /// - Returns: The headers from the response + public func enqueueWithResponseHeaders(_ request: Request) async throws -> [String: String] { + do { + let (_, headers) = try await network.responseDataAndHeaders(for: request) + return headers ?? [:] + } catch { + handleResponseError(error: error, for: request) + throw mapNetworkError(error: error, for: request) + } + } } private extension Remote { @@ -382,12 +399,16 @@ public extension Remote { let hasMorePages = totalPages.map { currentPageNumber < $0 } ?? true + let totalItems = totalItemsCount(from: responseHeaders) + + return PagedItems(items: items, hasMorePages: hasMorePages, totalItems: totalItems) + } + + func totalItemsCount(from responseHeaders: [String: String]?) -> Int? { // Extract total count from response headers (case insensitive) - let totalItems = responseHeaders?.first(where: { + responseHeaders?.first(where: { $0.key.lowercased() == PaginationHeaderKey.totalCount.lowercased() }).flatMap { Int($0.value) } - - return PagedItems(items: items, hasMorePages: hasMorePages, totalItems: totalItems) } } diff --git a/Modules/Sources/Yosemite/Tools/POS/POSCatalogSizeChecker.swift b/Modules/Sources/Yosemite/Tools/POS/POSCatalogSizeChecker.swift new file mode 100644 index 00000000000..0b3847fd167 --- /dev/null +++ b/Modules/Sources/Yosemite/Tools/POS/POSCatalogSizeChecker.swift @@ -0,0 +1,56 @@ +import Foundation +import Networking + +/// Protocol for checking the size of a remote POS catalog +public protocol POSCatalogSizeCheckerProtocol { + /// Checks the size of the remote catalog for the specified site + /// - Parameter siteID: The site ID to check catalog size for + /// - Returns: The size information of the catalog + /// - Throws: Network or parsing errors + func checkCatalogSize(for siteID: Int64) async throws -> POSCatalogSize +} + +/// Implementation of catalog size checker that uses the sync remote to get counts +public struct POSCatalogSizeChecker: POSCatalogSizeCheckerProtocol { + private let syncRemote: POSCatalogSyncRemoteProtocol + + public init(syncRemote: POSCatalogSyncRemoteProtocol) { + self.syncRemote = syncRemote + } + + public func checkCatalogSize(for siteID: Int64) async throws -> POSCatalogSize { + // Make concurrent requests to get both counts + async let productCount = syncRemote.getProductCount(siteID: siteID) + async let variationCount = syncRemote.getProductVariationCount(siteID: siteID) + + do { + return try await POSCatalogSize( + productCount: productCount, + variationCount: variationCount + ) + } catch { + DDLogError( + "⚠️ Failed to check POS catalog size for site \(siteID): \(error)" + ) + throw error + } + } +} + +public struct POSCatalogSize: Equatable { + /// Number of products in the catalog + public let productCount: Int + + /// Number of product variations in the catalog + public let variationCount: Int + + /// Total number of items (products + variations) + public var totalCount: Int { + productCount + variationCount + } + + public init(productCount: Int, variationCount: Int) { + self.productCount = productCount + self.variationCount = variationCount + } +} diff --git a/Modules/Sources/Yosemite/Tools/POS/POSCatalogSyncCoordinator.swift b/Modules/Sources/Yosemite/Tools/POS/POSCatalogSyncCoordinator.swift index 983d5bd80cc..afcf920502e 100644 --- a/Modules/Sources/Yosemite/Tools/POS/POSCatalogSyncCoordinator.swift +++ b/Modules/Sources/Yosemite/Tools/POS/POSCatalogSyncCoordinator.swift @@ -34,6 +34,8 @@ public actor POSCatalogSyncCoordinator: POSCatalogSyncCoordinatorProtocol { private let incrementalSyncService: POSCatalogIncrementalSyncServiceProtocol private let grdbManager: GRDBManagerProtocol private let maxIncrementalSyncAge: TimeInterval + private let catalogSizeLimit: Int + private let catalogSizeChecker: POSCatalogSizeCheckerProtocol /// Tracks ongoing full syncs by site ID to prevent duplicates private var ongoingSyncs: Set = [] @@ -43,11 +45,15 @@ public actor POSCatalogSyncCoordinator: POSCatalogSyncCoordinatorProtocol { public init(fullSyncService: POSCatalogFullSyncServiceProtocol, incrementalSyncService: POSCatalogIncrementalSyncServiceProtocol, grdbManager: GRDBManagerProtocol, - maxIncrementalSyncAge: TimeInterval = 300) { + maxIncrementalSyncAge: TimeInterval = 300, + catalogSizeLimit: Int? = nil, + catalogSizeChecker: POSCatalogSizeCheckerProtocol) { self.fullSyncService = fullSyncService self.incrementalSyncService = incrementalSyncService self.grdbManager = grdbManager self.maxIncrementalSyncAge = maxIncrementalSyncAge + self.catalogSizeLimit = catalogSizeLimit ?? Constants.defaultSizeLimitForPOSCatalog + self.catalogSizeChecker = catalogSizeChecker } public func performFullSync(for siteID: Int64) async throws { @@ -71,7 +77,20 @@ public actor POSCatalogSyncCoordinator: POSCatalogSyncCoordinatorProtocol { DDLogInfo("✅ POSCatalogSyncCoordinator completed full sync for site \(siteID)") } + /// Determines if a full sync should be performed based on the age of the last sync + /// - Parameters: + /// - siteID: The site ID to check + /// - maxAge: Maximum age before a sync is considered stale + /// - Returns: True if a sync should be performed public func shouldPerformFullSync(for siteID: Int64, maxAge: TimeInterval) async -> Bool { + return await shouldPerformFullSync(for: siteID, maxAge: maxAge, maxCatalogSize: catalogSizeLimit) + } + + private func shouldPerformFullSync(for siteID: Int64, maxAge: TimeInterval, maxCatalogSize: Int) async -> Bool { + guard await isCatalogSizeWithinLimit(for: siteID, maxCatalogSize: maxCatalogSize) else { + return false + } + if !siteExistsInDatabase(siteID: siteID) { DDLogInfo("📋 POSCatalogSyncCoordinator: Site \(siteID) not found in database, sync needed") return true @@ -86,20 +105,35 @@ public actor POSCatalogSyncCoordinator: POSCatalogSyncCoordinatorProtocol { let shouldSync = age > maxAge if shouldSync { - DDLogInfo("📋 POSCatalogSyncCoordinator: Last sync for site \(siteID) was \(Int(age))s ago (max: \(Int(maxAge))s), sync needed") + DDLogInfo("📋 POSCatalogSyncCoordinator: Last sync for site \(siteID) was \(Int(age))s ago " + + "(max: \(Int(maxAge))s), sync needed") } else { - DDLogInfo("📋 POSCatalogSyncCoordinator: Last sync for site \(siteID) was \(Int(age))s ago (max: \(Int(maxAge))s), sync not needed") + DDLogInfo("📋 POSCatalogSyncCoordinator: Last sync for site \(siteID) was \(Int(age))s ago " + + "(max: \(Int(maxAge))s), sync not needed") } return shouldSync } + /// Performs an incremental sync if applicable based on sync conditions + /// - Parameters: + /// - siteID: The site ID to sync catalog for + /// - forceSync: Whether to bypass age checks and always sync + /// - Throws: POSCatalogSyncError.syncAlreadyInProgress if a sync is already running for this site public func performIncrementalSyncIfApplicable(for siteID: Int64, forceSync: Bool) async throws { + try await performIncrementalSyncIfApplicable(for: siteID, forceSync: forceSync, maxCatalogSize: catalogSizeLimit) + } + + private func performIncrementalSyncIfApplicable(for siteID: Int64, forceSync: Bool, maxCatalogSize: Int) async throws { if ongoingIncrementalSyncs.contains(siteID) { DDLogInfo("⚠️ POSCatalogSyncCoordinator: Incremental sync already in progress for site \(siteID)") throw POSCatalogSyncError.syncAlreadyInProgress(siteID: siteID) } + guard await isCatalogSizeWithinLimit(for: siteID, maxCatalogSize: maxCatalogSize) else { + return + } + guard let lastFullSyncDate = await lastFullSyncDate(for: siteID) else { DDLogInfo("📋 POSCatalogSyncCoordinator: No full sync performed yet for site \(siteID), skipping incremental sync") return @@ -130,6 +164,28 @@ public actor POSCatalogSyncCoordinator: POSCatalogSyncCoordinatorProtocol { // MARK: - Private + /// Checks if the catalog size is within the specified sync limit + /// - Parameters: + /// - siteID: The site ID to check + /// - maxCatalogSize: Maximum allowed catalog size for syncing + /// - Returns: True if catalog size is within limit or if size cannot be determined + private func isCatalogSizeWithinLimit(for siteID: Int64, maxCatalogSize: Int) async -> Bool { + guard let catalogSize = try? await catalogSizeChecker.checkCatalogSize(for: siteID) else { + DDLogError("📋 POSCatalogSyncCoordinator: Could not get catalog size for site \(siteID)") + return false + } + + guard catalogSize.totalCount <= maxCatalogSize else { + DDLogInfo("📋 POSCatalogSyncCoordinator: Site \(siteID) has catalog size \(catalogSize.totalCount), " + + "greater than \(maxCatalogSize), should not sync.") + return false + } + + DDLogInfo("📋 POSCatalogSyncCoordinator: Site \(siteID) has catalog size \(catalogSize.totalCount), with " + + "\(catalogSize.productCount) products and \(catalogSize.variationCount) variations") + return true + } + private func lastFullSyncDate(for siteID: Int64) async -> Date? { do { return try await grdbManager.databaseConnection.read { db in @@ -164,3 +220,9 @@ public actor POSCatalogSyncCoordinator: POSCatalogSyncCoordinatorProtocol { } } } + +private extension POSCatalogSyncCoordinator { + enum Constants { + static let defaultSizeLimitForPOSCatalog = 1000 + } +} diff --git a/Modules/Tests/NetworkingTests/Remote/POSCatalogSyncRemoteTests.swift b/Modules/Tests/NetworkingTests/Remote/POSCatalogSyncRemoteTests.swift index 7c34814be64..c7f9e4aadbc 100644 --- a/Modules/Tests/NetworkingTests/Remote/POSCatalogSyncRemoteTests.swift +++ b/Modules/Tests/NetworkingTests/Remote/POSCatalogSyncRemoteTests.swift @@ -416,4 +416,178 @@ struct POSCatalogSyncRemoteTests { #expect(fieldNames.contains("stock_quantity")) #expect(fieldNames.contains("stock_status")) } + + // MARK: - Count Endpoints Tests + + @Test func getProductCount_uses_correct_path() async throws { + // Given + let remote = POSCatalogSyncRemote(network: network) + network.responseHeaders = ["X-WP-Total": "150"] + + // When + _ = try? await remote.getProductCount(siteID: sampleSiteID) + + // Then - verify correct path was used + let request = try #require(network.requestsForResponseData.last as? JetpackRequest) + #expect(request.path.contains("products")) + } + + @Test func getProductCount_returns_count_from_total_header() async throws { + // Given + let remote = POSCatalogSyncRemote(network: network) + let expectedCount = 500 + network.responseHeaders = ["X-WP-Total": "\(expectedCount)"] + network.simulateResponse(requestUrlSuffix: "products", filename: "empty-data-array") + + // When + let count = try await remote.getProductCount(siteID: sampleSiteID) + + // Then + #expect(count == expectedCount) + } + + @Test func getProductCount_returns_zero_when_header_missing() async throws { + // Given + let remote = POSCatalogSyncRemote(network: network) + network.responseHeaders = nil + network.simulateResponse(requestUrlSuffix: "products", filename: "empty-data-array") + + // When + let count = try await remote.getProductCount(siteID: sampleSiteID) + + // Then + #expect(count == 0) + } + + @Test func getProductCount_returns_zero_when_header_invalid() async throws { + // Given + let remote = POSCatalogSyncRemote(network: network) + network.responseHeaders = ["X-WP-Total": "invalid-number"] + network.simulateResponse(requestUrlSuffix: "products", filename: "empty-data-array") + + // When + let count = try await remote.getProductCount(siteID: sampleSiteID) + + // Then + #expect(count == 0) + } + + @Test func getProductCount_relays_networking_error() async throws { + // Given + let remote = POSCatalogSyncRemote(network: network) + + // When/Then + await #expect(throws: NetworkError.notFound()) { + try await remote.getProductCount(siteID: sampleSiteID) + } + } + + @Test func getProductVariationCount_uses_correct_path() async throws { + // Given + let remote = POSCatalogSyncRemote(network: network) + network.responseHeaders = ["X-WP-Total": "75"] + network.simulateResponse(requestUrlSuffix: "variations", filename: "empty-data-array") + + // When + _ = try? await remote.getProductVariationCount(siteID: sampleSiteID) + + // Then - verify correct path was used + let request = try #require(network.requestsForResponseData.last as? JetpackRequest) + #expect(request.path.contains("variations")) + } + + @Test func getProductVariationCount_returns_count_from_total_header() async throws { + // Given + let remote = POSCatalogSyncRemote(network: network) + let expectedCount = 250 + network.responseHeaders = ["X-WP-Total": "\(expectedCount)"] + network.simulateResponse(requestUrlSuffix: "variations", filename: "empty-data-array") + + // When + let count = try await remote.getProductVariationCount(siteID: sampleSiteID) + + // Then + #expect(count == expectedCount) + } + + @Test func getProductVariationCount_returns_zero_when_header_missing() async throws { + // Given + let remote = POSCatalogSyncRemote(network: network) + network.responseHeaders = nil + network.simulateResponse(requestUrlSuffix: "variations", filename: "empty-data-array") + + // When + let count = try await remote.getProductVariationCount(siteID: sampleSiteID) + + // Then + #expect(count == 0) + } + + @Test func getProductVariationCount_returns_zero_when_header_invalid() async throws { + // Given + let remote = POSCatalogSyncRemote(network: network) + network.responseHeaders = ["X-WP-Total": "not-a-number"] + network.simulateResponse(requestUrlSuffix: "variations", filename: "empty-data-array") + + // When + let count = try await remote.getProductVariationCount(siteID: sampleSiteID) + + // Then + #expect(count == 0) + } + + @Test func getProductVariationCount_relays_networking_error() async throws { + // Given + let remote = POSCatalogSyncRemote(network: network) + + // When/Then + await #expect(throws: NetworkError.notFound()) { + try await remote.getProductVariationCount(siteID: sampleSiteID) + } + } + + @Test func getProductCount_handles_very_large_counts() async throws { + // Given + let remote = POSCatalogSyncRemote(network: network) + let largeCount = 999999 + network.responseHeaders = ["X-WP-Total": "\(largeCount)"] + network.simulateResponse(requestUrlSuffix: "products", filename: "empty-data-array") + + // When + let count = try await remote.getProductCount(siteID: sampleSiteID) + + // Then + #expect(count == largeCount) + } + + @Test func getProductVariationCount_handles_very_large_counts() async throws { + // Given + let remote = POSCatalogSyncRemote(network: network) + let largeCount = 888888 + network.responseHeaders = ["X-WP-Total": "\(largeCount)"] + network.simulateResponse(requestUrlSuffix: "variations", filename: "empty-data-array") + + // When + let count = try await remote.getProductVariationCount(siteID: sampleSiteID) + + // Then + #expect(count == largeCount) + } + + @Test func count_endpoints_use_correct_api_versions() async throws { + // Given + let remote = POSCatalogSyncRemote(network: network) + network.responseHeaders = ["X-WP-Total": "10"] + + // When - make both count calls + _ = try? await remote.getProductCount(siteID: sampleSiteID) + _ = try? await remote.getProductVariationCount(siteID: sampleSiteID) + + // Then - verify API versions match the data endpoints + // Products should use .mark3, variations should use .wcAnalytics (based on the load endpoints) + // This is verified by checking that the correct paths are called + let requests = network.requestsForResponseData.compactMap { $0 as? JetpackRequest } + #expect(requests.contains { $0.path.contains("products") }) + #expect(requests.contains { $0.path.contains("variations") }) + } } diff --git a/Modules/Tests/NetworkingTests/Remote/RemoteTests.swift b/Modules/Tests/NetworkingTests/Remote/RemoteTests.swift index ad2bca7ac67..ba3bc97e821 100644 --- a/Modules/Tests/NetworkingTests/Remote/RemoteTests.swift +++ b/Modules/Tests/NetworkingTests/Remote/RemoteTests.swift @@ -1016,6 +1016,92 @@ final class RemoteTests: XCTestCase { XCTAssertTrue(result.1 as? NetworkError == error) } } + + // MARK: - Tests for enqueueWithResponseHeaders + + /// Verifies that `enqueueWithResponseHeaders` properly wraps up the received request and returns headers + /// + func test_enqueueWithResponseHeaders_wraps_up_request_and_returns_headers() async throws { + // Given + let network = MockNetwork() + let remote = Remote(network: network) + let expectedHeaders = ["Content-Type": "application/json", "X-Total-Count": "150"] + + network.simulateResponse(requestUrlSuffix: "something", filename: "order") + network.responseHeaders = expectedHeaders + + // When + let headers = try await remote.enqueueWithResponseHeaders(request) + + // Then + let receivedRequest = try XCTUnwrap(network.requestsForResponseData.first as? JetpackRequest) + XCTAssertEqual(network.requestsForResponseData.count, 1) + XCTAssertEqual(receivedRequest.method, request.method) + XCTAssertEqual(receivedRequest.path, request.path) + XCTAssertEqual(headers, expectedHeaders) + } + + /// Verifies that `enqueueWithResponseHeaders` returns empty dictionary when no headers are provided + /// + func test_enqueueWithResponseHeaders_returns_empty_dictionary_when_no_headers() async throws { + // Given + let network = MockNetwork() + let remote = Remote(network: network) + + network.simulateResponse(requestUrlSuffix: "something", filename: "order") + + // When + let headers = try await remote.enqueueWithResponseHeaders(request) + + // Then + XCTAssertEqual(headers, [:]) + } + + /// Verifies that `enqueueWithResponseHeaders` propagates NetworkError properly + /// + func test_enqueueWithResponseHeaders_propagates_NetworkError() async throws { + // Given + let network = MockNetwork() + let remote = Remote(network: network) + let expectedError = NetworkError.notFound() + + network.simulateError(requestUrlSuffix: "something", error: expectedError) + + // When/Then + do { + _ = try await remote.enqueueWithResponseHeaders(request) + XCTFail("Expected error to be thrown") + } catch { + XCTAssertTrue(error as? NetworkError == expectedError) + } + } + + /// Verifies that `enqueueWithResponseHeaders` handles various header types correctly + /// + func test_enqueueWithResponseHeaders_handles_various_header_types() async throws { + // Given + let network = MockNetwork() + let remote = Remote(network: network) + let expectedHeaders = [ + "Content-Type": "application/json", + "X-Total-Count": "500", + "X-WC-Total": "250", + "Cache-Control": "no-cache", + "Set-Cookie": "session=abc123" + ] + + network.simulateResponse(requestUrlSuffix: "something", filename: "order") + network.responseHeaders = expectedHeaders + + // When + let headers = try await remote.enqueueWithResponseHeaders(request) + + // Then + XCTAssertEqual(headers.count, expectedHeaders.count) + for (key, value) in expectedHeaders { + XCTAssertEqual(headers[key], value, "Header \(key) should match expected value") + } + } } diff --git a/Modules/Tests/YosemiteTests/Mocks/MockPOSCatalogSizeChecker.swift b/Modules/Tests/YosemiteTests/Mocks/MockPOSCatalogSizeChecker.swift new file mode 100644 index 00000000000..e88ac49fdae --- /dev/null +++ b/Modules/Tests/YosemiteTests/Mocks/MockPOSCatalogSizeChecker.swift @@ -0,0 +1,21 @@ +import Foundation +@testable import Yosemite + +final class MockPOSCatalogSizeChecker: POSCatalogSizeCheckerProtocol { + // MARK: - checkCatalogSize tracking + private(set) var checkCatalogSizeCallCount = 0 + private(set) var lastCheckedSiteID: Int64? + var checkCatalogSizeResult: Result = .success(POSCatalogSize(productCount: 100, variationCount: 50)) // 150 total - well under limit + + func checkCatalogSize(for siteID: Int64) async throws -> POSCatalogSize { + checkCatalogSizeCallCount += 1 + lastCheckedSiteID = siteID + + switch checkCatalogSizeResult { + case .success(let size): + return size + case .failure(let error): + throw error + } + } +} diff --git a/Modules/Tests/YosemiteTests/Mocks/MockPOSCatalogSyncRemote.swift b/Modules/Tests/YosemiteTests/Mocks/MockPOSCatalogSyncRemote.swift index f85787c9ff5..64a126b93ac 100644 --- a/Modules/Tests/YosemiteTests/Mocks/MockPOSCatalogSyncRemote.swift +++ b/Modules/Tests/YosemiteTests/Mocks/MockPOSCatalogSyncRemote.swift @@ -125,4 +125,50 @@ final class MockPOSCatalogSyncRemote: POSCatalogSyncRemoteProtocol { } return fallbackVariationResult } + + // MARK: - Protocol Methods - Catalog size + + // MARK: - getProductCount tracking + private(set) var getProductCountCallCount = 0 + private(set) var lastProductCountSiteID: Int64? + var getProductCountResult: Result = .success(0) + var productCountDelay: UInt64 = 0 + + // MARK: - getProductVariationCount tracking + private(set) var getProductVariationCountCallCount = 0 + private(set) var lastVariationCountSiteID: Int64? + var getProductVariationCountResult: Result = .success(0) + var variationCountDelay: UInt64 = 0 + + func getProductCount(siteID: Int64) async throws -> Int { + getProductCountCallCount += 1 + lastProductCountSiteID = siteID + + if productCountDelay > 0 { + try await Task.sleep(nanoseconds: productCountDelay) + } + + switch getProductCountResult { + case .success(let count): + return count + case .failure(let error): + throw error + } + } + + func getProductVariationCount(siteID: Int64) async throws -> Int { + getProductVariationCountCallCount += 1 + lastVariationCountSiteID = siteID + + if variationCountDelay > 0 { + try await Task.sleep(nanoseconds: variationCountDelay) + } + + switch getProductVariationCountResult { + case .success(let count): + return count + case .failure(let error): + throw error + } + } } diff --git a/Modules/Tests/YosemiteTests/Tools/POS/POSCatalogSizeCheckerTests.swift b/Modules/Tests/YosemiteTests/Tools/POS/POSCatalogSizeCheckerTests.swift new file mode 100644 index 00000000000..b5571f1ccc0 --- /dev/null +++ b/Modules/Tests/YosemiteTests/Tools/POS/POSCatalogSizeCheckerTests.swift @@ -0,0 +1,98 @@ +import Foundation +import Testing +@testable import Yosemite + +struct POSCatalogSizeCheckerTests { + private let mockSyncRemote: MockPOSCatalogSyncRemote + private let sut: POSCatalogSizeChecker + private let sampleSiteID: Int64 = 134 + + init() throws { + self.mockSyncRemote = MockPOSCatalogSyncRemote() + self.sut = POSCatalogSizeChecker(syncRemote: mockSyncRemote) + } + + @Test func checkCatalogSize_returns_combined_count_from_remote() async throws { + // Given + mockSyncRemote.getProductCountResult = .success(400) + mockSyncRemote.getProductVariationCountResult = .success(300) + + // When + let catalogSize = try await sut.checkCatalogSize(for: sampleSiteID) + + // Then + #expect(catalogSize.productCount == 400) + #expect(catalogSize.variationCount == 300) + #expect(catalogSize.totalCount == 700) + #expect(mockSyncRemote.getProductCountCallCount == 1) + #expect(mockSyncRemote.getProductVariationCountCallCount == 1) + #expect(mockSyncRemote.lastProductCountSiteID == sampleSiteID) + #expect(mockSyncRemote.lastVariationCountSiteID == sampleSiteID) + } + + @Test func checkCatalogSize_handles_zero_counts() async throws { + // Given + mockSyncRemote.getProductCountResult = .success(0) + mockSyncRemote.getProductVariationCountResult = .success(0) + + // When + let catalogSize = try await sut.checkCatalogSize(for: sampleSiteID) + + // Then + #expect(catalogSize.productCount == 0) + #expect(catalogSize.variationCount == 0) + #expect(catalogSize.totalCount == 0) + } + + @Test func checkCatalogSize_handles_product_count_only() async throws { + // Given + mockSyncRemote.getProductCountResult = .success(500) + mockSyncRemote.getProductVariationCountResult = .success(0) + + // When + let catalogSize = try await sut.checkCatalogSize(for: sampleSiteID) + + // Then + #expect(catalogSize.productCount == 500) + #expect(catalogSize.variationCount == 0) + #expect(catalogSize.totalCount == 500) + } + + @Test func checkCatalogSize_handles_variation_count_only() async throws { + // Given + mockSyncRemote.getProductCountResult = .success(0) + mockSyncRemote.getProductVariationCountResult = .success(750) + + // When + let catalogSize = try await sut.checkCatalogSize(for: sampleSiteID) + + // Then + #expect(catalogSize.productCount == 0) + #expect(catalogSize.variationCount == 750) + #expect(catalogSize.totalCount == 750) + } + + @Test func checkCatalogSize_propagates_product_count_error() async throws { + // Given + let expectedError = NSError(domain: "test", code: 404, userInfo: [NSLocalizedDescriptionKey: "Products not found"]) + mockSyncRemote.getProductCountResult = .failure(expectedError) + mockSyncRemote.getProductVariationCountResult = .success(100) + + // When/Then + await #expect(throws: expectedError) { + _ = try await sut.checkCatalogSize(for: sampleSiteID) + } + } + + @Test func checkCatalogSize_propagates_variation_count_error() async throws { + // Given + let expectedError = NSError(domain: "test", code: 500, userInfo: [NSLocalizedDescriptionKey: "Server error"]) + mockSyncRemote.getProductCountResult = .success(200) + mockSyncRemote.getProductVariationCountResult = .failure(expectedError) + + // When/Then + await #expect(throws: expectedError) { + _ = try await sut.checkCatalogSize(for: sampleSiteID) + } + } +} diff --git a/Modules/Tests/YosemiteTests/Tools/POS/POSCatalogSyncCoordinatorTests.swift b/Modules/Tests/YosemiteTests/Tools/POS/POSCatalogSyncCoordinatorTests.swift index 7cc1c0d6e0d..d33a84844e6 100644 --- a/Modules/Tests/YosemiteTests/Tools/POS/POSCatalogSyncCoordinatorTests.swift +++ b/Modules/Tests/YosemiteTests/Tools/POS/POSCatalogSyncCoordinatorTests.swift @@ -7,6 +7,7 @@ struct POSCatalogSyncCoordinatorTests { private let mockSyncService: MockPOSCatalogFullSyncService private let mockIncrementalSyncService: MockPOSCatalogIncrementalSyncService private let grdbManager: GRDBManager + private let mockCatalogSizeChecker: MockPOSCatalogSizeChecker private let sut: POSCatalogSyncCoordinator private let sampleSiteID: Int64 = 134 @@ -14,10 +15,12 @@ struct POSCatalogSyncCoordinatorTests { self.mockSyncService = MockPOSCatalogFullSyncService() self.mockIncrementalSyncService = MockPOSCatalogIncrementalSyncService() self.grdbManager = try GRDBManager() + self.mockCatalogSizeChecker = MockPOSCatalogSizeChecker() self.sut = POSCatalogSyncCoordinator( fullSyncService: mockSyncService, incrementalSyncService: mockIncrementalSyncService, - grdbManager: grdbManager + grdbManager: grdbManager, + catalogSizeChecker: mockCatalogSizeChecker ) } @@ -130,6 +133,79 @@ struct POSCatalogSyncCoordinatorTests { #expect(shouldSync == true) } + // MARK: - Catalog Size Check Tests + + @Test func shouldPerformFullSync_returns_false_when_catalog_size_exceeds_limit() async throws { + // Given - catalog size is above the 1000 item limit + mockCatalogSizeChecker.checkCatalogSizeResult = .success(POSCatalogSize(productCount: 800, variationCount: 300)) // 1100 total + try createSiteInDatabase(siteID: sampleSiteID, lastFullSyncDate: nil) + + // When + let shouldSync = await sut.shouldPerformFullSync(for: sampleSiteID, maxAge: 60 * 60) + + // Then + #expect(shouldSync == false) + #expect(mockCatalogSizeChecker.checkCatalogSizeCallCount == 1) + #expect(mockCatalogSizeChecker.lastCheckedSiteID == sampleSiteID) + } + + @Test func shouldPerformFullSync_returns_true_when_catalog_size_is_at_limit() async throws { + // Given - catalog size is exactly at the 1000 item limit + mockCatalogSizeChecker.checkCatalogSizeResult = .success(POSCatalogSize(productCount: 600, variationCount: 400)) // 1000 total + try createSiteInDatabase(siteID: sampleSiteID, lastFullSyncDate: nil) + + // When + let shouldSync = await sut.shouldPerformFullSync(for: sampleSiteID, maxAge: 60 * 60) + + // Then + #expect(shouldSync == true) + #expect(mockCatalogSizeChecker.checkCatalogSizeCallCount == 1) + #expect(mockCatalogSizeChecker.lastCheckedSiteID == sampleSiteID) + } + + @Test func shouldPerformFullSync_returns_true_when_catalog_size_is_under_limit() async throws { + // Given - catalog size is below the 1000 item limit + mockCatalogSizeChecker.checkCatalogSizeResult = .success(POSCatalogSize(productCount: 300, variationCount: 200)) // 500 total + try createSiteInDatabase(siteID: sampleSiteID, lastFullSyncDate: nil) + + // When + let shouldSync = await sut.shouldPerformFullSync(for: sampleSiteID, maxAge: 60 * 60) + + // Then + #expect(shouldSync == true) + #expect(mockCatalogSizeChecker.checkCatalogSizeCallCount == 1) + #expect(mockCatalogSizeChecker.lastCheckedSiteID == sampleSiteID) + } + + @Test func shouldPerformFullSync_returns_false_when_catalog_size_check_fails() async throws { + // Given - catalog size check throws an error + let sizeCheckError = NSError(domain: "size_check", code: 500, userInfo: [NSLocalizedDescriptionKey: "Network error"]) + mockCatalogSizeChecker.checkCatalogSizeResult = .failure(sizeCheckError) + try createSiteInDatabase(siteID: sampleSiteID, lastFullSyncDate: nil) + + // When + let shouldSync = await sut.shouldPerformFullSync(for: sampleSiteID, maxAge: 60 * 60) + + // Then + #expect(shouldSync == false) + #expect(mockCatalogSizeChecker.checkCatalogSizeCallCount == 1) + #expect(mockCatalogSizeChecker.lastCheckedSiteID == sampleSiteID) + } + + @Test func shouldPerformFullSync_respects_time_only_when_catalog_size_is_acceptable() async throws { + // Given - catalog size is acceptable but sync is recent + mockCatalogSizeChecker.checkCatalogSizeResult = .success(POSCatalogSize(productCount: 200, variationCount: 100)) // 300 total + let thirtyMinutesAgo = Date().addingTimeInterval(-30 * 60) + try createSiteInDatabase(siteID: sampleSiteID, lastFullSyncDate: thirtyMinutesAgo) + + // When - max age is 1 hour + let shouldSync = await sut.shouldPerformFullSync(for: sampleSiteID, maxAge: 60 * 60) + + // Then - should not sync because time hasn't passed yet + #expect(shouldSync == false) + #expect(mockCatalogSizeChecker.checkCatalogSizeCallCount == 1) + } + // MARK: - Database Check Tests @Test func shouldPerformFullSync_returns_true_when_site_not_in_database() async { @@ -246,7 +322,8 @@ struct POSCatalogSyncCoordinatorTests { fullSyncService: mockSyncService, incrementalSyncService: mockIncrementalSyncService, grdbManager: grdbManager, - maxIncrementalSyncAge: maxAge + maxIncrementalSyncAge: maxAge, + catalogSizeChecker: mockCatalogSizeChecker ) // When @@ -268,7 +345,8 @@ struct POSCatalogSyncCoordinatorTests { fullSyncService: mockSyncService, incrementalSyncService: mockIncrementalSyncService, grdbManager: grdbManager, - maxIncrementalSyncAge: maxAge + maxIncrementalSyncAge: maxAge, + catalogSizeChecker: mockCatalogSizeChecker ) // When @@ -304,7 +382,8 @@ struct POSCatalogSyncCoordinatorTests { fullSyncService: mockSyncService, incrementalSyncService: mockIncrementalSyncService, grdbManager: grdbManager, - maxIncrementalSyncAge: maxAge + maxIncrementalSyncAge: maxAge, + catalogSizeChecker: mockCatalogSizeChecker ) // When @@ -400,6 +479,98 @@ struct POSCatalogSyncCoordinatorTests { #expect(mockIncrementalSyncService.startIncrementalSyncCallCount == 2) } + // MARK: - Incremental Sync Catalog Size Tests + + @Test(arguments: [true, false]) + func performIncrementalSyncIfApplicable_skips_sync_when_catalog_size_exceeds_limit(forceSync: Bool) async throws { + // Given - catalog size is above the 1000 item limit + mockCatalogSizeChecker.checkCatalogSizeResult = .success(POSCatalogSize(productCount: 700, variationCount: 400)) // 1100 total + let fullSyncDate = Date().addingTimeInterval(-3600) + try createSiteInDatabase(siteID: sampleSiteID, lastFullSyncDate: fullSyncDate) + + // When + try await sut.performIncrementalSyncIfApplicable(for: sampleSiteID, forceSync: forceSync) + + // Then + #expect(mockIncrementalSyncService.startIncrementalSyncCallCount == 0) + #expect(mockCatalogSizeChecker.checkCatalogSizeCallCount == 1) + #expect(mockCatalogSizeChecker.lastCheckedSiteID == sampleSiteID) + } + + @Test(arguments: [true, false]) + func performIncrementalSyncIfApplicable_performs_sync_when_catalog_size_is_at_limit(forceSync: Bool) async throws { + // Given - catalog size is exactly at the 1000 item limit + mockCatalogSizeChecker.checkCatalogSizeResult = .success(POSCatalogSize(productCount: 500, variationCount: 500)) // 1000 total + let fullSyncDate = Date().addingTimeInterval(-3600) + try createSiteInDatabase(siteID: sampleSiteID, lastFullSyncDate: fullSyncDate) + + // When + try await sut.performIncrementalSyncIfApplicable(for: sampleSiteID, forceSync: forceSync) + + // Then + #expect(mockIncrementalSyncService.startIncrementalSyncCallCount == 1) + #expect(mockCatalogSizeChecker.checkCatalogSizeCallCount == 1) + #expect(mockCatalogSizeChecker.lastCheckedSiteID == sampleSiteID) + } + + @Test(arguments: [true, false]) + func performIncrementalSyncIfApplicable_performs_sync_when_catalog_size_is_under_limit(forceSync: Bool) async throws { + // Given - catalog size is below the 1000 item limit + mockCatalogSizeChecker.checkCatalogSizeResult = .success(POSCatalogSize(productCount: 200, variationCount: 150)) // 350 total + let fullSyncDate = Date().addingTimeInterval(-3600) + try createSiteInDatabase(siteID: sampleSiteID, lastFullSyncDate: fullSyncDate) + + // When + try await sut.performIncrementalSyncIfApplicable(for: sampleSiteID, forceSync: forceSync) + + // Then + #expect(mockIncrementalSyncService.startIncrementalSyncCallCount == 1) + #expect(mockCatalogSizeChecker.checkCatalogSizeCallCount == 1) + #expect(mockCatalogSizeChecker.lastCheckedSiteID == sampleSiteID) + } + + @Test(arguments: [true, false]) + func performIncrementalSyncIfApplicable_skips_sync_when_catalog_size_check_fails(forceSync: Bool) async throws { + // Given - catalog size check throws an error + let sizeCheckError = NSError(domain: "size_check", code: 500, userInfo: [NSLocalizedDescriptionKey: "Network error"]) + mockCatalogSizeChecker.checkCatalogSizeResult = .failure(sizeCheckError) + let fullSyncDate = Date().addingTimeInterval(-3600) + try createSiteInDatabase(siteID: sampleSiteID, lastFullSyncDate: fullSyncDate) + + // When + try await sut.performIncrementalSyncIfApplicable(for: sampleSiteID, forceSync: forceSync) + + // Then - should skip sync when size check fails + #expect(mockIncrementalSyncService.startIncrementalSyncCallCount == 0) + #expect(mockCatalogSizeChecker.checkCatalogSizeCallCount == 1) + #expect(mockCatalogSizeChecker.lastCheckedSiteID == sampleSiteID) + } + + @Test func performIncrementalSyncIfApplicable_checks_size_before_age_check() async throws { + // Given - catalog is over limit but would otherwise sync due to age + mockCatalogSizeChecker.checkCatalogSizeResult = .success(POSCatalogSize(productCount: 800, variationCount: 300)) // 1100 total + let maxAge: TimeInterval = 2 + let staleIncrementalSyncDate = Date().addingTimeInterval(-(maxAge + 1)) // Older than max age + let fullSyncDate = Date().addingTimeInterval(-3600) + try createSiteInDatabase(siteID: sampleSiteID, lastFullSyncDate: fullSyncDate, lastIncrementalSyncDate: staleIncrementalSyncDate) + + let sut = POSCatalogSyncCoordinator( + fullSyncService: mockSyncService, + incrementalSyncService: mockIncrementalSyncService, + grdbManager: grdbManager, + maxIncrementalSyncAge: maxAge, + catalogSizeChecker: mockCatalogSizeChecker + ) + + // When + try await sut.performIncrementalSyncIfApplicable(for: sampleSiteID, forceSync: false) + + // Then - should skip sync due to size limit, regardless of age + #expect(mockIncrementalSyncService.startIncrementalSyncCallCount == 0) + #expect(mockCatalogSizeChecker.checkCatalogSizeCallCount == 1) + #expect(mockCatalogSizeChecker.lastCheckedSiteID == sampleSiteID) + } + // MARK: - Helper Methods private func createSiteInDatabase(siteID: Int64, lastFullSyncDate: Date? = nil, lastIncrementalSyncDate: Date? = nil) throws { diff --git a/WooCommerce/Classes/Yosemite/AuthenticatedState.swift b/WooCommerce/Classes/Yosemite/AuthenticatedState.swift index 1f251d3ed8b..a3ad1c7c5d6 100644 --- a/WooCommerce/Classes/Yosemite/AuthenticatedState.swift +++ b/WooCommerce/Classes/Yosemite/AuthenticatedState.swift @@ -149,10 +149,13 @@ class AuthenticatedState: StoresManagerState { if ServiceLocator.featureFlagService.isFeatureFlagEnabled(.pointOfSaleLocalCatalogi1), let fullSyncService = POSCatalogFullSyncService(credentials: credentials, grdbManager: ServiceLocator.grdbManager), let incrementalSyncService = POSCatalogIncrementalSyncService(credentials: credentials, grdbManager: ServiceLocator.grdbManager) { + let syncRemote = POSCatalogSyncRemote(network: network) + let catalogSizeChecker = POSCatalogSizeChecker(syncRemote: syncRemote) posCatalogSyncCoordinator = POSCatalogSyncCoordinator( fullSyncService: fullSyncService, incrementalSyncService: incrementalSyncService, - grdbManager: ServiceLocator.grdbManager + grdbManager: ServiceLocator.grdbManager, + catalogSizeChecker: catalogSizeChecker ) } else { posCatalogSyncCoordinator = nil