diff --git a/native/kotlin/api/kotlin/src/integrationTest/kotlin/ApiUrlDiscoveryTest.kt b/native/kotlin/api/kotlin/src/integrationTest/kotlin/ApiUrlDiscoveryTest.kt index 979569a1e..7dcce7e77 100644 --- a/native/kotlin/api/kotlin/src/integrationTest/kotlin/ApiUrlDiscoveryTest.kt +++ b/native/kotlin/api/kotlin/src/integrationTest/kotlin/ApiUrlDiscoveryTest.kt @@ -342,5 +342,6 @@ private fun FetchAndParseApiRootFailure.getRequestExecutionErrorReason(): Reques private fun RequestExecutionException.reason(): RequestExecutionErrorReason? { return when (this) { is RequestExecutionException.RequestExecutionFailed -> this.reason + is RequestExecutionException.MediaFileNotFound -> null } } diff --git a/native/kotlin/api/kotlin/src/integrationTest/kotlin/MediaEndpointTest.kt b/native/kotlin/api/kotlin/src/integrationTest/kotlin/MediaEndpointTest.kt index 8b77edf63..e9acf0062 100644 --- a/native/kotlin/api/kotlin/src/integrationTest/kotlin/MediaEndpointTest.kt +++ b/native/kotlin/api/kotlin/src/integrationTest/kotlin/MediaEndpointTest.kt @@ -67,10 +67,7 @@ class MediaEndpointTest { val title = "Testing media upload from Kotlin" val response = client.request { requestBuilder -> requestBuilder.media().create( - params = MediaCreateParams(title = title), - "test_media.jpg", - "image/jpeg", - null + params = MediaCreateParams(title = title, filePath = "test_media.jpg") ) }.assertSuccessAndRetrieveData().data assertEquals(title, response.title.rendered) diff --git a/native/kotlin/api/kotlin/src/integrationTest/kotlin/MockRequestExecutor.kt b/native/kotlin/api/kotlin/src/integrationTest/kotlin/MockRequestExecutor.kt index 2db1f170a..51446f795 100644 --- a/native/kotlin/api/kotlin/src/integrationTest/kotlin/MockRequestExecutor.kt +++ b/native/kotlin/api/kotlin/src/integrationTest/kotlin/MockRequestExecutor.kt @@ -2,9 +2,9 @@ package rs.wordpress.api.kotlin import kotlinx.coroutines.delay import okio.FileNotFoundException -import uniffi.wp_api.MediaUploadRequest import uniffi.wp_api.RequestContext import uniffi.wp_api.RequestExecutor +import uniffi.wp_api.WpMultipartFormRequest import uniffi.wp_api.WpNetworkHeaderMap import uniffi.wp_api.WpNetworkRequest import uniffi.wp_api.WpNetworkResponse @@ -41,7 +41,7 @@ class MockRequestExecutor(private var stubs: List = listOf()) : RequestExe throw NoStubFoundException("No stub found for ${request.url()}") } - override suspend fun uploadMedia(mediaUploadRequest: MediaUploadRequest): WpNetworkResponse { + override suspend fun upload(request: WpMultipartFormRequest): WpNetworkResponse { TODO("Not yet implemented") } diff --git a/native/kotlin/api/kotlin/src/main/kotlin/rs/wordpress/api/kotlin/WpRequestExecutor.kt b/native/kotlin/api/kotlin/src/main/kotlin/rs/wordpress/api/kotlin/WpRequestExecutor.kt index c15826bfd..6b52971e0 100644 --- a/native/kotlin/api/kotlin/src/main/kotlin/rs/wordpress/api/kotlin/WpRequestExecutor.kt +++ b/native/kotlin/api/kotlin/src/main/kotlin/rs/wordpress/api/kotlin/WpRequestExecutor.kt @@ -14,13 +14,12 @@ import okhttp3.RequestBody import okhttp3.RequestBody.Companion.asRequestBody import okhttp3.RequestBody.Companion.toRequestBody import uniffi.wp_api.InvalidSslErrorReason -import uniffi.wp_api.MediaUploadRequest -import uniffi.wp_api.MediaUploadRequestExecutionException import uniffi.wp_api.RequestContext import uniffi.wp_api.RequestExecutionErrorReason import uniffi.wp_api.RequestExecutionException import uniffi.wp_api.RequestExecutor import uniffi.wp_api.RequestMethod +import uniffi.wp_api.WpMultipartFormRequest import uniffi.wp_api.WpNetworkHeaderMap import uniffi.wp_api.WpNetworkRequest import uniffi.wp_api.WpNetworkResponse @@ -85,54 +84,57 @@ class WpRequestExecutor( } } - override suspend fun uploadMedia(mediaUploadRequest: MediaUploadRequest): WpNetworkResponse = + override suspend fun upload(request: WpMultipartFormRequest): WpNetworkResponse = withContext(dispatcher) { - val requestBuilder = Request.Builder().url(mediaUploadRequest.url()) + val requestBuilder = Request.Builder().url(request.url()) val multipartBodyBuilder = MultipartBody.Builder() .setType(MultipartBody.FORM) - mediaUploadRequest.mediaParams().forEach { (k, v) -> + request.fields().forEach { (k, v) -> multipartBodyBuilder.addFormDataPart(k, v) } - val file = fileResolver.getFile(mediaUploadRequest.filePath()) - if (file == null || !file.canBeUploaded()) { - throw MediaUploadRequestExecutionException.MediaFileNotFound(mediaUploadRequest.filePath()) + request.files().forEach { (name, fileInfo) -> + val file = fileResolver.getFile(fileInfo.filePath) + if (file == null || !file.canBeUploaded()) { + throw RequestExecutionException.MediaFileNotFound(filePath = fileInfo.filePath) + } + val mimeType = fileInfo.mimeType ?: "application/octet-stream" + val requestBody = getRequestBody(file, mimeType, uploadListener) + val filename = fileInfo.fileName ?: file.name + multipartBodyBuilder.addFormDataPart( + name = name, + filename = filename, + body = requestBody + ) } - val progressRequestBody = getRequestBody(file, mediaUploadRequest, uploadListener) - multipartBodyBuilder.addFormDataPart( - name = "file", - filename = file.name, - body = progressRequestBody - ) requestBuilder.method( - method = mediaUploadRequest.method().toString(), + method = request.method().toString(), body = multipartBodyBuilder.build() ) - mediaUploadRequest.headerMap().toMap().forEach { (key, values) -> + request.headerMap().toMap().forEach { (key, values) -> values.forEach { value -> requestBuilder.addHeader(key, value) } } val call = httpClient.getClient().newCall(requestBuilder.build()) - // Notify about the call creation so it can be cancelled if needed uploadListener?.onUploadStarted(CancellableCall(call)) call.execute().use { response -> return@withContext WpNetworkResponse( body = response.body?.bytes() ?: ByteArray(0), statusCode = response.code.toUShort(), responseHeaderMap = WpNetworkHeaderMap.fromMultiMap(response.headers.toMultimap()), - requestUrl = mediaUploadRequest.url(), - requestHeaderMap = mediaUploadRequest.headerMap() + requestUrl = request.url(), + requestHeaderMap = request.headerMap() ) } } private fun getRequestBody( file: File, - mediaUploadRequest: MediaUploadRequest, + mimeType: String, uploadListener: UploadListener? ): RequestBody { - val fileRequestBody = file.asRequestBody(mediaUploadRequest.fileContentType().toMediaType()) + val fileRequestBody = file.asRequestBody(mimeType.toMediaType()) return if (uploadListener != null) { ProgressRequestBody( delegate = fileRequestBody, diff --git a/native/swift/Example/Example/UI/UploadView.swift b/native/swift/Example/Example/UI/UploadView.swift index 0efce33cf..55d94a651 100644 --- a/native/swift/Example/Example/UI/UploadView.swift +++ b/native/swift/Example/Example/UI/UploadView.swift @@ -190,8 +190,7 @@ private class UploadViewModel: ObservableObject { NSLog("Uploading \(item)") _ = try await api.uploadMedia( - params: .init(), - fromLocalFileURL: file, + params: .init(filePath: file.path), fulfilling: child ) diff --git a/native/swift/Sources/wordpress-api/Exports.swift b/native/swift/Sources/wordpress-api/Exports.swift index caf647cbe..e2b82ffc0 100644 --- a/native/swift/Sources/wordpress-api/Exports.swift +++ b/native/swift/Sources/wordpress-api/Exports.swift @@ -145,7 +145,6 @@ public typealias RevisionsRequestListWithEmbedContextResponse = WordPressAPIInte // MARK: - Media public typealias SparseMedia = WordPressAPIInternal.SparseMedia -public typealias MediaUploadRequest = WordPressAPIInternal.MediaUploadRequest public typealias MediaWithEditContext = WordPressAPIInternal.MediaWithEditContext public typealias MediaWithViewContext = WordPressAPIInternal.MediaWithViewContext public typealias MediaWithEmbedContext = WordPressAPIInternal.MediaWithEmbedContext diff --git a/native/swift/Sources/wordpress-api/SafeRequestExecutor.swift b/native/swift/Sources/wordpress-api/SafeRequestExecutor.swift index a52e79609..82d0b087e 100644 --- a/native/swift/Sources/wordpress-api/SafeRequestExecutor.swift +++ b/native/swift/Sources/wordpress-api/SafeRequestExecutor.swift @@ -1,6 +1,10 @@ import Foundation import WordPressAPIInternal +#if canImport(UniformTypeIdentifiers) +import UniformTypeIdentifiers +#endif + #if canImport(FoundationNetworking) import FoundationNetworking #endif @@ -11,9 +15,7 @@ import Combine public protocol SafeRequestExecutor: RequestExecutor, Sendable { func execute(_ request: WpNetworkRequest) async -> Result - func uploadMedia( - mediaUploadRequest: MediaUploadRequest - ) async -> Result + func upload(request: WpMultipartFormRequest) async -> Result #if PROGRESS_REPORTING_ENABLED /// Returns a publisher that emits zero or one `Progress` instance representing the overall progress of the task @@ -28,8 +30,8 @@ extension SafeRequestExecutor { return try result.get() } - public func uploadMedia(mediaUploadRequest: MediaUploadRequest) async throws -> WpNetworkResponse { - let result = await uploadMedia(mediaUploadRequest: mediaUploadRequest) + public func upload(request: WpMultipartFormRequest) async throws -> WpNetworkResponse { + let result = await upload(request: request) return try result.get() } } @@ -59,20 +61,8 @@ public final class WpRequestExecutor: SafeRequestExecutor { await perform(request) } - public func uploadMedia( - mediaUploadRequest: MediaUploadRequest - ) async -> Result { - (await perform(mediaUploadRequest)) - .mapError { error in - switch error { - case let .RequestExecutionFailed(statusCode, redirects, reason): - MediaUploadRequestExecutionError.RequestExecutionFailed( - statusCode: statusCode, - redirects: redirects, - reason: reason - ) - } - } + public func upload(request: WpMultipartFormRequest) async -> Result { + await perform(request) } public func cancel(context: RequestContext) { @@ -93,6 +83,10 @@ public final class WpRequestExecutor: SafeRequestExecutor { return .success(try WpNetworkResponse(data: data, request: request, response: response)) } catch { + if let error = error as? RequestExecutionError { + return .failure(error) + } + if errorIsHttpsError(error) { return handleHttpsError(error, for: request) } @@ -380,7 +374,7 @@ extension WpNetworkRequest: NetworkRequestContent { } } -extension MediaUploadRequest: NetworkRequestContent { +extension WpMultipartFormRequest: NetworkRequestContent { func encodeBody(into request: inout URLRequest) throws { // Do nothing. @@ -394,10 +388,33 @@ extension MediaUploadRequest: NetworkRequestContent { var request = try buildURLRequest(additionalHeaders: headers) var form = [MultipartFormField]() - for (name, value) in mediaParams() { + for (name, value) in fields() { form.append(.init(text: value, name: name)) } - try form.append(.init(fileAtPath: filePath(), name: "file")) + for (name, file) in files() { + var mimeType = file.mimeType + + #if canImport(UniformTypeIdentifiers) + if mimeType == nil { + mimeType = UTType( + filenameExtension: URL(fileURLWithPath: file.filePath).pathExtension + )?.preferredMIMEType + } + #endif + + do { + try form.append( + .init( + fileAtPath: file.filePath, + name: name, + filename: file.fileName, + mimeType: mimeType + ) + ) + } catch { + throw RequestExecutionError.MediaFileNotFound(filePath: file.filePath) + } + } let boundery = String(format: "wordpressrs.%08x", Int.random(in: Int.min.. MediaRequestCreateResponse { - precondition(localFileURL.isFileURL) precondition(progress.completedUnitCount == 0 && progress.totalUnitCount > 0) precondition(progress.cancellationHandler == nil) - let requestId = WpUuid() + let context = RequestContext() - let fileContentType: String - if let mimeType { - fileContentType = mimeType - } else if let mimeType = UTType(filenameExtension: localFileURL.pathExtension)?.preferredMIMEType { - fileContentType = mimeType - } else { - fileContentType = "application/octet-stream" + let uploadTask = Task { + try await media.createCancellation(params: params, context: context) } - let cancellable = requestExecutor - .progress(forRequestWithId: requestId.uuidString()) - .sink { - progress.addChild($0, withPendingUnitCount: progress.totalUnitCount - progress.completedUnitCount) + let progressObserver = Task { + // A request id will be put into the `RequestContext` during the execution of the `media.create` above. + // This loop waits for the request id becomes available + let requestId: String + while true { + try await Task.sleep(nanoseconds: 100_000) + try Task.checkCancellation() + + guard let id = context.requestIds().first else { + continue + } + + requestId = id + break } - defer { - cancellable.cancel() - } - let uploadTask = Task { - try await media.create( - params: params, - filePath: localFileURL.path, - fileContentType: fileContentType, - requestId: requestId - ) + // Get the progress of the `URLSessionTask` of the given request id. + guard let task = await requestExecutor + .progress(forRequestWithId: requestId) + .values + .first(where: { _ in true }) else { return } + + try Task.checkCancellation() + + progress.addChild(task, withPendingUnitCount: progress.totalUnitCount - progress.completedUnitCount) } progress.cancellationHandler = { uploadTask.cancel() + progressObserver.cancel() } return try await withTaskCancellationHandler { diff --git a/native/swift/Tests/integration-tests/MediaTests.swift b/native/swift/Tests/integration-tests/MediaTests.swift index 86e12a61c..c0076d189 100644 --- a/native/swift/Tests/integration-tests/MediaTests.swift +++ b/native/swift/Tests/integration-tests/MediaTests.swift @@ -10,10 +10,7 @@ struct MediaTests { func uploadImage() async throws { let file = try #require(Bundle.module.url(forResource: "test-data/test_media.jpg", withExtension: nil)) let response = try await api.media.create( - params: .init(title: "Image", altText: "This is a test image"), - filePath: file.path, - fileContentType: "image/jpeg", - requestId: nil + params: .init(title: "Image", altText: "This is a test image", filePath: file.path) ) #expect(response.data.mimeType == "image/jpeg") #expect(response.data.title.raw == "Image") @@ -22,6 +19,19 @@ struct MediaTests { try await restoreTestServer() } + @Test + func fileNotFoundError() async throws { + let file = "/path/to/a/non-existent-file.jpg" + await #expect( + throws: WpApiError.MediaFileNotFound(filePath: file), + performing: { + _ = try await api.media.create(params: .init(filePath: file)) + } + ) + + try await restoreTestServer() + } + #if os(macOS) @Test func uploadProgress() async throws { @@ -30,8 +40,7 @@ struct MediaTests { let file = try #require(Bundle.module.url(forResource: "test-data/test_media.jpg", withExtension: nil)) let response = try await api.uploadMedia( - params: .init(), - fromLocalFileURL: file, + params: .init(filePath: file.path), fulfilling: progress ) #expect(response.data.mimeType == "image/jpeg") @@ -51,8 +60,7 @@ struct MediaTests { performing: { let task = Task { _ = try await api.uploadMedia( - params: .init(), - fromLocalFileURL: file, + params: .init(filePath: file.path), fulfilling: progress ) Issue.record("The creating post function should throw") @@ -80,8 +88,7 @@ struct MediaTests { performing: { let task = Task { _ = try await api.uploadMedia( - params: .init(), - fromLocalFileURL: file, + params: .init(filePath: file.path), fulfilling: progress ) Issue.record("The creating post function should throw") diff --git a/native/swift/Tests/wordpress-api/Support/HTTPStubs.swift b/native/swift/Tests/wordpress-api/Support/HTTPStubs.swift index 33d0f5e98..b5544e9d7 100644 --- a/native/swift/Tests/wordpress-api/Support/HTTPStubs.swift +++ b/native/swift/Tests/wordpress-api/Support/HTTPStubs.swift @@ -48,9 +48,7 @@ final class HTTPStubs: SafeRequestExecutor { } } - func uploadMedia( - mediaUploadRequest: MediaUploadRequest - ) async -> Result { + func upload(request: WpMultipartFormRequest) async -> Result { preconditionFailure("This method is not yet implemented") } diff --git a/wp_api/src/api_error.rs b/wp_api/src/api_error.rs index 2f1f2b38d..b6972417f 100644 --- a/wp_api/src/api_error.rs +++ b/wp_api/src/api_error.rs @@ -540,12 +540,18 @@ pub enum RequestExecutionError { redirects: Option>, reason: RequestExecutionErrorReason, }, + MediaFileNotFound { + file_path: String, + }, } impl WpSupportsLocalization for RequestExecutionError { fn message_bundle(&self) -> MessageBundle<'_> { match self { RequestExecutionError::RequestExecutionFailed { reason, .. } => reason.message_bundle(), + RequestExecutionError::MediaFileNotFound { file_path } => { + WpMessages::media_file_not_found(file_path) + } } } } @@ -695,31 +701,6 @@ impl WpSupportsLocalization for RequestExecutionErrorReason { } } -#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error, uniffi::Error, WpDeriveLocalizable)] -pub enum MediaUploadRequestExecutionError { - RequestExecutionFailed { - status_code: Option, - redirects: Option>, - reason: RequestExecutionErrorReason, - }, - MediaFileNotFound { - file_path: String, - }, -} - -impl WpSupportsLocalization for MediaUploadRequestExecutionError { - fn message_bundle(&self) -> MessageBundle<'_> { - match self { - MediaUploadRequestExecutionError::RequestExecutionFailed { reason, .. } => { - reason.message_bundle() - } - MediaUploadRequestExecutionError::MediaFileNotFound { file_path } => { - WpMessages::media_file_not_found(file_path) - } - } - } -} - impl From for WpApiError { fn from(value: RequestExecutionError) -> Self { match value { @@ -732,23 +713,7 @@ impl From for WpApiError { redirects, reason, }, - } - } -} - -impl From for WpApiError { - fn from(value: MediaUploadRequestExecutionError) -> Self { - match value { - MediaUploadRequestExecutionError::RequestExecutionFailed { - status_code, - redirects, - reason, - } => Self::RequestExecutionFailed { - status_code, - redirects, - reason, - }, - MediaUploadRequestExecutionError::MediaFileNotFound { file_path } => { + RequestExecutionError::MediaFileNotFound { file_path } => { Self::MediaFileNotFound { file_path } } } diff --git a/wp_api/src/media.rs b/wp_api/src/media.rs index a048e5f81..3c236a8c5 100644 --- a/wp_api/src/media.rs +++ b/wp_api/src/media.rs @@ -6,6 +6,7 @@ use crate::{ PostCommentStatus, PostId, PostPingStatus, PostStatus, WpApiParamPostsOrderBy, WpApiParamPostsSearchColumn, }, + request::{MultipartFormFile, RequiresMultipartForm}, url_query::{ AppendUrlQueryPairs, FromUrlQueryPairs, QueryPairs, QueryPairsExtension, UrlQueryPairsMap, }, @@ -310,6 +311,23 @@ pub struct MediaCreateParams { #[serde(skip_serializing_if = "Option::is_none")] pub post_id: Option, // meta field is omitted for now: https://github.com/Automattic/wordpress-rs/issues/381 + #[serde(skip)] + pub file_path: String, +} + +impl RequiresMultipartForm for MediaCreateParams { + fn multipart_form_files(&self) -> HashMap { + let mut files = HashMap::new(); + files.insert( + "file".to_string(), + MultipartFormFile { + file_path: self.file_path.clone(), + mime_type: None, + file_name: None, + }, + ); + files + } } impl From for HashMap { diff --git a/wp_api/src/middleware.rs b/wp_api/src/middleware.rs index e3f905d7a..b8114d49f 100644 --- a/wp_api/src/middleware.rs +++ b/wp_api/src/middleware.rs @@ -2,7 +2,7 @@ use crate::{ api_client::IsWpApiClientDelegate, api_error::{RequestExecutionError, RequestExecutionErrorReason}, request::RequestContext, - request::{RequestExecutor, WpNetworkRequest, WpNetworkResponse}, + request::{RequestExecutor, WpMultipartFormRequest, WpNetworkRequest, WpNetworkResponse}, }; use std::{fmt::Debug, sync::Arc, time::Duration}; @@ -124,6 +124,28 @@ pub trait PerformsRequests { Ok(response) } + + async fn perform_upload( + &self, + request: Arc, + context: Option>, + ) -> Result { + if let Some(context) = &context { + context.add_request_id(request.uuid.clone()); + } + + let response = self.get_request_executor().upload(request.clone()).await?; + + if let Some(reason) = RequestExecutionErrorReason::try_from_response(&response) { + return Err(RequestExecutionError::RequestExecutionFailed { + status_code: Some(response.status_code), + redirects: None, + reason, + }); + } + + Ok(response) + } } impl PerformsRequests for T @@ -259,13 +281,7 @@ mod tests { use super::*; mod api_discovery_authentication_middleware { - use crate::{ - api_error::MediaUploadRequestExecutionError, - request::{ - WpNetworkHeaderMap, - endpoint::{WpEndpointUrl, media_endpoint::MediaUploadRequest}, - }, - }; + use crate::request::{WpNetworkHeaderMap, endpoint::WpEndpointUrl}; use super::*; use async_trait::async_trait; @@ -286,17 +302,11 @@ mod tests { (self.execute_fn)(request) } - async fn upload_media( + async fn upload( &self, - _: Arc, - ) -> Result { - Err(MediaUploadRequestExecutionError::RequestExecutionFailed { - status_code: None, - redirects: None, - reason: RequestExecutionErrorReason::GenericError { - error_message: "upload_media is not used".to_string(), - }, - }) + _request: Arc, + ) -> Result { + unimplemented!() } async fn sleep(&self, _: u64) {} @@ -398,13 +408,7 @@ mod tests { mod retry_after_middleware { use super::*; - use crate::{ - api_error::MediaUploadRequestExecutionError, - request::{ - WpNetworkHeaderMap, - endpoint::{WpEndpointUrl, media_endpoint::MediaUploadRequest}, - }, - }; + use crate::request::{WpNetworkHeaderMap, endpoint::WpEndpointUrl}; use async_trait::async_trait; use http::HeaderMap; use std::sync::atomic::{AtomicBool, Ordering}; @@ -437,11 +441,11 @@ mod tests { } } - async fn upload_media( + async fn upload( &self, - _: Arc, - ) -> Result { - Err(MediaUploadRequestExecutionError::RequestExecutionFailed { + _request: Arc, + ) -> Result { + Err(RequestExecutionError::RequestExecutionFailed { status_code: None, redirects: None, reason: RequestExecutionErrorReason::GenericError { diff --git a/wp_api/src/prelude.rs b/wp_api/src/prelude.rs index 902a3d115..d2f1ce6c8 100644 --- a/wp_api/src/prelude.rs +++ b/wp_api/src/prelude.rs @@ -2,8 +2,8 @@ pub use crate::{ WpApiParamOrder, WpAppNotifier, WpContext, api_client::{IsWpApiClientDelegate, WpApiClient, WpApiClientDelegate, WpApiRequestBuilder}, api_error::{ - InvalidSslErrorReason, MaybeWpError, MediaUploadRequestExecutionError, ParsedRequestError, - RequestExecutionError, RequestExecutionErrorReason, WpApiError, WpError, WpErrorCode, + InvalidSslErrorReason, MaybeWpError, ParsedRequestError, RequestExecutionError, + RequestExecutionErrorReason, WpApiError, WpError, WpErrorCode, }, auth::{WpAuthentication, WpAuthenticationProvider}, date::WpGmtDateTime, @@ -14,7 +14,7 @@ pub use crate::{ request::{ NetworkRequestAccessor, RequestExecutor, WpNetworkHeaderMap, WpNetworkRequest, WpNetworkResponse, - endpoint::{ApiUrlResolver, WpOrgSiteApiUrlResolver, media_endpoint::MediaUploadRequest}, + endpoint::{ApiUrlResolver, WpOrgSiteApiUrlResolver}, }, uuid::{WpUuid, WpUuidParseError}, wp_content_i64_id, wp_content_string_id, diff --git a/wp_api/src/request.rs b/wp_api/src/request.rs index dff87550a..756462579 100644 --- a/wp_api/src/request.rs +++ b/wp_api/src/request.rs @@ -1,8 +1,8 @@ use self::endpoint::WpEndpointUrl; use crate::{ api_error::{ - MediaUploadRequestExecutionError, ParsedRequestError, RequestExecutionError, - RequestExecutionErrorReason, WpApiError, WpErrorCode, + ParsedRequestError, RequestExecutionError, RequestExecutionErrorReason, WpApiError, + WpErrorCode, }, auth::WpAuthenticationProvider, url_query::{FromUrlQueryPairs, UrlQueryPairsMap}, @@ -15,7 +15,6 @@ use endpoint::{ ApplicationPasswordsRequestBuilder, ApplicationPasswordsRequestRetrieveCurrentWithEditContextResponse, }, - media_endpoint::MediaUploadRequest, }; use http::{HeaderMap, HeaderName, HeaderValue}; use regex::Regex; @@ -115,6 +114,37 @@ impl InnerRequestBuilder { } } + pub fn post_multipart(&self, url: ApiEndpointUrl, params: &T) -> WpMultipartFormRequest + where + T: ?Sized + Serialize + RequiresMultipartForm, + { + let mut fields = HashMap::new(); + if let Ok(serde_json::Value::Object(object)) = serde_json::to_value(params) { + for (key, value) in object { + if let serde_json::Value::String(s) = value { + fields.insert(key, s); + } else { + fields.insert(key, value.to_string()); + } + } + } + + let mut header_map = self.header_map_for_post_request(); + header_map.inner.insert( + http::header::CONTENT_TYPE, + HeaderValue::from_static(CONTENT_TYPE_MULTIPART), + ); + + WpMultipartFormRequest { + uuid: Uuid::new_v4().into(), + method: RequestMethod::POST, + url: url.into(), + header_map: header_map.into(), + fields, + files: params.multipart_form_files(), + } + } + fn header_map(&self) -> WpNetworkHeaderMap { let mut header_map = HeaderMap::new(); header_map.insert( @@ -145,10 +175,10 @@ pub trait RequestExecutor: Send + Sync { request: Arc, ) -> Result; - async fn upload_media( + async fn upload( &self, - media_upload_request: Arc, - ) -> Result; + request: Arc, + ) -> Result; async fn sleep(&self, millis: u64); @@ -303,6 +333,76 @@ impl NetworkRequestAccessor for WpNetworkRequest { } } +#[derive(Debug, Clone, PartialEq, Eq, uniffi::Record)] +pub struct MultipartFormFile { + pub file_path: String, + pub mime_type: Option, + pub file_name: Option, +} + +pub trait RequiresMultipartForm { + fn multipart_form_files(&self) -> HashMap; +} + +#[derive(uniffi::Object)] +pub struct WpMultipartFormRequest { + pub(crate) uuid: String, + pub(crate) method: RequestMethod, + pub(crate) url: WpEndpointUrl, + pub(crate) header_map: Arc, + pub(crate) fields: HashMap, + pub(crate) files: HashMap, +} + +#[uniffi::export] +impl WpMultipartFormRequest { + pub fn fields(&self) -> HashMap { + self.fields.clone() + } + + pub fn files(&self) -> HashMap { + self.files.clone() + } +} + +impl std::fmt::Debug for WpMultipartFormRequest { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut s = format!( + indoc::indoc! {" + WpMultipartFormRequest {{ + method: '{:?}', + url: '{:?}', + header_map: '{:?}', + fields: '{:?}', + files: '{:?}' + }} + "}, + self.method, self.url, self.header_map, self.fields, self.files + ); + s.pop(); // Remove the new line at the end + write!(f, "{s}") + } +} + +#[uniffi::export] +impl NetworkRequestAccessor for WpMultipartFormRequest { + fn request_id(&self) -> String { + self.uuid.clone() + } + + fn method(&self) -> RequestMethod { + self.method.clone() + } + + fn url(&self) -> WpEndpointUrl { + self.url.clone() + } + + fn header_map(&self) -> Arc { + self.header_map.clone() + } +} + // Has custom `Debug` trait implementation #[derive(uniffi::Record)] pub struct WpNetworkResponse { diff --git a/wp_api/src/request/endpoint/media_endpoint.rs b/wp_api/src/request/endpoint/media_endpoint.rs index d5b507b2d..5d5a890fb 100644 --- a/wp_api/src/request/endpoint/media_endpoint.rs +++ b/wp_api/src/request/endpoint/media_endpoint.rs @@ -1,15 +1,7 @@ -use super::{AsNamespace, DerivedRequest, WpEndpointUrl, WpNamespace}; -use crate::{ - api_error::WpApiError, - media::{MediaCreateParams, MediaId, MediaListParams, MediaUpdateParams, MediaWithEditContext}, - request::{ - CONTENT_TYPE_MULTIPART, NetworkRequestAccessor, ParsedResponse, RequestMethod, - WpNetworkHeaderMap, WpNetworkResponse, - }, - uuid::WpUuid, +use super::{AsNamespace, DerivedRequest, WpNamespace}; +use crate::media::{ + MediaCreateParams, MediaId, MediaListParams, MediaUpdateParams, MediaWithEditContext, }; -use http::HeaderValue; -use std::{collections::HashMap, sync::Arc}; use wp_derive_request_builder::WpDerivedRequest; #[derive(WpDerivedRequest)] @@ -22,6 +14,8 @@ enum MediaRequest { Delete, #[post(url = "/media/", params = &MediaUpdateParams, output = MediaWithEditContext)] Update, + #[post(url = "/media", params = &MediaCreateParams, output = MediaWithEditContext, multipart = true)] + Create, } impl DerivedRequest for MediaRequest { @@ -39,166 +33,6 @@ impl DerivedRequest for MediaRequest { } } -impl MediaRequestEndpoint { - pub fn create(&self) -> crate::request::endpoint::ApiEndpointUrl { - Arc::unwrap_or_clone(self.api_url_resolver.resolve( - MediaRequest::namespace().namespace_value().to_string(), - vec!["media".to_string()], - )) - .inner - .into() - } -} - -#[derive(Debug, serde::Serialize, serde::Deserialize, uniffi::Record)] -#[serde(transparent)] -pub struct MediaRequestCreateResponse { - pub data: crate::media::MediaWithEditContext, - #[serde(skip)] - pub header_map: std::sync::Arc, -} - -impl From for ParsedResponse { - fn from(value: MediaRequestCreateResponse) -> Self { - Self { - data: value.data, - header_map: value.header_map, - next_page_params: None, - prev_page_params: None, - } - } -} -impl From> for MediaRequestCreateResponse { - fn from(value: ParsedResponse) -> Self { - Self { - data: value.data, - header_map: value.header_map, - } - } -} - -#[uniffi::export] -fn parse_as_media_request_create_response( - response: WpNetworkResponse, -) -> Result { - response.parse() -} - -#[derive(uniffi::Object)] -pub struct MediaUploadRequest { - pub(crate) uuid: String, - pub(crate) method: RequestMethod, - pub(crate) url: WpEndpointUrl, - pub(crate) header_map: Arc, - pub(crate) file_path: String, - pub(crate) file_content_type: String, - pub(crate) media_params: HashMap, -} - -#[uniffi::export] -impl MediaUploadRequest { - pub fn file_path(&self) -> String { - self.file_path.clone() - } - - pub fn file_content_type(&self) -> String { - self.file_content_type.clone() - } - - pub fn media_params(&self) -> HashMap { - self.media_params.clone() - } -} - -impl std::fmt::Debug for MediaUploadRequest { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut s = format!( - indoc::indoc! {" - MediaUploadRequest {{ - method: '{:?}', - url: '{:?}', - header_map: '{:?}', - file_path: '{:?}' - file_content_type: '{:?}' - media_params: '{:?}' - }} - "}, - self.method, - self.url, - self.header_map, - self.file_path, - self.file_content_type, - self.media_params - ); - s.pop(); // Remove the new line at the end - write!(f, "{s}") - } -} - -#[uniffi::export] -impl NetworkRequestAccessor for MediaUploadRequest { - fn request_id(&self) -> String { - self.uuid.clone() - } - - fn method(&self) -> RequestMethod { - self.method.clone() - } - - fn url(&self) -> WpEndpointUrl { - self.url.clone() - } - - fn header_map(&self) -> Arc { - self.header_map.clone() - } -} - -impl MediaRequestBuilder { - pub fn create( - &self, - params: MediaCreateParams, - file_path: String, - file_content_type: String, - request_id: Option>, - ) -> MediaUploadRequest { - let mut header_map = self.inner.header_map(); - header_map.inner.insert( - http::header::CONTENT_TYPE, - HeaderValue::from_static(CONTENT_TYPE_MULTIPART), - ); - MediaUploadRequest { - uuid: request_id.unwrap_or_default().uuid_string(), - method: RequestMethod::POST, - url: self.endpoint.create().into(), - header_map: header_map.into(), - file_path, - file_content_type, - media_params: params.into(), - } - } -} - -#[uniffi::export] -impl MediaRequestExecutor { - pub async fn create( - &self, - params: MediaCreateParams, - file_path: String, - file_content_type: String, - request_id: Option>, - ) -> Result { - let request = self - .request_builder - .create(params, file_path, file_content_type, request_id); - self.delegate - .request_executor - .upload_media(Arc::new(request)) - .await? - .parse() - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/wp_api/src/reqwest_request_executor.rs b/wp_api/src/reqwest_request_executor.rs index 29b49a242..7c2347542 100644 --- a/wp_api/src/reqwest_request_executor.rs +++ b/wp_api/src/reqwest_request_executor.rs @@ -1,13 +1,9 @@ use crate::{ - api_error::{ - InvalidSslErrorReason, MediaUploadRequestExecutionError, RequestExecutionError, - RequestExecutionErrorReason, - }, + api_error::{InvalidSslErrorReason, RequestExecutionError, RequestExecutionErrorReason}, request::RequestContext, request::{ - NetworkRequestAccessor, RequestExecutor, RequestMethod, WpNetworkHeaderMap, - WpNetworkRequest, WpNetworkResponse, endpoint::media_endpoint::MediaUploadRequest, - user_agent, + NetworkRequestAccessor, RequestExecutor, RequestMethod, WpMultipartFormRequest, + WpNetworkHeaderMap, WpNetworkRequest, WpNetworkResponse, user_agent, }, }; use async_trait::async_trait; @@ -87,47 +83,6 @@ impl ReqwestRequestExecutor { }) } - pub async fn upload_media_request( - &self, - media_upload_request: Arc, - ) -> Result { - let request = self - .client - .request( - Self::request_method(media_upload_request.method()), - media_upload_request.url().0.as_str(), - ) - .headers(media_upload_request.header_map().to_header_map()); - let file_path = media_upload_request.file_path(); - let mut file_header_map = HeaderMap::new(); - file_header_map.insert( - http::header::CONTENT_TYPE, - HeaderValue::from_str(&media_upload_request.file_content_type()).unwrap(), - ); - let mut form = reqwest::multipart::Form::new().part( - "file", - Part::file(file_path) - .await - .unwrap() - .headers(file_header_map), - ); - for (k, v) in media_upload_request.media_params() { - form = form.text(k, v) - } - - let request = request.multipart(form); - let mut response = request.send().await?; - - let header_map = std::mem::take(response.headers_mut()); - Ok(WpNetworkResponse { - status_code: response.status().as_u16(), - body: response.bytes().await.unwrap().to_vec(), - response_header_map: Arc::new(WpNetworkHeaderMap::new(header_map)), - request_url: media_upload_request.url(), - request_header_map: media_upload_request.header_map(), - }) - } - pub fn request_method(method: RequestMethod) -> http::Method { match method { RequestMethod::GET => reqwest::Method::GET, @@ -148,21 +103,50 @@ impl RequestExecutor for ReqwestRequestExecutor { self.async_request(request).await.map_err(|e| e.into()) } - async fn upload_media( + async fn upload( &self, - media_upload_request: Arc, - ) -> Result { - self.upload_media_request(media_upload_request) - .await - .map_err( - |err| MediaUploadRequestExecutionError::RequestExecutionFailed { - status_code: err.status().map(|s| s.as_u16()), - redirects: None, - reason: RequestExecutionErrorReason::GenericError { - error_message: err.to_string(), - }, - }, + upload_request: Arc, + ) -> Result { + let request = self + .client + .request( + Self::request_method(upload_request.method()), + upload_request.url().0.as_str(), ) + .headers(upload_request.header_map().to_header_map()); + let mut form = reqwest::multipart::Form::new(); + + for (name, file) in upload_request.files() { + let file_path = file.file_path; + let mut file_header_map = HeaderMap::new(); + if let Some(mime_type) = &file.mime_type { + file_header_map.insert( + http::header::CONTENT_TYPE, + HeaderValue::from_str(mime_type).unwrap(), + ); + } + let part = Part::file(file_path) + .await + .unwrap() + .headers(file_header_map); + form = form.part(name, part); + } + + for (k, v) in upload_request.fields() { + form = form.text(k, v) + } + + let request = request.multipart(form); + let mut response = request.send().await?; + + let header_map = std::mem::take(response.headers_mut()); + Ok(WpNetworkResponse { + status_code: response.status().as_u16(), + body: response.bytes().await.unwrap().to_vec(), + response_header_map: Arc::new(WpNetworkHeaderMap::new(header_map)), + request_url: upload_request.url(), + request_header_map: upload_request.header_map(), + }) } async fn sleep(&self, millis: u64) { diff --git a/wp_api/src/wordpress_org/client.rs b/wp_api/src/wordpress_org/client.rs index b92a895f6..bb322af54 100644 --- a/wp_api/src/wordpress_org/client.rs +++ b/wp_api/src/wordpress_org/client.rs @@ -244,6 +244,11 @@ impl From for WordPressOrgApiClientError { redirects, reason, }, + RequestExecutionError::MediaFileNotFound { .. } => { + WordPressOrgApiClientError::RequestEncodingError { + reason: "file not found".to_string(), + } + } } } } diff --git a/wp_api/src/wp_com/support_tickets.rs b/wp_api/src/wp_com/support_tickets.rs index 324ce334a..7906e375b 100644 --- a/wp_api/src/wp_com/support_tickets.rs +++ b/wp_api/src/wp_com/support_tickets.rs @@ -2,7 +2,11 @@ use std::collections::HashMap; use serde::{Deserialize, Serialize}; -use crate::{date::WpGmtDateTime, impl_as_query_value_for_new_type}; +use crate::{ + date::WpGmtDateTime, + impl_as_query_value_for_new_type, + request::{MultipartFormFile, RequiresMultipartForm}, +}; #[derive(Debug, PartialEq, Eq, Serialize, uniffi::Record)] pub struct CreateSupportTicketParams { @@ -15,9 +19,30 @@ pub struct CreateSupportTicketParams { #[uniffi(default = [])] pub tags: Vec, #[uniffi(default = [])] + #[serde(skip)] pub attachments: Vec, } +impl RequiresMultipartForm for CreateSupportTicketParams { + fn multipart_form_files(&self) -> HashMap { + self.attachments + .iter() + .enumerate() + .map(|(i, file_path)| { + ( + // TODO: The backend is not ready yet. This name may need to be changed. + format!("attachment_{i}"), + MultipartFormFile { + file_path: file_path.clone(), + mime_type: None, + file_name: None, + }, + ) + }) + .collect() + } +} + #[derive(Debug, PartialEq, Eq, Serialize, Deserialize, uniffi::Record)] pub struct SupportConversationSummary { pub id: ConversationId, diff --git a/wp_api_integration_tests/src/mock.rs b/wp_api_integration_tests/src/mock.rs index 212d969a0..d87bf4d76 100644 --- a/wp_api_integration_tests/src/mock.rs +++ b/wp_api_integration_tests/src/mock.rs @@ -1,26 +1,20 @@ use async_trait::async_trait; use std::sync::Arc; use wp_api::{ - prelude::*, request::RequestContext, request::endpoint::media_endpoint::MediaUploadRequest, + prelude::*, + request::{RequestContext, WpMultipartFormRequest}, }; #[derive(Debug)] pub struct MockExecutor { execute_fn: fn(Arc) -> Result, - upload_media_fn: - fn(Arc) -> Result, } impl MockExecutor { pub fn with_execute_fn( execute_fn: fn(Arc) -> Result, ) -> Self { - Self { - execute_fn, - upload_media_fn: |_: Arc| { - panic!("Upload media is not implemented for `MockExecutor`") - }, - } + Self { execute_fn } } } @@ -33,11 +27,11 @@ impl RequestExecutor for MockExecutor { (self.execute_fn)(request) } - async fn upload_media( + async fn upload( &self, - media_upload_request: Arc, - ) -> Result { - (self.upload_media_fn)(media_upload_request) + _request: Arc, + ) -> Result { + unimplemented!() } async fn sleep(&self, _: u64) {} diff --git a/wp_api_integration_tests/tests/test_app_notifier_immut.rs b/wp_api_integration_tests/tests/test_app_notifier_immut.rs index f55408904..868123b67 100644 --- a/wp_api_integration_tests/tests/test_app_notifier_immut.rs +++ b/wp_api_integration_tests/tests/test_app_notifier_immut.rs @@ -2,7 +2,10 @@ use std::sync::{ Mutex, atomic::{AtomicBool, Ordering}, }; -use wp_api::{request::RequestContext, users::UserListParams}; +use wp_api::{ + request::{RequestContext, WpMultipartFormRequest}, + users::UserListParams, +}; use wp_api_integration_tests::prelude::*; #[tokio::test] @@ -125,11 +128,11 @@ impl RequestExecutor for TrackedRequestExecutor { self.executor.execute(request).await } - async fn upload_media( + async fn upload( &self, - media_upload_request: Arc, - ) -> Result { - self.upload_media(media_upload_request).await + _request: Arc, + ) -> Result { + unimplemented!() } async fn sleep(&self, _: u64) {} diff --git a/wp_api_integration_tests/tests/test_media_err.rs b/wp_api_integration_tests/tests/test_media_err.rs index a51c5285d..28bf4c36b 100644 --- a/wp_api_integration_tests/tests/test_media_err.rs +++ b/wp_api_integration_tests/tests/test_media_err.rs @@ -3,8 +3,7 @@ use wp_api::{ media::{MediaCreateParams, MediaId, MediaListParams, MediaUpdateParams}, posts::WpApiParamPostsOrderBy, prelude::*, - request::RequestContext, - request::endpoint::media_endpoint::MediaUploadRequest, + request::{RequestContext, WpMultipartFormRequest}, users::UserId, }; use wp_api_integration_tests::prelude::*; @@ -14,12 +13,10 @@ use wp_api_integration_tests::prelude::*; async fn create_media_err_cannot_create() { api_client_as_subscriber() .media() - .create( - MediaCreateParams::default(), - MEDIA_TEST_FILE_PATH.to_string(), - MEDIA_TEST_FILE_CONTENT_TYPE.to_string(), - None, - ) + .create(&MediaCreateParams { + file_path: MEDIA_TEST_FILE_PATH.to_string(), + ..Default::default() + }) .await .assert_wp_error(WpErrorCode::CannotCreate) } @@ -29,12 +26,10 @@ async fn create_media_err_cannot_create() { async fn create_media_err_upload_no_data() { api_client_with_medir_err_networking(MediaErrNetworkingTestType::UploadNoData) .media() - .create( - MediaCreateParams::default(), - MEDIA_TEST_FILE_PATH.to_string(), - MEDIA_TEST_FILE_CONTENT_TYPE.to_string(), - None, - ) + .create(&MediaCreateParams { + file_path: MEDIA_TEST_FILE_PATH.to_string(), + ..Default::default() + }) .await .assert_wp_error(WpErrorCode::UploadNoData) } @@ -210,50 +205,39 @@ impl RequestExecutor for MediaErrNetworking { }) } - async fn upload_media( + async fn upload( &self, - media_upload_request: Arc, - ) -> Result { - let mut request = self + upload_request: Arc, + ) -> Result { + let request = self .client .request( - ReqwestRequestExecutor::request_method(media_upload_request.method()), - media_upload_request.url().0.as_str(), + ReqwestRequestExecutor::request_method(upload_request.method()), + upload_request.url().0.as_str(), ) - .headers(media_upload_request.header_map().to_header_map()); - let mut file_header_map = HeaderMap::new(); - file_header_map.insert( - http::header::CONTENT_TYPE, - HeaderValue::from_str(&media_upload_request.file_content_type()).unwrap(), - ); + .headers(upload_request.header_map().to_header_map()); let mut form = reqwest::multipart::Form::new(); + match self.test_type { MediaErrNetworkingTestType::UploadNoData => { // don't add the file } } - for (k, v) in media_upload_request.media_params() { + + for (k, v) in upload_request.fields() { form = form.text(k, v) } - request = request.multipart(form); - let mut response = request.send().await.map_err(|err| { - MediaUploadRequestExecutionError::RequestExecutionFailed { - status_code: err.status().map(|s| s.as_u16()), - redirects: None, - reason: RequestExecutionErrorReason::GenericError { - error_message: err.to_string(), - }, - } - })?; + let request = request.multipart(form); + let mut response = request.send().await?; let header_map = std::mem::take(response.headers_mut()); Ok(WpNetworkResponse { status_code: response.status().as_u16(), body: response.bytes().await.unwrap().to_vec(), response_header_map: Arc::new(WpNetworkHeaderMap::new(header_map)), - request_url: media_upload_request.url(), - request_header_map: media_upload_request.header_map(), + request_url: upload_request.url(), + request_header_map: upload_request.header_map(), }) } diff --git a/wp_api_integration_tests/tests/test_media_mut.rs b/wp_api_integration_tests/tests/test_media_mut.rs index fe6429c7a..1abd07803 100644 --- a/wp_api_integration_tests/tests/test_media_mut.rs +++ b/wp_api_integration_tests/tests/test_media_mut.rs @@ -11,15 +11,11 @@ async fn upload_media() { let title = "Foo media"; let created_media = api_client() .media() - .create( - MediaCreateParams { - title: Some(title.to_string()), - ..Default::default() - }, - MEDIA_TEST_FILE_PATH.to_string(), - MEDIA_TEST_FILE_CONTENT_TYPE.to_string(), - None, - ) + .create(&MediaCreateParams { + title: Some(title.to_string()), + file_path: MEDIA_TEST_FILE_PATH.to_string(), + ..Default::default() + }) .await .assert_response(); assert_eq!(created_media.data.title.rendered.as_str(), title); diff --git a/wp_derive_request_builder/src/generate.rs b/wp_derive_request_builder/src/generate.rs index 4ba0b5600..a5095a2e1 100644 --- a/wp_derive_request_builder/src/generate.rs +++ b/wp_derive_request_builder/src/generate.rs @@ -78,6 +78,12 @@ fn generate_async_request_executor( #fn_signature_body } }; + let perform_call = if variant.attr.multipart { + quote! { self.perform_upload(std::sync::Arc::new(request), context.clone()).await? } + } else { + quote! { self.perform(std::sync::Arc::new(request), context.clone()).await? } + }; + let cancellable = quote! { pub async #fn_signature_cancellable -> Result<#response_type_ident, #error_type> { use #crate_ident::api_error::MaybeWpError; @@ -86,7 +92,7 @@ fn generate_async_request_executor( let perform_request = async || { #request_from_request_builder let request_url: String = request.url().into(); - let response = self.perform(std::sync::Arc::new(request), context.clone()).await?; + let response = #perform_call; let response_status_code = response.status_code; let parsed_response = response.parse(); let unauthorized = parsed_response.is_unauthorized_error().unwrap_or_default() || (response_status_code == 401 && self.fetch_authentication_state().await.map(|auth_state| auth_state.is_unauthorized()).unwrap_or_default()); @@ -254,6 +260,7 @@ fn generate_request_builder(config: &Config, parsed_enum: &ParsedEnum) -> TokenS let static_inner_request_builder_type = &config.static_types.inner_request_builder; let static_auth_provider_type = &config.static_types.auth_provider; let static_wp_network_request_type = &config.static_types.wp_network_request; + let static_wp_multipart_form_request_type = &config.static_types.wp_multipart_form_request; let generated_endpoint_ident = &config.generated_idents.endpoint; let generated_request_builder_ident = &config.generated_idents.request_builder; @@ -283,10 +290,18 @@ fn generate_request_builder(config: &Config, parsed_enum: &ParsedEnum) -> TokenS variant.attr.request_type, &context_and_filter_handler, ); - let fn_body_build_request_from_url = - fn_body_build_request_from_url(params_type.as_ref(), variant.attr.request_type); + let fn_body_build_request_from_url = fn_body_build_request_from_url( + params_type.as_ref(), + variant.attr.request_type, + variant.attr.multipart, + ); + let return_type = if variant.attr.multipart { + static_wp_multipart_form_request_type + } else { + static_wp_network_request_type + }; quote! { - pub #fn_signature -> #static_wp_network_request_type { + pub #fn_signature -> #return_type { #url_from_endpoint #fn_body_build_request_from_url } @@ -503,6 +518,7 @@ pub struct ConfigStaticTypes { pub inner_request_builder: TokenStream, pub auth_provider: TokenStream, pub wp_network_request: TokenStream, + pub wp_multipart_form_request: TokenStream, } impl ConfigStaticTypes { @@ -513,6 +529,7 @@ impl ConfigStaticTypes { inner_request_builder: quote! { #crate_ident::request::InnerRequestBuilder }, auth_provider: quote! { std::sync::Arc<#crate_ident::auth::WpAuthenticationProvider> }, wp_network_request: quote! { #crate_ident::request::WpNetworkRequest }, + wp_multipart_form_request: quote! { #crate_ident::request::WpMultipartFormRequest }, } } } diff --git a/wp_derive_request_builder/src/generate/helpers_to_generate_tokens.rs b/wp_derive_request_builder/src/generate/helpers_to_generate_tokens.rs index e4d68e48b..20344e00d 100644 --- a/wp_derive_request_builder/src/generate/helpers_to_generate_tokens.rs +++ b/wp_derive_request_builder/src/generate/helpers_to_generate_tokens.rs @@ -399,6 +399,7 @@ pub fn fn_body_context_query_pairs( pub fn fn_body_build_request_from_url( params_type: Option<&ParamsType>, request_type: RequestType, + multipart: bool, ) -> TokenStream { match request_type { RequestType::ContextualGet | RequestType::ContextualPaged | RequestType::Get => quote! { @@ -408,7 +409,17 @@ pub fn fn_body_build_request_from_url( self.inner.delete(url) }, RequestType::Post => { - if params_type.is_some() { + if multipart { + if params_type.is_some() { + quote! { + self.inner.post_multipart(url, params) + } + } else { + quote! { + compile_error!("multipart POST requires params") + } + } + } else if params_type.is_some() { quote! { self.inner.post(url, Some(params)) } @@ -1081,35 +1092,46 @@ mod tests { } #[rstest] - #[case(None, RequestType::ContextualGet, "self . inner . get (url)")] + #[case(None, RequestType::ContextualGet, false, "self . inner . get (url)")] #[case( referenced_params_type("UserListParams"), RequestType::ContextualGet, + false, "self . inner . get (url)" )] - #[case(None, RequestType::Delete, "self . inner . delete (url)")] + #[case(None, RequestType::Delete, false, "self . inner . delete (url)")] #[case( referenced_params_type("UserListParams"), RequestType::Delete, + false, "self . inner . delete (url)" )] #[case( None, RequestType::Post, + false, "self . inner . post (url , None :: < & () >)" )] #[case( referenced_params_type("UserListParams"), RequestType::Post, + false, "self . inner . post (url , Some (params))" )] + #[case( + referenced_params_type("UserListParams"), + RequestType::Post, + true, + "self . inner . post_multipart (url , params)" + )] fn test_fn_body_build_request_from_url( #[case] params: Option, #[case] request_type: RequestType, + #[case] multipart: bool, #[case] expected_str: &str, ) { assert_eq!( - fn_body_build_request_from_url(params.as_ref(), request_type).to_string(), + fn_body_build_request_from_url(params.as_ref(), request_type, multipart).to_string(), expected_str ); } diff --git a/wp_derive_request_builder/src/variant_attr.rs b/wp_derive_request_builder/src/variant_attr.rs index ed6faccf2..7c8aa64fc 100644 --- a/wp_derive_request_builder/src/variant_attr.rs +++ b/wp_derive_request_builder/src/variant_attr.rs @@ -25,6 +25,7 @@ pub struct ParsedVariantAttribute { pub output: Vec, pub filter_by: Option, pub available_contexts: Vec, + pub multipart: bool, } impl ParsedVariantAttribute { @@ -35,6 +36,7 @@ impl ParsedVariantAttribute { output: Vec, filter_by: Option>, available_contexts: Vec, + multipart: bool, ) -> Result { let non_empty_token_tree_or_none = |tokens: Option>| -> Option> { @@ -63,6 +65,7 @@ impl ParsedVariantAttribute { tokens: TokenStream::from_iter(tokens), }), available_contexts, + multipart, }) } @@ -232,6 +235,39 @@ impl ParsedVariantAttribute { .collect() } + fn parse_bool_flag( + bool_tokens: Option<&Vec>, + error_span: Span, + ) -> syn::Result { + if let Some(bool_tokens) = bool_tokens { + if bool_tokens.len() != 1 { + return Err(ItemVariantAttributeParseError::BoolFlagShouldBeASingleValue + .into_syn_error(error_span)); + } + let first_token = bool_tokens + .first() + .expect("Already verified that there is only one token"); + + match first_token { + TokenTree::Ident(ident) => { + let ident_str = ident.to_string(); + match ident_str.as_str() { + "true" => Ok(true), + "false" => Ok(false), + _ => Err(ItemVariantAttributeParseError::BoolFlagUnexpectedValue { + unexpected_value: ident_str, + } + .into_syn_error(error_span)), + } + } + _ => Err(ItemVariantAttributeParseError::BoolFlagUnexpectedTokens + .into_syn_error(error_span)), + } + } else { + Ok(false) + } + } + fn available_contexts( available_contexts_tokens: Option<&Vec>, error_span: Span, @@ -305,6 +341,7 @@ impl Parse for ParsedVariantAttribute { let mut output_tokens = None; let mut filter_by_tokens = None; let mut available_contexts_tokens = None; + let mut multipart_tokens = None; for (ident, tokens) in pair_vec.into_iter() { match ident.to_string().as_str() { @@ -313,6 +350,7 @@ impl Parse for ParsedVariantAttribute { "output" => output_tokens = Some(tokens), "filter_by" => filter_by_tokens = Some(tokens), "available_contexts" => available_contexts_tokens = Some(tokens), + "multipart" => multipart_tokens = Some(tokens), _ => { return Err(ItemVariantAttributeParseError::ExpectingKeyValuePairs .into_syn_error(meta_list_span)); @@ -351,6 +389,8 @@ impl Parse for ParsedVariantAttribute { available_contexts_tokens.as_ref(), input.span(), )?; + let multipart = + ParsedVariantAttribute::parse_bool_flag(multipart_tokens.as_ref(), input.span())?; ParsedVariantAttribute::new( request_type, @@ -359,6 +399,7 @@ impl Parse for ParsedVariantAttribute { output, filter_by_tokens, available_contexts, + multipart, ) .map_err(|e| e.into_syn_error(meta_list_span)) } @@ -403,6 +444,15 @@ enum ItemVariantAttributeParseError { AvailableContextShouldBeLiteral, #[error("Only 'contextual_get', 'contextual_paged', 'get', 'post' & 'delete' are supported")] UnsupportedRequestType, + #[error("Boolean flag contains unexpected tokens. It should be either 'true' or 'false'")] + BoolFlagUnexpectedTokens, + #[error("Boolean flag should be a single value: 'true' or 'false'")] + BoolFlagShouldBeASingleValue, + #[error( + "Boolean flag contains unexpected value: '{}'. Expected 'true' or 'false'", + unexpected_value + )] + BoolFlagUnexpectedValue { unexpected_value: String }, } impl ItemVariantAttributeParseError { @@ -466,4 +516,27 @@ mod tests { expected_url_parts ); } + + #[test] + fn test_multipart_flag_default() { + let parsed: ParsedVariantAttribute = + syn::parse_str(r#"#[post(url = "/test", output = TestOutput)]"#).unwrap(); + assert!(!parsed.multipart); + } + + #[test] + fn test_multipart_flag_true() { + let parsed: ParsedVariantAttribute = + syn::parse_str(r#"#[post(url = "/test", output = TestOutput, multipart = true)]"#) + .unwrap(); + assert!(parsed.multipart); + } + + #[test] + fn test_multipart_flag_false() { + let parsed: ParsedVariantAttribute = + syn::parse_str(r#"#[post(url = "/test", output = TestOutput, multipart = false)]"#) + .unwrap(); + assert!(!parsed.multipart); + } }