Skip to content

Commit 2b10aaa

Browse files
authored
implement Slice op (#2260)
1 parent 9f804af commit 2b10aaa

File tree

2 files changed

+215
-0
lines changed

2 files changed

+215
-0
lines changed

candle-onnx/src/eval.rs

+80
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ pub fn dtype(dt: DataType) -> Option<DType> {
1414
DataType::Float16 => Some(DType::F16),
1515
DataType::Float => Some(DType::F32),
1616
DataType::Double => Some(DType::F64),
17+
DataType::Bool => Some(DType::U8),
1718
_ => None,
1819
}
1920
}
@@ -1053,6 +1054,85 @@ fn simple_eval_(
10531054
),
10541055
}
10551056
}
1057+
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#slice
1058+
"Slice" => {
1059+
let data = get(&node.input[0])?;
1060+
let starts = get(&node.input[1])?;
1061+
let ends = get(&node.input[2])?;
1062+
let default_axes;
1063+
let default_steps;
1064+
let axes: &Tensor;
1065+
let steps: &Tensor;
1066+
// If axes are omitted, they are set to [0, ..., r-1]. If steps are omitted,
1067+
// they are set to [1, ..., 1] of length len(starts)
1068+
match node.input.len() {
1069+
3 => {
1070+
let len = starts.dims()[0];
1071+
default_axes = Some(Tensor::arange(0, len as i64, starts.device())?);
1072+
axes = default_axes.as_ref().unwrap();
1073+
default_steps = Some(Tensor::ones((len,), DType::I64, starts.device())?);
1074+
steps = default_steps.as_ref().unwrap();
1075+
}
1076+
4 => {
1077+
let len = starts.dims()[0];
1078+
axes = get(&node.input[3])?;
1079+
default_steps = Some(Tensor::ones((len,), DType::I64, starts.device())?);
1080+
steps = default_steps.as_ref().unwrap();
1081+
}
1082+
5 => {
1083+
steps = get(&node.input[4])?;
1084+
axes = get(&node.input[3])?;
1085+
}
1086+
_ => bail!(
1087+
"Slice node is invalid, expected 3-5 inputs, got {}: {:?}",
1088+
node.input.len(),
1089+
node
1090+
),
1091+
}
1092+
1093+
let mut out = data.clone();
1094+
for (i, axis) in axes.to_vec1::<i64>()?.into_iter().enumerate() {
1095+
// All negative elements of axes are made non-negative by
1096+
// adding r to them, where r = rank(input).
1097+
let axis = if axis < 0 {
1098+
axis + data.rank() as i64
1099+
} else {
1100+
axis
1101+
} as usize;
1102+
1103+
let data_dim = data.dims()[axis] as i64;
1104+
let mut s = starts.get(i)?.to_scalar::<i64>()?;
1105+
let mut e = ends.get(i)?.to_scalar::<i64>()?;
1106+
// All negative values in starts[i] and ends[i] have
1107+
// dims[axes[i]] added to them, where dims are the
1108+
// dimensions of input.
1109+
if s < 0 {
1110+
s += data_dim;
1111+
}
1112+
if e < 0 {
1113+
e += data_dim;
1114+
}
1115+
1116+
let p = steps.get(i)?.to_scalar::<i64>()?;
1117+
// starts[i] is clamped into the range [0, dims[axes[i]]]
1118+
// for positive stepping and [0, dims[axes[i]]-1] for
1119+
// negative stepping.
1120+
// for positive stepping ends[axes[i]] is clamped to
1121+
// [0, dims[axes[i]]], while for negative stepping it is
1122+
// clamped to [-1, dims[axes[i]]-1].
1123+
if p >= 0 {
1124+
s = s.clamp(0, data_dim);
1125+
e = e.clamp(0, data_dim);
1126+
} else {
1127+
s = s.clamp(0, data_dim - 1);
1128+
e = e.clamp(-1, data_dim - 1);
1129+
}
1130+
1131+
let indexes = Tensor::arange_step(s, e, p, data.device())?;
1132+
out = out.index_select(&indexes, axis)?
1133+
}
1134+
values.insert(node.output[0].clone(), out);
1135+
}
10561136
// https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13
10571137
// TODO: This version is only compatible with ReduceMean V13 and below.
10581138
"ReduceMean" => {

candle-onnx/tests/ops.rs

+135
Original file line numberDiff line numberDiff line change
@@ -3272,3 +3272,138 @@ fn test_pad() -> Result<()> {
32723272
assert_eq!(actual.to_vec2::<f64>()?, expected.to_vec2::<f64>()?);
32733273
Ok(())
32743274
}
3275+
3276+
#[test]
3277+
fn test_slice() -> Result<()> {
3278+
let model = create_model_proto_with_graph(Some(GraphProto {
3279+
node: vec![NodeProto {
3280+
op_type: "Slice".to_string(),
3281+
input: vec![
3282+
"data".to_string(),
3283+
"starts".to_string(),
3284+
"ends".to_string(),
3285+
"axes".to_string(),
3286+
"steps".to_string(),
3287+
],
3288+
output: vec!["result".to_string()],
3289+
..NodeProto::default()
3290+
}],
3291+
input: ["data", "starts", "ends", "axes", "steps"]
3292+
.into_iter()
3293+
.map(|name| ValueInfoProto {
3294+
name: name.to_string(),
3295+
r#type: None,
3296+
doc_string: "".to_string(),
3297+
})
3298+
.collect(),
3299+
output: ["result"]
3300+
.into_iter()
3301+
.map(|name| ValueInfoProto {
3302+
name: name.to_string(),
3303+
r#type: None,
3304+
doc_string: "".to_string(),
3305+
})
3306+
.collect(),
3307+
..GraphProto::default()
3308+
}));
3309+
3310+
/*
3311+
data = [
3312+
[1, 2, 3, 4],
3313+
[5, 6, 7, 8],
3314+
]
3315+
axes = [0, 1]
3316+
starts = [1, 0]
3317+
ends = [2, 3]
3318+
steps = [1, 2]
3319+
result = [
3320+
[5, 7],
3321+
]
3322+
*/
3323+
3324+
let outputs = candle_onnx::simple_eval(
3325+
&model,
3326+
HashMap::from_iter([
3327+
(
3328+
"data".to_string(),
3329+
Tensor::from_vec(vec![1i64, 2, 3, 4, 5, 6, 7, 8], (2, 4), &Device::Cpu)?,
3330+
),
3331+
(
3332+
"starts".to_string(),
3333+
Tensor::from_vec(vec![1i64, 0], (2,), &Device::Cpu)?,
3334+
),
3335+
(
3336+
"ends".to_string(),
3337+
Tensor::from_vec(vec![2i64, 3], (2,), &Device::Cpu)?,
3338+
),
3339+
(
3340+
"axes".to_string(),
3341+
Tensor::from_vec(vec![0i64, 1], (2,), &Device::Cpu)?,
3342+
),
3343+
(
3344+
"steps".to_string(),
3345+
Tensor::from_vec(vec![1i64, 2], (2,), &Device::Cpu)?,
3346+
),
3347+
]),
3348+
)?;
3349+
let actual = outputs.get("result").unwrap().to_vec2::<i64>()?;
3350+
assert_eq!(actual, vec![vec![5i64, 7]]);
3351+
3352+
/*
3353+
data = [
3354+
[1, 2, 3, 4],
3355+
[5, 6, 7, 8],
3356+
]
3357+
starts = [0, 1]
3358+
ends = [-1, 1000]
3359+
result = [
3360+
[2, 3, 4],
3361+
]
3362+
*/
3363+
let model = create_model_proto_with_graph(Some(GraphProto {
3364+
node: vec![NodeProto {
3365+
op_type: "Slice".to_string(),
3366+
input: vec!["data".to_string(), "starts".to_string(), "ends".to_string()],
3367+
output: vec!["result".to_string()],
3368+
..NodeProto::default()
3369+
}],
3370+
input: ["data", "starts", "ends"]
3371+
.into_iter()
3372+
.map(|name| ValueInfoProto {
3373+
name: name.to_string(),
3374+
r#type: None,
3375+
doc_string: "".to_string(),
3376+
})
3377+
.collect(),
3378+
output: ["result"]
3379+
.into_iter()
3380+
.map(|name| ValueInfoProto {
3381+
name: name.to_string(),
3382+
r#type: None,
3383+
doc_string: "".to_string(),
3384+
})
3385+
.collect(),
3386+
..GraphProto::default()
3387+
}));
3388+
let outputs = candle_onnx::simple_eval(
3389+
&model,
3390+
HashMap::from_iter([
3391+
(
3392+
"data".to_string(),
3393+
Tensor::from_vec(vec![1i64, 2, 3, 4, 5, 6, 7, 8], (2, 4), &Device::Cpu)?,
3394+
),
3395+
(
3396+
"starts".to_string(),
3397+
Tensor::from_vec(vec![0i64, 1], (2,), &Device::Cpu)?,
3398+
),
3399+
(
3400+
"ends".to_string(),
3401+
Tensor::from_vec(vec![-1i64, 1000], (2,), &Device::Cpu)?,
3402+
),
3403+
]),
3404+
)?;
3405+
let actual = outputs.get("result").unwrap().to_vec2::<i64>()?;
3406+
assert_eq!(actual, vec![vec![2i64, 3, 4]]);
3407+
3408+
Ok(())
3409+
}

0 commit comments

Comments
 (0)