Skip to content

Commit 4341eac

Browse files
committed
the editor part of the model editor API
1 parent e65fada commit 4341eac

File tree

5 files changed

+125
-5
lines changed

5 files changed

+125
-5
lines changed

src/editor/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ impl Model {
255255
Ok(())
256256
}
257257

258-
pub fn into_session(self, options: SessionBuilder) -> Result<Session> {
258+
pub fn into_session(self, mut options: SessionBuilder) -> Result<Session> {
259259
let env = get_environment()?;
260260
let mut session_ptr = ptr::null_mut();
261261
ortsys![@editor:

src/session/builder/editable.rs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
use alloc::sync::Arc;
2+
use core::{
3+
ops::Deref,
4+
ptr::{self, NonNull}
5+
};
6+
7+
use super::{PrepackedWeights, SessionBuilder};
8+
use crate::{AsPointer, Error, Result, editor::Model, ortsys, session::Session};
9+
10+
pub struct EditableSession {
11+
session: Session,
12+
builder: SessionBuilder,
13+
prepacked_weights: Option<PrepackedWeights>
14+
}
15+
16+
impl EditableSession {
17+
pub(crate) fn new(session: NonNull<ort_sys::OrtSession>, mut builder: SessionBuilder) -> Result<Self> {
18+
// Prepacked weights are passed to `FinalizeModelEditorSession`; steal them from the builder so we can add them later.
19+
let prepacked_weights = builder.prepacked_weights.take();
20+
Ok(Self {
21+
session: builder.commit_finalize(session)?,
22+
builder,
23+
prepacked_weights
24+
})
25+
}
26+
27+
pub fn apply_model(&mut self, model: &Model) -> Result<()> {
28+
ortsys![@editor:
29+
unsafe ApplyModelToModelEditorSession(
30+
self.session.ptr_mut(),
31+
model.ptr().cast_mut()
32+
)?
33+
];
34+
Ok(())
35+
}
36+
37+
pub fn into_session(mut self) -> Result<Session> {
38+
ortsys![@editor:
39+
unsafe FinalizeModelEditorSession(
40+
self.session.ptr_mut(),
41+
self.builder.ptr(),
42+
self.prepacked_weights.as_ref().map(|p| p.ptr()).unwrap_or_else(ptr::null)
43+
)?
44+
];
45+
46+
if let Some(prepacked_weights) = self.prepacked_weights {
47+
let Some(inner) = Arc::get_mut(&mut self.session.inner) else {
48+
return Err(Error::new("Expected to have exclusive access to session inner"));
49+
};
50+
51+
// add to extras so it outlives the session
52+
inner._extras.push(Box::new(prepacked_weights));
53+
}
54+
55+
Ok(self.session)
56+
}
57+
}
58+
59+
impl Deref for EditableSession {
60+
type Target = Session;
61+
62+
fn deref(&self) -> &Self::Target {
63+
&self.session
64+
}
65+
}

src/session/builder/impl_commit.rs

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use std::path::PathBuf;
1515

1616
use smallvec::SmallVec;
1717

18-
use super::SessionBuilder;
18+
use super::{EditableSession, SessionBuilder};
1919
#[cfg(feature = "std")]
2020
use crate::error::{Error, ErrorCode};
2121
use crate::{
@@ -176,7 +176,7 @@ impl SessionBuilder {
176176
self.commit_finalize(unsafe { NonNull::new_unchecked(session_ptr) })
177177
}
178178

179-
pub(crate) fn commit_finalize(mut self, ptr: NonNull<ort_sys::OrtSession>) -> Result<Session> {
179+
pub(crate) fn commit_finalize(&mut self, ptr: NonNull<ort_sys::OrtSession>) -> Result<Session> {
180180
let allocator = match &self.memory_info {
181181
Some(info) => {
182182
let mut allocator_ptr: *mut ort_sys::OrtAllocator = ptr::null_mut();
@@ -218,4 +218,47 @@ impl SessionBuilder {
218218
outputs
219219
})
220220
}
221+
222+
#[cfg(feature = "std")]
223+
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
224+
pub fn edit_from_file<P>(self, model_filepath: P) -> Result<EditableSession>
225+
where
226+
P: AsRef<Path>
227+
{
228+
let mut session_ptr: *mut ort_sys::OrtSession = ptr::null_mut();
229+
let model_path = crate::util::path_to_os_char(model_filepath);
230+
231+
let env = get_environment()?;
232+
233+
ortsys![@editor:
234+
unsafe CreateModelEditorSession(
235+
env.ptr(),
236+
model_path.as_ptr(),
237+
self.session_options_ptr.as_ptr(),
238+
&mut session_ptr
239+
)?;
240+
nonNull(session_ptr)
241+
];
242+
243+
EditableSession::new(unsafe { NonNull::new_unchecked(session_ptr) }, self)
244+
}
245+
246+
pub fn edit_from_memory(self, model_bytes: &[u8]) -> Result<EditableSession> {
247+
let mut session_ptr: *mut ort_sys::OrtSession = ptr::null_mut();
248+
249+
let env = get_environment()?;
250+
251+
ortsys![@editor:
252+
unsafe CreateModelEditorSessionFromArray(
253+
env.ptr(),
254+
model_bytes.as_ptr().cast(),
255+
model_bytes.len() as _,
256+
self.session_options_ptr.as_ptr(),
257+
&mut session_ptr
258+
)?;
259+
nonNull(session_ptr)
260+
];
261+
262+
EditableSession::new(unsafe { NonNull::new_unchecked(session_ptr) }, self)
263+
}
221264
}

src/session/builder/mod.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,15 @@ use smallvec::SmallVec;
88

99
use crate::{AsPointer, error::Result, logging::LoggerFunction, memory::MemoryInfo, operator::OperatorDomain, ortsys, util::with_cstr, value::DynValue};
1010

11+
mod editable;
1112
mod impl_commit;
1213
mod impl_config_keys;
1314
mod impl_options;
1415

15-
pub use self::impl_options::{GraphOptimizationLevel, PrepackedWeights};
16+
pub use self::{
17+
editable::EditableSession,
18+
impl_options::{GraphOptimizationLevel, PrepackedWeights}
19+
};
1620

1721
/// Creates a session using the builder pattern.
1822
///

src/session/mod.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ use crate::{
3232
memory::Allocator,
3333
metadata::ModelMetadata,
3434
ortsys,
35-
util::{STACK_SESSION_INPUTS, STACK_SESSION_OUTPUTS, with_cstr_ptr_array},
35+
util::{STACK_SESSION_INPUTS, STACK_SESSION_OUTPUTS, with_cstr, with_cstr_ptr_array},
3636
value::{DynValue, Value, ValueType}
3737
};
3838

@@ -545,6 +545,14 @@ impl Session {
545545
ortsys![unsafe SetEpDynamicOptions(self.inner.session_ptr.as_ptr(), &key, &value, 1)?];
546546
Ok(())
547547
}
548+
549+
pub fn opset_for_domain(&self, domain: impl AsRef<str>) -> Result<u32> {
550+
with_cstr(domain.as_ref().as_bytes(), &|domain| {
551+
let mut opset = 0;
552+
ortsys![@editor: unsafe SessionGetOpsetForDomain(self.inner.session_ptr.as_ptr(), domain.as_ptr(), &mut opset)?];
553+
Ok(opset as u32)
554+
})
555+
}
548556
}
549557

550558
/// Workload type, used to signal to execution providers whether to prioritize performance or efficiency.

0 commit comments

Comments
 (0)