Skip to content

Commit 13bfdca

Browse files
authored
Give request as argument to OakAbiNativeExtension (#2700)
1 parent cf29c6f commit 13bfdca

File tree

5 files changed

+78
-80
lines changed

5 files changed

+78
-80
lines changed

oak_functions/loader/src/lookup.rs

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -84,27 +84,17 @@ where
8484
&mut self,
8585
wasm_state: &mut WasmState,
8686
args: wasmi::RuntimeArgs,
87+
request: Vec<u8>,
8788
) -> Result<Result<(), OakStatus>, wasmi::Trap> {
88-
let key_ptr = args.nth_checked(0)?;
89-
let key_len = args.nth_checked(1)?;
89+
// TODO(#2699), TODO(#2664): Do not write value to Wasm State here.
9090
let value_ptr_ptr = args.nth_checked(2)?;
9191
let value_len_ptr = args.nth_checked(3)?;
9292

93-
let extension_args = wasm_state
94-
.read_extension_args(key_ptr, key_len)
95-
.map_err(|err| {
96-
self.log_error(&format!(
97-
"storage_get_item(): Unable to read key from guest memory: {:?}",
98-
err
99-
));
100-
OakStatus::ErrInvalidArgs
101-
});
102-
103-
let extension_result = extension_args
104-
.and_then(|key| storage_get_item(self, key))
105-
.and_then(|value| {
106-
wasm_state.write_extension_result(value, value_ptr_ptr, value_len_ptr)
107-
});
93+
// The request is the key to lookup.
94+
let key = request;
95+
let extension_result = storage_get_item(self, key).and_then(|value| {
96+
wasm_state.write_extension_result(value, value_ptr_ptr, value_len_ptr)
97+
});
10898

10999
Ok(extension_result)
110100
}

oak_functions/loader/src/metrics.rs

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -69,24 +69,12 @@ impl ExtensionFactory for PrivateMetricsProxyFactory {
6969
impl OakApiNativeExtension for PrivateMetricsExtension<Logger> {
7070
fn invoke(
7171
&mut self,
72-
wasm_state: &mut WasmState,
73-
args: wasmi::RuntimeArgs,
72+
_wasm_state: &mut WasmState,
73+
_args: wasmi::RuntimeArgs,
74+
request: Vec<u8>,
7475
) -> Result<Result<(), OakStatus>, wasmi::Trap> {
75-
let buf_ptr = args.nth_checked(0)?;
76-
let buf_len = args.nth_checked(1)?;
77-
78-
let args = wasm_state
79-
.read_extension_args(buf_ptr, buf_len)
80-
.map_err(|err| {
81-
self.log_error(&format!(
82-
"report_metric(): Unable to read label from guest memory: {:?}",
83-
err
84-
));
85-
OakStatus::ErrInvalidArgs
86-
});
87-
88-
let result = args.and_then(|metric_message| report_metric(self, metric_message));
89-
76+
// TODO(#2664): Remove WasmState from invoke.
77+
let result = report_metric(self, request);
9078
Ok(result)
9179
}
9280

oak_functions/loader/src/server.rs

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ pub trait OakApiNativeExtension {
122122
&mut self,
123123
wasm_state: &mut WasmState,
124124
args: wasmi::RuntimeArgs,
125+
request: Vec<u8>,
125126
) -> Result<Result<(), OakStatus>, wasmi::Trap>;
126127

127128
/// Metadata about this Extension, including the exported host function name, the function's
@@ -369,13 +370,29 @@ impl WasmState {
369370
})
370371
.expect("Fail to find extension with given handle.");
371372

372-
// We invoke the found extension.
373-
let result = from_oak_status_result(extension.invoke(self, args)?);
373+
// We read the request from the Wasm memory.
374+
let request_ptr: AbiPointer = args.nth_checked(1)?;
375+
let request_len: AbiPointerOffset = args.nth_checked(2)?;
376+
377+
let request = self
378+
.read_extension_args(request_ptr, request_len)
379+
.map_err(|err| {
380+
self.log_error(&format!(
381+
"Handle {:?}: Unable to read input from guest memory: {:?}",
382+
handle, err
383+
));
384+
OakStatus::ErrInvalidArgs
385+
});
386+
387+
let result = match request {
388+
Ok(request) => extension.invoke(self, args, request)?,
389+
Err(err) => Err(err),
390+
};
374391

375392
// We put the extension indices back.
376393
self.extensions_indices = Some(extensions_indices);
377394

378-
result
395+
from_oak_status_result(result)
379396
}
380397

381398
pub fn alloc(&mut self, len: u32) -> AbiPointer {
@@ -394,6 +411,10 @@ impl WasmState {
394411
_ => panic!("invalid value type returned from `alloc`"),
395412
}
396413
}
414+
415+
fn log_error(&self, message: &str) {
416+
self.logger.log_sensitive(Level::Error, message)
417+
}
397418
}
398419

399420
impl wasmi::Externals for WasmState {
@@ -429,9 +450,31 @@ impl wasmi::Externals for WasmState {
429450

430451
None => panic!("Unimplemented function at {}", index),
431452
};
432-
let result = from_oak_status_result(extension.invoke(self, args)?);
453+
454+
// Careful: We assume that here for the ABI call the first two arguments are the
455+
// request (which is true). We will remove this, when we call every
456+
// extension through `invoke`.
457+
let request_ptr: AbiPointer = args.nth_checked(0)?;
458+
let request_len: AbiPointerOffset = args.nth_checked(1)?;
459+
460+
let request = self
461+
.read_extension_args(request_ptr, request_len)
462+
.map_err(|err| {
463+
self.log_error(&format!(
464+
"Handle {:?}: Unable to read input from guest memory: {:?}",
465+
extension.get_handle(),
466+
err
467+
));
468+
OakStatus::ErrInvalidArgs
469+
});
470+
471+
let result = match request {
472+
Ok(request) => extension.invoke(self, args, request)?,
473+
Err(err) => Err(err),
474+
};
475+
433476
self.extensions_indices = Some(extensions_indices);
434-
result
477+
from_oak_status_result(result)
435478
}
436479
}
437480
}

oak_functions/loader/src/testing.rs

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
use crate::{
1818
logger::Logger,
1919
server::{
20-
AbiExtensionHandle, AbiPointer, AbiPointerOffset, BoxedExtension, BoxedExtensionFactory,
21-
ExtensionFactory, OakApiNativeExtension, ABI_USIZE,
20+
AbiPointer, BoxedExtension, BoxedExtensionFactory, ExtensionFactory, OakApiNativeExtension,
21+
ABI_USIZE,
2222
},
2323
};
2424

@@ -34,26 +34,13 @@ impl OakApiNativeExtension for TestingExtension<Logger> {
3434
&mut self,
3535
wasm_state: &mut crate::server::WasmState,
3636
args: wasmi::RuntimeArgs,
37+
request: Vec<u8>,
3738
) -> Result<Result<(), oak_functions_abi::proto::OakStatus>, wasmi::Trap> {
38-
// For consistency we also get the first argument, but we do not need it, as we did read the
39-
// handle already to decide to call the invoke of this extension.
40-
let _handle: AbiExtensionHandle = args.nth_checked(0)?;
41-
let request_ptr: AbiPointer = args.nth_checked(1)?;
42-
let request_len: AbiPointerOffset = args.nth_checked(2)?;
39+
// TODO(#2699), TODO(#2664): Do not write response to Wasm State here.
4340
let response_ptr_ptr: AbiPointer = args.nth_checked(3)?;
4441
let response_len_ptr: AbiPointer = args.nth_checked(4)?;
4542

46-
let extension_args = wasm_state
47-
.read_extension_args(request_ptr, request_len)
48-
.map_err(|err| {
49-
self.log_error(&format!(
50-
"testing(): Unable to read input from guest memory: {:?}",
51-
err
52-
));
53-
OakStatus::ErrInvalidArgs
54-
});
55-
56-
let result = extension_args.and_then(testing).and_then(|result| {
43+
let result = testing(request).and_then(|result| {
5744
wasm_state.write_extension_result(result, response_ptr_ptr, response_len_ptr)
5845
});
5946

oak_functions/loader/src/tf.rs

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
use crate::{
1818
logger::Logger,
1919
server::{
20-
AbiPointer, AbiPointerOffset, BoxedExtension, BoxedExtensionFactory, ExtensionFactory,
21-
OakApiNativeExtension, WasmState, ABI_USIZE,
20+
AbiPointer, BoxedExtension, BoxedExtensionFactory, ExtensionFactory, OakApiNativeExtension,
21+
WasmState, ABI_USIZE,
2222
},
2323
};
2424
use anyhow::Context;
@@ -39,31 +39,21 @@ impl OakApiNativeExtension for TensorFlowModel<Logger> {
3939
&mut self,
4040
wasm_state: &mut WasmState,
4141
args: wasmi::RuntimeArgs,
42+
request: Vec<u8>,
4243
) -> Result<Result<(), OakStatus>, wasmi::Trap> {
43-
let input_ptr: AbiPointer = args.nth_checked(0)?;
44-
let input_len: AbiPointerOffset = args.nth_checked(1)?;
44+
// TODO(#2699), TODO(#2664): Do not write inference to Wasm State here.
4545
let inference_ptr_ptr: AbiPointer = args.nth_checked(2)?;
4646
let inference_len_ptr: AbiPointer = args.nth_checked(3)?;
4747

48-
let extension_args = wasm_state
49-
.read_extension_args(input_ptr, input_len)
50-
.map_err(|err| {
51-
self.log_error(&format!(
52-
"tf_model_infer(): Unable to read input from guest memory: {:?}",
53-
err
54-
));
55-
OakStatus::ErrInvalidArgs
56-
});
57-
58-
let result = extension_args
59-
.and_then(|input| tf_model_infer(self, input))
60-
.and_then(|encoded_inference| {
61-
wasm_state.write_extension_result(
62-
encoded_inference,
63-
inference_ptr_ptr,
64-
inference_len_ptr,
65-
)
66-
});
48+
let input = request;
49+
50+
let result = tf_model_infer(self, input).and_then(|encoded_inference| {
51+
wasm_state.write_extension_result(
52+
encoded_inference,
53+
inference_ptr_ptr,
54+
inference_len_ptr,
55+
)
56+
});
6757

6858
Ok(result)
6959
}

0 commit comments

Comments
 (0)