Skip to content

Commit d80b119

Browse files
author
Yinwei Li
committed
[bugfix]Change AnnSearchRequest and The params settings for search and hybrid_search.
Changed AnnSearchRequest's params from KeyValuePair to Vec<KeyValuePair> to enable different search options for different anns_field. Change the default params setting mothod,following pymilvus. Add some more graceful error handlers to raw_id.Now if the search result is null,the code will throw an error and continue to run. Signed-off-by: Yinwei Li <yinwei.li@zilliz.com>
1 parent e3bd6fb commit d80b119

File tree

4 files changed

+549
-77
lines changed

4 files changed

+549
-77
lines changed

examples/query_search.rs

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
use std::collections::HashMap;
2+
use std::vec;
3+
4+
use milvus::client::Client;
5+
use milvus::collection::SearchResult;
6+
use milvus::data::FieldColumn;
7+
use milvus::error::Result;
8+
use milvus::index::{IndexParams, IndexType, MetricType};
9+
use milvus::query::*;
10+
use milvus::schema::{CollectionSchemaBuilder, FieldSchema};
11+
use milvus::value::Value;
12+
use rand::Rng;
13+
14+
const DIM: i64 = 8;
15+
const NUM_ENTITIES: usize = 10000;
16+
const PICTURE: &str = "picture";
17+
const USER_ID: &str = "id";
18+
const AGE: &str = "age";
19+
const DEPOSIT: &str = "deposit";
20+
const COLLECTION_NAME: &str = "test_query_search_collection";
21+
22+
#[tokio::main]
23+
async fn main() -> Result<()> {
24+
let client = Client::new("http://localhost:19530").await?;
25+
prepare_data(&client).await?;
26+
query_test(&client).await?;
27+
search_test(&client).await?;
28+
get_test(&client).await?;
29+
Ok(())
30+
}
31+
32+
// query test
33+
async fn query_test(client: &Client) -> Result<()> {
34+
let options = QueryOptions::new()
35+
.limit(50)
36+
.output_fields(vec![USER_ID.to_string(), AGE.to_string()]);
37+
let res = client.query(COLLECTION_NAME, "10<age<20", &options).await?;
38+
println!("==========Query test begin==========");
39+
println!("Query result:");
40+
41+
// Extract id and age columns from the result
42+
let id_column = res.iter().find(|col| col.name == USER_ID).unwrap();
43+
let age_column = res.iter().find(|col| col.name == AGE).unwrap();
44+
45+
// Get the data vectors from the columns
46+
let ids: Vec<i64> = id_column.value.clone().try_into().unwrap();
47+
let ages: Vec<i64> = age_column.value.clone().try_into().unwrap();
48+
49+
// Print the results in the requested format
50+
for (id, age) in ids.iter().zip(ages.iter()) {
51+
println!("id: {} age: {}", id, age);
52+
}
53+
println!("==========Query test end==========\n");
54+
Ok(())
55+
}
56+
57+
// search test
58+
async fn search_test(client: &Client) -> Result<()> {
59+
let vector_to_search = Value::from(
60+
(0..DIM as usize)
61+
.map(|_| rand::thread_rng().gen_range(0.0..1.0))
62+
.collect::<Vec<f32>>(),
63+
);
64+
println!("==========Search test begin==========");
65+
// Prepare search options
66+
let options = SearchOptions::new()
67+
.limit(10)
68+
.output_fields(vec![
69+
USER_ID.to_string(),
70+
AGE.to_string(),
71+
PICTURE.to_string(),
72+
])
73+
.add_param("anns_field", "picture")
74+
.add_param("metric_type", "L2");
75+
76+
// Search
77+
let res = client
78+
.search(COLLECTION_NAME, vec![vector_to_search], Some(options))
79+
.await?;
80+
81+
println!("Search result:");
82+
print_search_results(&res);
83+
println!("==========Search test end==========\n");
84+
Ok(())
85+
}
86+
87+
// get test
88+
async fn get_test(client: &Client) -> Result<()> {
89+
// Prepare get options
90+
let options = GetOptions::new().output_fields(vec![
91+
USER_ID.to_string(),
92+
AGE.to_string(),
93+
DEPOSIT.to_string(),
94+
PICTURE.to_string(),
95+
]);
96+
// Get
97+
let res = client
98+
.get(
99+
COLLECTION_NAME,
100+
IdType::Int64(vec![1, 2, 3, 4, 5]),
101+
Some(options),
102+
)
103+
.await?;
104+
println!("==========Get test begin==========");
105+
println!("Get result:");
106+
print_get_results(&res);
107+
println!("==========Get test end==========\n");
108+
Ok(())
109+
}
110+
111+
// prepare data
112+
async fn prepare_data(client: &Client) -> Result<()> {
113+
// Prepare data
114+
if client.has_collection(COLLECTION_NAME).await? {
115+
client.drop_collection(COLLECTION_NAME).await?;
116+
}
117+
println!("==========Prepare data begin==========");
118+
// 1. create collection
119+
let schema = CollectionSchemaBuilder::new(COLLECTION_NAME, "test_query_search_collection")
120+
.add_field(FieldSchema::new_primary_int64(USER_ID, "user if", false))
121+
.add_field(FieldSchema::new_int64(AGE, "age of user"))
122+
.add_field(FieldSchema::new_double(DEPOSIT, ""))
123+
.add_field(FieldSchema::new_float_vector(PICTURE, "", DIM))
124+
.build()?;
125+
126+
client.create_collection(schema.clone(), None).await?;
127+
// 2. insert data
128+
let ids = (0..NUM_ENTITIES).map(|i| i as i64).collect::<Vec<_>>();
129+
let age = (0..NUM_ENTITIES)
130+
.map(|i| (i % 100) as i64)
131+
.collect::<Vec<_>>();
132+
let deposit = (0..NUM_ENTITIES).map(|i| i as f64).collect::<Vec<_>>();
133+
let picture = (0..NUM_ENTITIES * DIM as usize)
134+
.map(|_| rand::thread_rng().gen_range(0.0..1.0))
135+
.collect::<Vec<f32>>();
136+
137+
let id_column = FieldColumn::new(schema.get_field(USER_ID).unwrap(), ids);
138+
let age_column = FieldColumn::new(schema.get_field(AGE).unwrap(), age);
139+
let deposit_column = FieldColumn::new(schema.get_field(DEPOSIT).unwrap(), deposit);
140+
let picture_column = FieldColumn::new(schema.get_field(PICTURE).unwrap(), picture);
141+
142+
client
143+
.insert(
144+
COLLECTION_NAME,
145+
vec![id_column, age_column, deposit_column, picture_column],
146+
None,
147+
)
148+
.await?;
149+
client.flush(COLLECTION_NAME).await?;
150+
println!("Finish flush collections:{}", COLLECTION_NAME);
151+
152+
// 3. create index
153+
let index_params = IndexParams::new(
154+
"picture_index".to_string(),
155+
IndexType::IvfFlat,
156+
MetricType::L2,
157+
HashMap::from([("nlist".to_string(), "1024".to_string())]),
158+
);
159+
client
160+
.create_index(COLLECTION_NAME, PICTURE, index_params)
161+
.await?;
162+
client.load_collection(COLLECTION_NAME, None).await?;
163+
println!("==========Prepare data end==========\n");
164+
Ok(())
165+
}
166+
167+
168+
// Print functions.
169+
// You can ignore this part.
170+
171+
fn print_search_results(res: &Vec<SearchResult<'_>>) {
172+
let id_column = res
173+
.iter()
174+
.map(|col| {
175+
col.field
176+
.iter()
177+
.find(|x| x.name == USER_ID)
178+
.unwrap()
179+
.value
180+
.clone()
181+
})
182+
.collect::<Vec<_>>();
183+
let age_column = res
184+
.iter()
185+
.map(|col| {
186+
col.field
187+
.iter()
188+
.find(|x| x.name == AGE)
189+
.unwrap()
190+
.value
191+
.clone()
192+
})
193+
.collect::<Vec<_>>();
194+
let picture_column = res
195+
.iter()
196+
.map(|col| {
197+
col.field
198+
.iter()
199+
.find(|x| x.name == PICTURE)
200+
.unwrap()
201+
.value
202+
.clone()
203+
})
204+
.collect::<Vec<_>>();
205+
let score_column = res.iter().map(|col| col.score.clone()).collect::<Vec<_>>();
206+
for (ids, ages, pictures, scores) in id_column
207+
.iter()
208+
.zip(age_column.iter())
209+
.zip(picture_column.iter())
210+
.zip(score_column.iter())
211+
.map(|(((id, age), picture), score)| {
212+
(id.clone(), age.clone(), picture.clone(), score.clone())
213+
})
214+
{
215+
let id_column: Vec<i64> = ids.clone().try_into().unwrap();
216+
let age_column: Vec<i64> = ages.clone().try_into().unwrap();
217+
let picture_column: Vec<f32> = pictures.clone().try_into().unwrap();
218+
let score_column: Vec<f32> = scores.clone().try_into().unwrap();
219+
for (id, age, picture, score) in id_column
220+
.iter()
221+
.zip(age_column.iter())
222+
.zip(picture_column.chunks(DIM as usize))
223+
.zip(score_column.iter())
224+
.map(|(((id, age), picture), score)| {
225+
(id.clone(), age.clone(), picture.to_vec(), score.clone())
226+
})
227+
{
228+
println!(
229+
"id: {} age: {} picture: {:?} score: {}",
230+
id, age, picture, score
231+
);
232+
}
233+
}
234+
}
235+
236+
fn print_get_results(res: &Vec<FieldColumn>) {
237+
let id_column = res.iter().find(|col| col.name == USER_ID).unwrap();
238+
let age_column = res.iter().find(|col| col.name == AGE).unwrap();
239+
let deposit_column = res.iter().find(|col| col.name == DEPOSIT).unwrap();
240+
let picture_column = res.iter().find(|col| col.name == PICTURE).unwrap();
241+
242+
let ids: Vec<i64> = id_column.value.clone().try_into().unwrap();
243+
let ages: Vec<i64> = age_column.value.clone().try_into().unwrap();
244+
let deposits: Vec<f64> = deposit_column.value.clone().try_into().unwrap();
245+
let pictures: Vec<f32> = picture_column.value.clone().try_into().unwrap();
246+
for (id, age, deposit, picture) in ids
247+
.iter()
248+
.zip(ages.iter())
249+
.zip(deposits.iter())
250+
.zip(pictures.chunks(DIM as usize))
251+
.map(|(((id, age), deposit), picture)| {
252+
(id.clone(), age.clone(), deposit.clone(), picture.to_vec())
253+
})
254+
{
255+
println!(
256+
"id: {} age: {} deposit: {} picture: {:?}",
257+
id, age, deposit, picture
258+
);
259+
}
260+
}

src/collection.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,7 @@ pub type ParamValue = serde_json::Value;
751751
pub use serde_json::json as ParamValue;
752752

753753
// search result for a single vector
754-
#[derive(Clone)]
754+
#[derive(Clone,Debug)]
755755
pub struct SearchResult<'a> {
756756
pub size: i64,
757757
pub id: Vec<Value<'a>>,

src/iterator.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1452,7 +1452,7 @@ impl SearchIterator {
14521452
partition_names: self.options.partition_names.clone(),
14531453
dsl: self.options.filter.clone(),
14541454
nq: self.data.len() as _,
1455-
placeholder_group: crate::query::get_place_holder_group(self.data.clone())?,
1455+
placeholder_group: crate::query::get_place_holder_group(&self.data)?,
14561456
dsl_type: crate::proto::common::DslType::BoolExprV1 as _,
14571457
output_fields: self.options.output_fields.clone(),
14581458
search_params,

0 commit comments

Comments
 (0)