forked from tauri-apps/wry
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathurl_scheme_handler.rs
More file actions
358 lines (315 loc) · 13.1 KB
/
url_scheme_handler.rs
File metadata and controls
358 lines (315 loc) · 13.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
// Copyright 2020-2024 Tauri Programme within The Commons Conservancy
// SPDX-License-Identifier: Apache-2.0
// SPDX-License-Identifier: MIT
use std::{
borrow::Cow,
ffi::{c_char, c_void, CStr},
panic::AssertUnwindSafe,
ptr::NonNull,
};
use http::{
header::{CONTENT_LENGTH, CONTENT_TYPE},
Request, Response as HttpResponse, StatusCode, Version,
};
use objc2::{
rc::Retained,
runtime::{AnyClass, AnyObject, ClassBuilder, ProtocolObject},
AllocAnyThread, ClassType, Message,
};
use objc2_foundation::{
NSData, NSHTTPURLResponse, NSMutableDictionary, NSObject, NSObjectProtocol, NSString, NSURL,
NSUUID,
};
use objc2_web_kit::{WKURLSchemeHandler, WKURLSchemeTask};
use crate::{wkwebview::WEBVIEW_STATE, RequestAsyncResponder, WryWebView};
const NO_COPY_DATA_THRESHOLD: usize = 128 * 1024;
pub fn create(name: &str) -> &AnyClass {
unsafe {
// Include the address of WEBVIEW_STATE in the class name so that each dylib in the process
// gets its own ObjC class with method pointers into its own code and data segments.
let unique_id = std::ptr::addr_of!(WEBVIEW_STATE) as usize;
let scheme_name = format!("{name}URLSchemeHandler_{unique_id:x}\0");
let scheme_name = CStr::from_bytes_with_nul(scheme_name.as_bytes()).unwrap();
let cls = ClassBuilder::new(scheme_name, NSObject::class());
match cls {
Some(mut cls) => {
cls.add_ivar::<*mut c_char>(c"webview_id");
cls.add_ivar::<usize>(c"protocol_index");
cls.add_method(
objc2::sel!(webView:startURLSchemeTask:),
start_task as extern "C" fn(_, _, _, _),
);
cls.add_method(
objc2::sel!(webView:stopURLSchemeTask:),
stop_task as extern "C" fn(_, _, _, _),
);
cls.register()
}
None => AnyClass::get(scheme_name).expect("Failed to get the class definition"),
}
}
}
// Task handler for custom protocol
extern "C" fn start_task(
this: &AnyObject,
_sel: objc2::runtime::Sel,
webview: &WryWebView,
task: &ProtocolObject<dyn WKURLSchemeTask>,
) {
unsafe {
#[cfg(feature = "tracing")]
let span = tracing::info_span!(parent: None, "wry::custom_protocol::handle", uri = tracing::field::Empty)
.entered();
let task_key = task.hash(); // hash by task object address
let task_uuid = webview.add_custom_task_key(task_key);
let ivar = this.class().instance_variable(c"webview_id").unwrap();
let webview_id_ptr: *mut c_char = *ivar.load(this);
let webview_id = CStr::from_ptr(webview_id_ptr)
.to_str()
.ok()
.unwrap_or_default();
let ivar = this.class().instance_variable(c"protocol_index").unwrap();
let protocol_index: usize = *ivar.load(this);
let function = WEBVIEW_STATE
.read()
.unwrap()
.get(webview_id)
.and_then(|v| v.protocol_ptrs.get(protocol_index))
.cloned();
if let Some(function) = function {
// Get url request
let request = task.request();
let url = request.URL().unwrap();
let uri = url.absoluteString().unwrap().to_string();
#[cfg(feature = "tracing")]
span.record("uri", uri.clone());
// Get request method (GET, POST, PUT etc...)
let method = request.HTTPMethod().unwrap().to_string();
// Prepare our HttpRequest
let mut http_request = Request::builder().uri(uri).method(method.as_str());
// Get body
let mut sent_form_body = Vec::new();
let body = request.HTTPBody();
let body_stream = request.HTTPBodyStream();
if let Some(body) = body {
sent_form_body = body.to_vec();
} else if let Some(body_stream) = body_stream {
body_stream.open();
while body_stream.hasBytesAvailable() {
sent_form_body.reserve(128);
let p = sent_form_body.as_mut_ptr().add(sent_form_body.len());
let read_length = sent_form_body.capacity() - sent_form_body.len();
let count = body_stream.read_maxLength(NonNull::new(p).unwrap(), read_length);
sent_form_body.set_len(sent_form_body.len() + count as usize);
}
body_stream.close();
}
// Extract all headers fields
let all_headers = request.allHTTPHeaderFields();
// get all our headers values and inject them in our request
if let Some(all_headers) = all_headers {
for current_header in all_headers.allKeys().iter() {
let header_value = all_headers.valueForKey(¤t_header).unwrap();
// inject the header into the request
http_request = http_request.header(current_header.to_string(), header_value.to_string());
}
}
let respond_with_404 = || {
let urlresponse = NSHTTPURLResponse::alloc();
let response = NSHTTPURLResponse::initWithURL_statusCode_HTTPVersion_headerFields(
urlresponse,
&url,
StatusCode::NOT_FOUND.as_u16().try_into().unwrap(),
Some(&NSString::from_str(
format!("{:#?}", Version::HTTP_11).as_str(),
)),
None,
)
.unwrap();
task.didReceiveResponse(&response);
// Finish
task.didFinish();
};
fn check_webview_id_valid(webview_id: &str) -> crate::Result<()> {
if !WEBVIEW_STATE.read().unwrap().contains_key(webview_id) {
return Err(crate::Error::CustomProtocolTaskInvalid);
}
Ok(())
}
/// Task may not live longer than async custom protocol handler.
///
/// There are roughly 2 ways to cause segfault:
/// 1. Task has stopped. pointer of the task not valid anymore.
/// 2. Task had stopped, but the pointer of the task has allocated to a new task.
/// Outdated custom handler may call to the new task instance and cause segfault.
fn check_task_is_valid(
webview: &WryWebView,
task_key: usize,
current_uuid: Retained<NSUUID>,
) -> crate::Result<()> {
let latest_task_uuid = webview.get_custom_task_uuid(task_key);
let Some(latest_uuid) = latest_task_uuid else {
return Err(crate::Error::CustomProtocolTaskInvalid);
};
if latest_uuid != current_uuid {
return Err(crate::Error::CustomProtocolTaskInvalid);
}
Ok(())
}
// send response
match http_request.body(sent_form_body) {
Ok(final_request) => {
let webview = webview.retain();
let task = task.retain();
let responder: Box<dyn FnOnce(HttpResponse<Cow<'static, [u8]>>)> =
Box::new(move |sent_response| {
// Consolidate checks before calling into `did*` methods.
let validate = || -> crate::Result<()> {
check_webview_id_valid(webview_id)?;
check_task_is_valid(&webview, task_key, task_uuid.clone())?;
Ok(())
};
// Perform an upfront validation
if let Err(_e) = validate() {
#[cfg(feature = "tracing")]
tracing::warn!("Task invalid before sending response: {:?}", _e);
return; // If invalid, return early without calling task methods.
}
unsafe fn response(
// FIXME: though we give it a static lifetime, it's not guaranteed to be valid.
task: Retained<ProtocolObject<dyn WKURLSchemeTask>>,
// FIXME: though we give it a static lifetime, it's not guaranteed to be valid.
webview: Retained<WryWebView>,
task_key: usize,
task_uuid: Retained<NSUUID>,
webview_id: &str,
url: Retained<NSURL>,
sent_response: HttpResponse<Cow<'_, [u8]>>,
) -> crate::Result<()> {
// Validate
check_webview_id_valid(webview_id)?;
check_task_is_valid(&webview, task_key, task_uuid.clone())?;
let content = sent_response.body();
let content_len = content.len();
// default: application/octet-stream, but should be provided by the client
let wanted_mime = sent_response.headers().get(CONTENT_TYPE);
// default to 200
let wanted_status_code = sent_response.status().as_u16() as i32;
// default to HTTP/1.1
let wanted_version = format!("{:#?}", sent_response.version());
let headers = NSMutableDictionary::new();
if let Some(mime) = wanted_mime {
headers.insert(
&*NSString::from_str(CONTENT_TYPE.as_str()),
&*NSString::from_str(mime.to_str().unwrap()),
);
}
headers.insert(
&*NSString::from_str(CONTENT_LENGTH.as_str()),
&*NSString::from_str(&content_len.to_string()),
);
// add headers
for (name, value) in sent_response.headers().iter() {
if let Ok(value) = value.to_str() {
headers.insert(
&*NSString::from_str(name.as_str()),
&*NSString::from_str(value),
);
}
}
let urlresponse = NSHTTPURLResponse::alloc();
let response = NSHTTPURLResponse::initWithURL_statusCode_HTTPVersion_headerFields(
urlresponse,
&url,
wanted_status_code.try_into().unwrap(),
Some(&NSString::from_str(&wanted_version)),
Some(&headers),
)
.unwrap();
// Re-validate before calling didReceiveResponse
check_webview_id_valid(webview_id)?;
check_task_is_valid(&webview, task_key, task_uuid.clone())?;
// Use map_err to convert Option<Retained<Exception>> to crate::Error
objc2::exception::catch(AssertUnwindSafe(|| {
task.didReceiveResponse(&response);
}))
.map_err(|_e| crate::Error::CustomProtocolTaskInvalid)?;
let data = if content_len < NO_COPY_DATA_THRESHOLD {
let data = NSData::alloc();
// Keep small responses on the original copy path; no-copy deallocation costs more.
NSData::initWithBytes_length(data, content.as_ptr() as *mut c_void, content.len())
} else {
match sent_response.into_body() {
Cow::Owned(content) => NSData::from_vec(content),
Cow::Borrowed(content) => {
let data = NSData::alloc();
// Copy borrowed responses because NSData cannot take ownership.
NSData::initWithBytes_length(
data,
content.as_ptr() as *mut c_void,
content.len(),
)
}
}
};
// Check validity again
check_webview_id_valid(webview_id)?;
check_task_is_valid(&webview, task_key, task_uuid.clone())?;
objc2::exception::catch(AssertUnwindSafe(|| {
task.didReceiveData(&data);
}))
.map_err(|_e| crate::Error::CustomProtocolTaskInvalid)?;
check_webview_id_valid(webview_id)?;
check_task_is_valid(&webview, task_key, task_uuid)?;
objc2::exception::catch(AssertUnwindSafe(|| {
task.didFinish();
}))
.map_err(|_e| crate::Error::CustomProtocolTaskInvalid)?;
if WEBVIEW_STATE.read().unwrap().contains_key(webview_id) {
webview.remove_custom_task_key(task_key);
Ok(())
} else {
Err(crate::Error::CustomProtocolTaskInvalid)
}
}
#[cfg(feature = "tracing")]
let _span = tracing::info_span!("wry::custom_protocol::call_handler").entered();
if let Err(_e) = response(
task,
webview,
task_key,
task_uuid,
webview_id,
url,
sent_response,
) {
#[cfg(feature = "tracing")]
tracing::error!("Error responding to task: {:?}", _e);
}
});
#[cfg(feature = "tracing")]
let _span = tracing::info_span!("wry::custom_protocol::call_handler").entered();
function(
webview_id,
final_request,
RequestAsyncResponder { responder },
);
}
Err(_) => respond_with_404(),
};
} else {
#[cfg(feature = "tracing")]
tracing::warn!(
"Either WebView or WebContext instance is dropped! This handler shouldn't be called."
);
};
}
}
extern "C" fn stop_task(
_this: &ProtocolObject<dyn WKURLSchemeHandler>,
_sel: objc2::runtime::Sel,
webview: &WryWebView,
task: &ProtocolObject<dyn WKURLSchemeTask>,
) {
webview.remove_custom_task_key(task.hash());
}