diff --git a/CHANGELOG.md b/CHANGELOG.md index 8dfb22828..c8cdcb802 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed ### Fixed +- Cassandra version unmarshal fix (CASSGO-49) ## [1.6.0] - 2023-08-28 diff --git a/host_source.go b/host_source.go index a0bab9ad0..9cda3478e 100644 --- a/host_source.go +++ b/host_source.go @@ -56,6 +56,7 @@ const ( type cassVersion struct { Major, Minor, Patch int + Qualifier string } func (c *cassVersion) Set(v string) error { @@ -87,13 +88,30 @@ func (c *cassVersion) unmarshal(data []byte) error { c.Minor, err = strconv.Atoi(v[1]) if err != nil { - return fmt.Errorf("invalid minor version %v: %v", v[1], err) + vMinor := strings.Split(v[1], "-") + if len(vMinor) < 2 { + return fmt.Errorf("invalid minor version %v: %v", v[1], err) + } + c.Minor, err = strconv.Atoi(vMinor[0]) + if err != nil { + return fmt.Errorf("invalid minor version %v: %v", v[1], err) + } + c.Qualifier = v[1][strings.Index(v[1], "-")+1:] + return nil } if len(v) > 2 { c.Patch, err = strconv.Atoi(v[2]) if err != nil { - return fmt.Errorf("invalid patch version %v: %v", v[2], err) + vPatch := strings.Split(v[2], "-") + if len(vPatch) < 2 { + return fmt.Errorf("invalid patch version %v: %v", v[2], err) + } + c.Patch, err = strconv.Atoi(vPatch[0]) + if err != nil { + return fmt.Errorf("invalid patch version %v: %v", v[2], err) + } + c.Qualifier = v[2][strings.Index(v[2], "-")+1:] } } @@ -121,6 +139,9 @@ func (c cassVersion) AtLeast(major, minor, patch int) bool { } func (c cassVersion) String() string { + if c.Qualifier != "" { + return fmt.Sprintf("%d.%d.%d-%v", c.Major, c.Minor, c.Patch, c.Qualifier) + } return fmt.Sprintf("v%d.%d.%d", c.Major, c.Minor, c.Patch) } diff --git a/host_source_test.go b/host_source_test.go index 081384237..cd4dfc025 100644 --- a/host_source_test.go +++ b/host_source_test.go @@ -29,6 +29,7 @@ package gocql import ( "errors" + "fmt" "net" "sync" "sync/atomic" @@ -41,9 +42,13 @@ func TestUnmarshalCassVersion(t *testing.T) { data string version cassVersion }{ - {"3.2", cassVersion{3, 2, 0}}, - {"2.10.1-SNAPSHOT", cassVersion{2, 10, 1}}, - {"1.2.3", cassVersion{1, 2, 3}}, + {"3.2", cassVersion{3, 2, 0, ""}}, + {"2.10.1-SNAPSHOT", cassVersion{2, 10, 1, ""}}, + {"1.2.3", cassVersion{1, 2, 3, ""}}, + {"4.0-rc2", cassVersion{4, 0, 0, "rc2"}}, + {"4.3.2-rc1", cassVersion{4, 3, 2, "rc1"}}, + {"4.3.2-rc1-qualifier1", cassVersion{4, 3, 2, "rc1-qualifier1"}}, + {"4.3-rc1-qualifier1", cassVersion{4, 3, 0, "rc1-qualifier1"}}, } for i, test := range tests { @@ -53,6 +58,7 @@ func TestUnmarshalCassVersion(t *testing.T) { } else if *v != test.version { t.Errorf("%d: expected %#+v got %#+v", i, test.version, *v) } + fmt.Println(v.String()) } } @@ -60,14 +66,17 @@ func TestCassVersionBefore(t *testing.T) { tests := [...]struct { version cassVersion major, minor, patch int + Qualifier string }{ - {cassVersion{1, 0, 0}, 0, 0, 0}, - {cassVersion{0, 1, 0}, 0, 0, 0}, - {cassVersion{0, 0, 1}, 0, 0, 0}, + {cassVersion{1, 0, 0, ""}, 0, 0, 0, ""}, + {cassVersion{0, 1, 0, ""}, 0, 0, 0, ""}, + {cassVersion{0, 0, 1, ""}, 0, 0, 0, ""}, - {cassVersion{1, 0, 0}, 0, 1, 0}, - {cassVersion{0, 1, 0}, 0, 0, 1}, - {cassVersion{4, 1, 0}, 3, 1, 2}, + {cassVersion{1, 0, 0, ""}, 0, 1, 0, ""}, + {cassVersion{0, 1, 0, ""}, 0, 0, 1, ""}, + {cassVersion{4, 1, 0, ""}, 3, 1, 2, ""}, + + {cassVersion{4, 1, 0, ""}, 3, 1, 2, ""}, } for i, test := range tests {