Skip to content

Commit 6736eca

Browse files
committed
small improvements
1 parent ace1265 commit 6736eca

File tree

4 files changed

+33
-27
lines changed

4 files changed

+33
-27
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,12 @@ Config the sharding middleware, register the tables which you want to shard. See
4141
```go
4242
db.Use(sharding.Register(sharding.Config{
4343
ShardingKey: "user_id",
44-
ShardingNumber: 64,
45-
PrimaryKeyGenerator: PKSnowflake,
44+
NumberOfShards: 64,
45+
PrimaryKeyGenerator: sharding.PKSnowflake,
4646
}, "orders").Register(sharding.Config{
4747
ShardingKey: "user_id",
48-
ShardingNumber: 256,
49-
PrimaryKeyGenerator: PKSnowflake,
48+
NumberOfShards: 256,
49+
PrimaryKeyGenerator: sharding.PKSnowflake,
5050
// This case for show up give notifications, audit_logs table use same sharding rule.
5151
}, Notification{}, AuditLog{}))
5252
```

examples/order.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ func main() {
3333

3434
middleware := sharding.Register(sharding.Config{
3535
ShardingKey: "user_id",
36-
ShardingNumber: 64,
36+
NumberOfShards: 64,
3737
PrimaryKeyGenerator: sharding.PKSnowflake,
3838
}, "orders")
3939
db.Use(middleware)

sharding.go

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ type Config struct {
4242
// For example, for a product order table, you may want to split the rows by `user_id`.
4343
ShardingKey string
4444

45-
// ShardingNumber specifies how many tables you want to sharding.
46-
ShardingNumber uint
45+
// NumberOfShards specifies how many tables you want to sharding.
46+
NumberOfShards uint
4747

48-
// TableFormat specifies the sharding table suffix format.
49-
TableFormat string
48+
// tableFormat specifies the sharding table suffix format.
49+
tableFormat string
5050

5151
// ShardingAlgorithm specifies a function to generate the sharding
5252
// table's suffix by the column value.
@@ -104,15 +104,18 @@ func (s *Sharding) Register(config Config, tables ...interface{}) *Sharding {
104104
}
105105

106106
for t, c := range s.configs {
107+
if c.NumberOfShards > 1024 && c.PrimaryKeyGenerator == PKSnowflake {
108+
panic("Snowflake NumberOfShards should less than 1024")
109+
}
110+
107111
if c.PrimaryKeyGenerator == PKSnowflake {
108112
c.PrimaryKeyGeneratorFn = func(index int64) int64 {
109113
return s.snowflakeNodes[index].Generate().Int64()
110114
}
111115
} else if c.PrimaryKeyGenerator == PKPGSequence {
112-
sname := "gorm_sharding_serial_for_" + t
113116
c.PrimaryKeyGeneratorFn = func(index int64) int64 {
114117
var id int64
115-
err := s.DB.Raw("SELECT nextval('" + sname + "')").Scan(&id).Error
118+
err := s.DB.Raw("SELECT nextval('" + pgSeqName(t) + "')").Scan(&id).Error
116119
if err != nil {
117120
panic(err)
118121
}
@@ -127,17 +130,17 @@ func (s *Sharding) Register(config Config, tables ...interface{}) *Sharding {
127130
}
128131

129132
if c.ShardingAlgorithm == nil {
130-
if c.ShardingNumber == 0 {
131-
panic("specify ShardingNumber or ShardingAlgorithm")
133+
if c.NumberOfShards == 0 {
134+
panic("specify NumberOfShards or ShardingAlgorithm")
132135
}
133-
if c.ShardingNumber < 10 {
134-
c.TableFormat = "_%01d"
135-
} else if c.ShardingNumber < 100 {
136-
c.TableFormat = "_%02d"
137-
} else if c.ShardingNumber < 1000 {
138-
c.TableFormat = "_%03d"
139-
} else if c.ShardingNumber < 10000 {
140-
c.TableFormat = "_%04d"
136+
if c.NumberOfShards < 10 {
137+
c.tableFormat = "_%01d"
138+
} else if c.NumberOfShards < 100 {
139+
c.tableFormat = "_%02d"
140+
} else if c.NumberOfShards < 1000 {
141+
c.tableFormat = "_%03d"
142+
} else if c.NumberOfShards < 10000 {
143+
c.tableFormat = "_%04d"
141144
}
142145
c.ShardingAlgorithm = func(value interface{}) (suffix string, err error) {
143146
id := 0
@@ -155,14 +158,14 @@ func (s *Sharding) Register(config Config, tables ...interface{}) *Sharding {
155158
return "", fmt.Errorf("default algorithm only support integer and string column," +
156159
"if you use other type, specify you own ShardingAlgorithm")
157160
}
158-
return fmt.Sprintf(c.TableFormat, id%int(c.ShardingNumber)), nil
161+
return fmt.Sprintf(c.tableFormat, id%int(c.NumberOfShards)), nil
159162
}
160163
}
161164

162165
if c.ShardingAlgorithmByPrimaryKey == nil {
163166
if c.PrimaryKeyGenerator == PKSnowflake {
164167
c.ShardingAlgorithmByPrimaryKey = func(id int64) (suffix string) {
165-
return fmt.Sprintf(c.TableFormat, snowflake.ParseInt64(id).Node())
168+
return fmt.Sprintf(c.tableFormat, snowflake.ParseInt64(id).Node())
166169
}
167170
}
168171
}
@@ -193,8 +196,7 @@ func (s *Sharding) Initialize(db *gorm.DB) error {
193196

194197
for t, c := range s.configs {
195198
if c.PrimaryKeyGenerator == PKPGSequence {
196-
sname := "gorm_sharding_serial_for_" + t
197-
err := s.DB.Exec("CREATE SEQUENCE IF NOT EXISTS " + sname).Error
199+
err := s.DB.Exec("CREATE SEQUENCE IF NOT EXISTS " + pgSeqName(t)).Error
198200
if err != nil {
199201
return fmt.Errorf("init postgresql sequence error, %w", err)
200202
}
@@ -451,3 +453,7 @@ func getBindValue(value interface{}, args []interface{}) (interface{}, error) {
451453
}
452454
return args[pos-1], nil
453455
}
456+
457+
func pgSeqName(table string) string {
458+
return fmt.Sprintf("gorm_sharding_%s_id_seq", table)
459+
}

sharding_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ var (
5050
shardingConfig = Config{
5151
DoubleWrite: true,
5252
ShardingKey: "user_id",
53-
ShardingNumber: 4,
53+
NumberOfShards: 4,
5454
PrimaryKeyGenerator: PKSnowflake,
5555
}
5656

@@ -231,7 +231,7 @@ func TestPKPGSequence(t *testing.T) {
231231
middleware := Register(shardingConfig, &Order{})
232232
db.Use(middleware)
233233

234-
db.Exec("SELECT setval('gorm_sharding_serial_for_orders', 42)")
234+
db.Exec("SELECT setval('" + pgSeqName("orders") + "', 42)")
235235
db.Create(&Order{UserID: 100, Product: "iPhone"})
236236
expected := `INSERT INTO "orders_0" ("user_id", "product", "id") VALUES ($1, $2, 43) RETURNING "id"`
237237
assert.Equal(t, expected, middleware.LastQuery())

0 commit comments

Comments
 (0)