@@ -3,25 +3,34 @@ package sharding
33import (
44 "errors"
55 "fmt"
6+ "hash/crc32"
67 "strconv"
78 "strings"
89 "sync"
910
11+ "github.com/bwmarrin/snowflake"
1012 "github.com/longbridgeapp/sqlparser"
1113 "gorm.io/gorm"
1214 "gorm.io/gorm/schema"
1315)
1416
17+ const (
18+ PKSnowflake = iota // Use Snowflake primary key generator
19+ PKPGSequence // Use PostgreSQL sequence primary key generator
20+ PKCustom // Use custom primary key generator
21+ )
22+
1523var (
1624 ErrMissingShardingKey = errors .New ("sharding key or id required, and use operator =" )
1725 ErrInvalidID = errors .New ("invalid id format" )
1826)
1927
2028type Sharding struct {
2129 * gorm.DB
22- ConnPool * ConnPool
23- configs map [string ]Config
24- querys sync.Map
30+ ConnPool * ConnPool
31+ configs map [string ]Config
32+ querys sync.Map
33+ snowflakeNodes []* snowflake.Node
2534}
2635
2736// Config specifies the configuration for sharding.
@@ -33,6 +42,12 @@ type Config struct {
3342 // For example, for a product order table, you may want to split the rows by `user_id`.
3443 ShardingKey string
3544
45+ // NumberOfShards specifies how many tables you want to sharding.
46+ NumberOfShards uint
47+
48+ // tableFormat specifies the sharding table suffix format.
49+ tableFormat string
50+
3651 // ShardingAlgorithm specifies a function to generate the sharding
3752 // table's suffix by the column value.
3853 // For example, this function implements a mod sharding algorithm.
@@ -47,26 +62,27 @@ type Config struct {
4762
4863 // ShardingAlgorithmByPrimaryKey specifies a function to generate the sharding
4964 // table's suffix by the primary key. Used when no sharding key specified.
50- // For example, this function use the LongKey library to generate the suffix.
65+ // For example, this function use the Snowflake library to generate the suffix.
5166 //
5267 // func(id int64) (suffix string) {
53- // return fmt.Sprintf("_%02d", longkey.TableIdx (id))
68+ // return fmt.Sprintf("_%02d", snowflake.ParseInt64 (id).Node( ))
5469 // }
5570 ShardingAlgorithmByPrimaryKey func (id int64 ) (suffix string )
5671
57- // PrimaryKeyGenerate specifies a function to generate the primary key.
72+ // PrimaryKeyGenerator specifies the primary key generate algorithm .
5873 // Used only when insert and the record does not contains an id field.
59- // We recommend you use the
60- // [LongKey](https://github.com/longbridgeapp/longkey) component,
61- // it is a distributed primary key generator.
74+ // Options are PKSnowflake, PKPGSequence and PKCustom.
75+ // When use PKCustom, you should also specify PrimaryKeyGeneratorFn.
76+ PrimaryKeyGenerator int
77+
78+ // PrimaryKeyGeneratorFn specifies a function to generate the primary key.
6279 // When use auto-increment like generator, the tableIdx argument could ignored.
63- //
64- // For example, this function use the LongKey library to generate the primary key.
80+ // For example, this function use the Snowflake library to generate the primary key.
6581 //
6682 // func(tableIdx int64) int64 {
67- // return longkey.Next(tableIdx )
83+ // return nodes[tableIdx].Generate().Int64( )
6884 // }
69- PrimaryKeyGenerate func (tableIdx int64 ) int64
85+ PrimaryKeyGeneratorFn func (tableIdx int64 ) int64
7086}
7187
7288func Register (config Config , tables ... interface {}) * Sharding {
@@ -87,6 +103,75 @@ func (s *Sharding) Register(config Config, tables ...interface{}) *Sharding {
87103 }
88104 }
89105
106+ for t , c := range s .configs {
107+ if c .NumberOfShards > 1024 && c .PrimaryKeyGenerator == PKSnowflake {
108+ panic ("Snowflake NumberOfShards should less than 1024" )
109+ }
110+
111+ if c .PrimaryKeyGenerator == PKSnowflake {
112+ c .PrimaryKeyGeneratorFn = func (index int64 ) int64 {
113+ return s .snowflakeNodes [index ].Generate ().Int64 ()
114+ }
115+ } else if c .PrimaryKeyGenerator == PKPGSequence {
116+ c .PrimaryKeyGeneratorFn = func (index int64 ) int64 {
117+ var id int64
118+ err := s .DB .Raw ("SELECT nextval('" + pgSeqName (t ) + "')" ).Scan (& id ).Error
119+ if err != nil {
120+ panic (err )
121+ }
122+ return id
123+ }
124+ } else if c .PrimaryKeyGenerator == PKCustom {
125+ if c .PrimaryKeyGeneratorFn == nil {
126+ panic ("PrimaryKeyGeneratorFn not configured" )
127+ }
128+ } else {
129+ panic ("PrimaryKeyGenerator can only be one of PKSnowflake, PKPGSequence and PKCustom" )
130+ }
131+
132+ if c .ShardingAlgorithm == nil {
133+ if c .NumberOfShards == 0 {
134+ panic ("specify NumberOfShards or ShardingAlgorithm" )
135+ }
136+ if c .NumberOfShards < 10 {
137+ c .tableFormat = "_%01d"
138+ } else if c .NumberOfShards < 100 {
139+ c .tableFormat = "_%02d"
140+ } else if c .NumberOfShards < 1000 {
141+ c .tableFormat = "_%03d"
142+ } else if c .NumberOfShards < 10000 {
143+ c .tableFormat = "_%04d"
144+ }
145+ c .ShardingAlgorithm = func (value interface {}) (suffix string , err error ) {
146+ id := 0
147+ switch value := value .(type ) {
148+ case int :
149+ id = value
150+ case int64 :
151+ id = int (value )
152+ case string :
153+ id , err = strconv .Atoi (value )
154+ if err != nil {
155+ id = int (crc32 .ChecksumIEEE ([]byte (value )))
156+ }
157+ default :
158+ return "" , fmt .Errorf ("default algorithm only support integer and string column," +
159+ "if you use other type, specify you own ShardingAlgorithm" )
160+ }
161+ return fmt .Sprintf (c .tableFormat , id % int (c .NumberOfShards )), nil
162+ }
163+ }
164+
165+ if c .ShardingAlgorithmByPrimaryKey == nil {
166+ if c .PrimaryKeyGenerator == PKSnowflake {
167+ c .ShardingAlgorithmByPrimaryKey = func (id int64 ) (suffix string ) {
168+ return fmt .Sprintf (c .tableFormat , snowflake .ParseInt64 (id ).Node ())
169+ }
170+ }
171+ }
172+ s .configs [t ] = c
173+ }
174+
90175 return s
91176}
92177
@@ -108,6 +193,25 @@ func (s *Sharding) LastQuery() string {
108193func (s * Sharding ) Initialize (db * gorm.DB ) error {
109194 s .DB = db
110195 s .registerConnPool (db )
196+
197+ for t , c := range s .configs {
198+ if c .PrimaryKeyGenerator == PKPGSequence {
199+ err := s .DB .Exec ("CREATE SEQUENCE IF NOT EXISTS " + pgSeqName (t )).Error
200+ if err != nil {
201+ return fmt .Errorf ("init postgresql sequence error, %w" , err )
202+ }
203+ }
204+ }
205+
206+ s .snowflakeNodes = make ([]* snowflake.Node , 1024 )
207+ for i := int64 (0 ); i < 1024 ; i ++ {
208+ n , err := snowflake .NewNode (i )
209+ if err != nil {
210+ return fmt .Errorf ("init snowflake node error, %w" , err )
211+ }
212+ s .snowflakeNodes [i ] = n
213+ }
214+
111215 return nil
112216}
113217
@@ -208,7 +312,7 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery,
208312 if err != nil {
209313 return ftQuery , stQuery , tableName , err
210314 }
211- id := r .PrimaryKeyGenerate (int64 (tblIdx ))
315+ id := r .PrimaryKeyGeneratorFn (int64 (tblIdx ))
212316 insertNames = append (insertNames , & sqlparser.Ident {Name : "id" })
213317 insertValues = append (insertValues , & sqlparser.NumberLit {Value : strconv .FormatInt (id , 10 )})
214318 }
@@ -349,3 +453,7 @@ func getBindValue(value interface{}, args []interface{}) (interface{}, error) {
349453 }
350454 return args [pos - 1 ], nil
351455}
456+
457+ func pgSeqName (table string ) string {
458+ return fmt .Sprintf ("gorm_sharding_%s_id_seq" , table )
459+ }
0 commit comments