Skip to content

Commit be539d2

Browse files
committed
Rewrite last_insert_id type.
1 parent 4e3acb1 commit be539d2

19 files changed

+221
-124
lines changed

src/database/proxy.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ pub trait ProxyDatabaseTrait: Send + Sync + std::fmt::Debug {
3131
#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
3232
pub struct ProxyExecResult {
3333
/// The last inserted id on auto-increment
34-
pub last_insert_id: u64,
34+
pub last_insert_id: Option<u64>,
3535
/// The number of rows affected by the database operation
3636
pub rows_affected: u64,
3737
}
3838

3939
impl ProxyExecResult {
4040
/// Create a new [ProxyExecResult] from the last inserted id and the number of rows affected
41-
pub fn new(last_insert_id: u64, rows_affected: u64) -> Self {
41+
pub fn new(last_insert_id: Option<u64>, rows_affected: u64) -> Self {
4242
Self {
4343
last_insert_id,
4444
rows_affected,

src/entity/base_entity.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ pub trait EntityTrait: EntityName {
314314
///
315315
/// let insert_result = cake::Entity::insert(apple).exec(&db).await?;
316316
///
317-
/// assert_eq!(dbg!(insert_result.last_insert_id), 15);
317+
/// assert_eq!(dbg!(insert_result.last_insert_id), Some(15));
318318
///
319319
/// assert_eq!(
320320
/// db.into_transaction_log(),
@@ -356,7 +356,7 @@ pub trait EntityTrait: EntityName {
356356
///
357357
/// let insert_result = cake::Entity::insert(apple).exec(&db).await?;
358358
///
359-
/// assert_eq!(insert_result.last_insert_id, 15);
359+
/// assert_eq!(insert_result.last_insert_id, Some(15));
360360
///
361361
/// assert_eq!(
362362
/// db.into_transaction_log(),
@@ -407,7 +407,7 @@ pub trait EntityTrait: EntityName {
407407
///
408408
/// let insert_result = cake::Entity::insert_many([apple, orange]).exec(&db).await?;
409409
///
410-
/// assert_eq!(insert_result.last_insert_id, 28);
410+
/// assert_eq!(insert_result.last_insert_id, Some(28));
411411
///
412412
/// assert_eq!(
413413
/// db.into_transaction_log(),
@@ -453,7 +453,7 @@ pub trait EntityTrait: EntityName {
453453
///
454454
/// let insert_result = cake::Entity::insert_many([apple, orange]).exec(&db).await?;
455455
///
456-
/// assert_eq!(insert_result.last_insert_id, 28);
456+
/// assert_eq!(insert_result.last_insert_id, Some(28));
457457
///
458458
/// assert_eq!(
459459
/// db.into_transaction_log(),

src/executor/execute.rs

+6-10
Original file line numberDiff line numberDiff line change
@@ -32,28 +32,24 @@ pub(crate) enum ExecResultHolder {
3232
impl ExecResult {
3333
/// Get the last id after `AUTOINCREMENT` is done on the primary key
3434
///
35-
/// # Panics
36-
///
37-
/// Postgres does not support retrieving last insert id this way except through `RETURNING` clause
38-
pub fn last_insert_id(&self) -> u64 {
35+
/// Postgres always returns `None`
36+
pub fn last_insert_id(&self) -> Option<u64> {
3937
match &self.result {
4038
#[cfg(feature = "sqlx-mysql")]
41-
ExecResultHolder::SqlxMySql(result) => result.last_insert_id(),
39+
ExecResultHolder::SqlxMySql(result) => Some(result.last_insert_id()),
4240
#[cfg(feature = "sqlx-postgres")]
43-
ExecResultHolder::SqlxPostgres(_) => {
44-
panic!("Should not retrieve last_insert_id this way")
45-
}
41+
ExecResultHolder::SqlxPostgres(_) => None,
4642
#[cfg(feature = "sqlx-sqlite")]
4743
ExecResultHolder::SqlxSqlite(result) => {
4844
let last_insert_rowid = result.last_insert_rowid();
4945
if last_insert_rowid < 0 {
5046
unreachable!("negative last_insert_rowid")
5147
} else {
52-
last_insert_rowid as u64
48+
Some(last_insert_rowid as u64)
5349
}
5450
}
5551
#[cfg(feature = "mock")]
56-
ExecResultHolder::Mock(result) => result.last_insert_id,
52+
ExecResultHolder::Mock(result) => Some(result.last_insert_id),
5753
#[cfg(feature = "proxy")]
5854
ExecResultHolder::Proxy(result) => result.last_insert_id,
5955
#[allow(unreachable_patterns)]

src/executor/insert.rs

+26-16
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ where
2424
A: ActiveModelTrait,
2525
{
2626
/// The id performed when AUTOINCREMENT was performed on the PrimaryKey
27-
pub last_insert_id: <<<A as ActiveModelTrait>::Entity as EntityTrait>::PrimaryKey as PrimaryKeyTrait>::ValueType,
27+
pub last_insert_id: Option<<<<A as ActiveModelTrait>::Entity as EntityTrait>::PrimaryKey as PrimaryKeyTrait>::ValueType>,
2828
}
2929

3030
/// The types of results for an INSERT operation
@@ -226,7 +226,7 @@ where
226226
if res.rows_affected() == 0 {
227227
return Err(DbErr::RecordNotInserted);
228228
}
229-
FromValueTuple::from_value_tuple(value_tuple)
229+
Some(FromValueTuple::from_value_tuple(value_tuple))
230230
}
231231
(None, true) => {
232232
let mut rows = db.query_all(statement).await?;
@@ -237,24 +237,32 @@ where
237237
let cols = PrimaryKey::<A>::iter()
238238
.map(|col| col.to_string())
239239
.collect::<Vec<_>>();
240-
row.try_get_many("", cols.as_ref())
241-
.map_err(|_| DbErr::UnpackInsertId)?
240+
Some(
241+
row.try_get_many("", cols.as_ref())
242+
.map_err(|_| DbErr::UnpackInsertId)?,
243+
)
242244
}
243245
(None, false) => {
244246
let res = db.execute(statement).await?;
245247
if res.rows_affected() == 0 {
246248
return Err(DbErr::RecordNotInserted);
247249
}
248-
let last_insert_id = res.last_insert_id();
249-
// For MySQL, the affected-rows number:
250-
// - The affected-rows value per row is `1` if the row is inserted as a new row,
251-
// - `2` if an existing row is updated,
252-
// - and `0` if an existing row is set to its current values.
253-
// Reference: https://dev.mysql.com/doc/refman/8.4/en/insert-on-duplicate.html
254-
if db_backend == DbBackend::MySql && last_insert_id == 0 {
255-
return Err(DbErr::RecordNotInserted);
250+
if let Some(last_insert_id) = res.last_insert_id() {
251+
// For MySQL, the affected-rows number:
252+
// - The affected-rows value per row is `1` if the row is inserted as a new row,
253+
// - `2` if an existing row is updated,
254+
// - and `0` if an existing row is set to its current values.
255+
// Reference: https://dev.mysql.com/doc/refman/8.4/en/insert-on-duplicate.html
256+
if db_backend == DbBackend::MySql && last_insert_id == 0 {
257+
return Err(DbErr::RecordNotInserted);
258+
}
259+
Some(
260+
ValueTypeOf::<A>::try_from_u64(last_insert_id)
261+
.map_err(|_| DbErr::UnpackInsertId)?,
262+
)
263+
} else {
264+
None
256265
}
257-
ValueTypeOf::<A>::try_from_u64(last_insert_id).map_err(|_| DbErr::UnpackInsertId)?
258266
}
259267
};
260268

@@ -301,9 +309,11 @@ where
301309
}
302310
false => {
303311
let insert_res = exec_insert::<A, _>(primary_key, insert_statement, db).await?;
304-
<A::Entity as EntityTrait>::find_by_id(insert_res.last_insert_id)
305-
.one(db)
306-
.await?
312+
<A::Entity as EntityTrait>::find_by_id(insert_res.last_insert_id.ok_or(
313+
DbErr::RecordNotFound("No last insert id returned from the database".to_owned()),
314+
)?)
315+
.one(db)
316+
.await?
307317
}
308318
};
309319
match found {

tests/byte_primary_key_tests.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ pub async fn create_and_update(db: &DatabaseConnection) -> Result<(), DbErr> {
3030

3131
assert_eq!(Entity::find().one(db).await?, Some(model.clone()));
3232

33-
assert_eq!(res.last_insert_id, model.id);
33+
assert_eq!(res.last_insert_id, Some(model.id.clone()));
3434

3535
let updated_active_model = ActiveModel {
3636
value: Set("First Row (Updated)".to_owned()),

tests/crud/create_baker.rs

+16-9
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,21 @@ pub async fn test_create_baker(db: &DbConn) {
2727
let baker_bob = baker::ActiveModel {
2828
name: Set("Baker Bob".to_owned()),
2929
contact_details: Set(serde_json::json!(baker_bob_contact)),
30-
bakery_id: Set(Some(bakery_insert_res.last_insert_id)),
30+
bakery_id: Set(bakery_insert_res.last_insert_id),
3131
..Default::default()
3232
};
3333
let res = Baker::insert(baker_bob)
3434
.exec(db)
3535
.await
3636
.expect("could not insert baker");
3737

38-
let baker: Option<baker::Model> = Baker::find_by_id(res.last_insert_id)
39-
.one(db)
40-
.await
41-
.expect("could not find baker");
38+
let baker: Option<baker::Model> = Baker::find_by_id(
39+
res.last_insert_id
40+
.expect("could not get last insert id for baker"),
41+
)
42+
.one(db)
43+
.await
44+
.expect("could not find baker");
4245

4346
assert!(baker.is_some());
4447
let baker_model = baker.unwrap();
@@ -63,10 +66,14 @@ pub async fn test_create_baker(db: &DbConn) {
6366
"SeaSide Bakery"
6467
);
6568

66-
let bakery: Option<bakery::Model> = Bakery::find_by_id(bakery_insert_res.last_insert_id)
67-
.one(db)
68-
.await
69-
.unwrap();
69+
let bakery: Option<bakery::Model> = Bakery::find_by_id(
70+
bakery_insert_res
71+
.last_insert_id
72+
.expect("could not get last insert id for bakery"),
73+
)
74+
.one(db)
75+
.await
76+
.unwrap();
7077

7178
let related_bakers: Vec<baker::Model> = bakery
7279
.unwrap()

tests/crud/create_cake.rs

+25-13
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ pub async fn test_create_cake(db: &DbConn) {
1919
"home": "0395555555",
2020
"address": "12 Test St, Testville, Vic, Australia"
2121
})),
22-
bakery_id: Set(Some(bakery_insert_res.last_insert_id)),
22+
bakery_id: Set(bakery_insert_res.last_insert_id),
2323
..Default::default()
2424
};
2525
let baker_insert_res = Baker::insert(baker_bob)
@@ -33,7 +33,7 @@ pub async fn test_create_cake(db: &DbConn) {
3333
price: Set(rust_dec(-10.25)),
3434
gluten_free: Set(false),
3535
serial: Set(uuid),
36-
bakery_id: Set(Some(bakery_insert_res.last_insert_id)),
36+
bakery_id: Set(bakery_insert_res.last_insert_id),
3737
..Default::default()
3838
};
3939

@@ -42,22 +42,30 @@ pub async fn test_create_cake(db: &DbConn) {
4242
.await
4343
.expect("could not insert cake");
4444

45-
let cake: Option<cake::Model> = Cake::find_by_id(cake_insert_res.last_insert_id)
46-
.one(db)
47-
.await
48-
.expect("could not find cake");
45+
let cake: Option<cake::Model> = Cake::find_by_id(
46+
cake_insert_res
47+
.last_insert_id
48+
.expect("could not get last insert id for cake"),
49+
)
50+
.one(db)
51+
.await
52+
.expect("could not find cake");
4953

5054
let cake_baker = cakes_bakers::ActiveModel {
51-
cake_id: Set(cake_insert_res.last_insert_id),
52-
baker_id: Set(baker_insert_res.last_insert_id),
55+
cake_id: Set(cake_insert_res
56+
.last_insert_id
57+
.expect("could not get last insert id for cake")),
58+
baker_id: Set(baker_insert_res
59+
.last_insert_id
60+
.expect("could not get last insert id for baker")),
5361
};
5462
let cake_baker_res = CakesBakers::insert(cake_baker.clone())
5563
.exec(db)
5664
.await
5765
.expect("could not insert cake_baker");
5866
assert_eq!(
5967
cake_baker_res.last_insert_id,
60-
(cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap())
68+
Some((cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap()))
6169
);
6270

6371
assert!(cake.is_some());
@@ -85,10 +93,14 @@ pub async fn test_create_cake(db: &DbConn) {
8593
assert_eq!(related_bakers.len(), 1);
8694
assert_eq!(related_bakers[0].name, "Baker Bob");
8795

88-
let baker: Option<baker::Model> = Baker::find_by_id(baker_insert_res.last_insert_id)
89-
.one(db)
90-
.await
91-
.expect("could not find baker");
96+
let baker: Option<baker::Model> = Baker::find_by_id(
97+
baker_insert_res
98+
.last_insert_id
99+
.expect("could not get last insert id for baker"),
100+
)
101+
.one(db)
102+
.await
103+
.expect("could not find baker");
92104

93105
let related_cakes: Vec<cake::Model> = baker
94106
.unwrap()

tests/crud/create_lineitem.rs

+33-15
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ pub async fn test_create_lineitem(db: &DbConn) {
2222
"home": "0395555555",
2323
"address": "12 Test St, Testville, Vic, Australia"
2424
})),
25-
bakery_id: Set(Some(bakery_insert_res.last_insert_id)),
25+
bakery_id: Set(bakery_insert_res.last_insert_id),
2626
..Default::default()
2727
};
2828
let baker_insert_res = Baker::insert(baker_bob)
@@ -36,7 +36,7 @@ pub async fn test_create_lineitem(db: &DbConn) {
3636
price: Set(rust_dec(10.25)),
3737
gluten_free: Set(false),
3838
serial: Set(Uuid::new_v4()),
39-
bakery_id: Set(Some(bakery_insert_res.last_insert_id)),
39+
bakery_id: Set(bakery_insert_res.last_insert_id),
4040
..Default::default()
4141
};
4242

@@ -47,16 +47,20 @@ pub async fn test_create_lineitem(db: &DbConn) {
4747

4848
// Cake_Baker
4949
let cake_baker = cakes_bakers::ActiveModel {
50-
cake_id: Set(cake_insert_res.last_insert_id),
51-
baker_id: Set(baker_insert_res.last_insert_id),
50+
cake_id: Set(cake_insert_res
51+
.last_insert_id
52+
.expect("could not get last insert id for cake")),
53+
baker_id: Set(baker_insert_res
54+
.last_insert_id
55+
.expect("could not get last insert id for baker")),
5256
};
5357
let cake_baker_res = CakesBakers::insert(cake_baker.clone())
5458
.exec(db)
5559
.await
5660
.expect("could not insert cake_baker");
5761
assert_eq!(
5862
cake_baker_res.last_insert_id,
59-
(cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap())
63+
Some((cake_baker.cake_id.unwrap(), cake_baker.baker_id.unwrap()))
6064
);
6165

6266
// Customer
@@ -72,8 +76,12 @@ pub async fn test_create_lineitem(db: &DbConn) {
7276

7377
// Order
7478
let order_1 = order::ActiveModel {
75-
bakery_id: Set(bakery_insert_res.last_insert_id),
76-
customer_id: Set(customer_insert_res.last_insert_id),
79+
bakery_id: Set(bakery_insert_res
80+
.last_insert_id
81+
.expect("could not get last insert id for bakery")),
82+
customer_id: Set(customer_insert_res
83+
.last_insert_id
84+
.expect("could not get last insert id for customer")),
7785
total: Set(rust_dec(7.55)),
7886
placed_at: Set(Utc::now().naive_utc()),
7987
..Default::default()
@@ -85,8 +93,12 @@ pub async fn test_create_lineitem(db: &DbConn) {
8593

8694
// Lineitem
8795
let lineitem_1 = lineitem::ActiveModel {
88-
cake_id: Set(cake_insert_res.last_insert_id),
89-
order_id: Set(order_insert_res.last_insert_id),
96+
cake_id: Set(cake_insert_res
97+
.last_insert_id
98+
.expect("could not get last insert id for cake")),
99+
order_id: Set(order_insert_res
100+
.last_insert_id
101+
.expect("could not get last insert id for order")),
90102
price: Set(rust_dec(7.55)),
91103
quantity: Set(1),
92104
..Default::default()
@@ -96,11 +108,14 @@ pub async fn test_create_lineitem(db: &DbConn) {
96108
.await
97109
.expect("could not insert lineitem");
98110

99-
let lineitem: Option<lineitem::Model> =
100-
Lineitem::find_by_id(lineitem_insert_res.last_insert_id)
101-
.one(db)
102-
.await
103-
.expect("could not find lineitem");
111+
let lineitem: Option<lineitem::Model> = Lineitem::find_by_id(
112+
lineitem_insert_res
113+
.last_insert_id
114+
.expect("could not get last insert id for lineitem"),
115+
)
116+
.one(db)
117+
.await
118+
.expect("could not find lineitem");
104119

105120
assert!(lineitem.is_some());
106121
let lineitem_model = lineitem.unwrap();
@@ -121,5 +136,8 @@ pub async fn test_create_lineitem(db: &DbConn) {
121136
.expect("could not find order");
122137

123138
let order_model = order.unwrap();
124-
assert_eq!(order_model.customer_id, customer_insert_res.last_insert_id);
139+
assert_eq!(
140+
Some(order_model.customer_id),
141+
customer_insert_res.last_insert_id
142+
);
125143
}

0 commit comments

Comments
 (0)