Skip to content

Commit

Permalink
Add AstraDB support
Browse files Browse the repository at this point in the history
  • Loading branch information
wizardishungry committed Feb 4, 2025
1 parent d477553 commit 5e87d11
Show file tree
Hide file tree
Showing 10 changed files with 373 additions and 22 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
SOURCE ?= file go_bindata github github_ee bitbucket aws_s3 google_cloud_storage godoc_vfs gitlab
DATABASE ?= postgres mysql redshift cassandra spanner cockroachdb yugabytedb clickhouse mongodb sqlserver firebird neo4j pgx pgx5 rqlite
DATABASE ?= postgres mysql redshift cassandra spanner cockroachdb yugabytedb clickhouse mongodb sqlserver firebird neo4j pgx pgx5 rqlite astra
DATABASE_TEST ?= $(DATABASE) sqlite sqlite3 sqlcipher
VERSION ?= $(shell git describe --tags 2>/dev/null | cut -c 2-)
TEST_FLAGS ?=
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Database drivers run migrations. [Add a new database?](database/driver.go)
* [PGX v5](database/pgx/v5)
* [Redshift](database/redshift)
* [Ql](database/ql)
* [Cassandra / ScyllaDB](database/cassandra)
* [Cassandra / ScyllaDB / AstraDB](database/cassandra)
* [SQLite](database/sqlite)
* [SQLite3](database/sqlite3) ([todo #165](https://github.com/mattes/migrate/issues/165))
* [SQLCipher](database/sqlcipher)
Expand Down
23 changes: 22 additions & 1 deletion database/cassandra/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Cassandra / ScyllaDB
# Cassandra / ScyllaDB / AstraDB

* `Drop()` method will not work on Cassandra 2.X because it rely on
system_schema table which comes with 3.X
Expand All @@ -13,6 +13,10 @@ system_schema table which comes with 3.X
* The `Drop()` method` works for ScyllaDB 5.1


**AstraDB**

* Astra uses different parameters for authentication. See below.

## Usage
`cassandra://host:port/keyspace?param1=value&param2=value2`

Expand All @@ -36,6 +40,23 @@ system_schema table which comes with 3.X

`timeout` is parsed using [time.ParseDuration(s string)](https://golang.org/pkg/time/#ParseDuration)

### [AstraDB](https://docs.datastax.com/)

`astra:///keyspace?bundle=bundle.zip&token=token` or
`astra:///keyspace?token=token&database_id=database_id`. *Note the triple slash.*

Astra supports two authentication schemes;
[bundle](https://pkg.go.dev/github.com/datastax/gocql-astra#NewClusterFromURL) and
[token](https://pkg.go.dev/github.com/datastax/gocql-astra#NewClusterFromURL).
The additional parameters are:


| URL Query | Default value | Description |
|------------|-------------|-----------|
| `token` | | Astra Bearer Token (beginning with AstraCS) |
| `database_id` | | Database ID |
| `bundle` | | Path to secure connect bundle |
| `api_url` | `https://api.astra.datastax.com` | Custom Astra Endpoint |

## Upgrading from v1

Expand Down
18 changes: 18 additions & 0 deletions database/cassandra/astra.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package cassandra

import (
"time"

"github.com/gocql/gocql"
)

// These are stubs to keep from depending on the astra driver.
// The astra package assigns to these from the astra driver.
var (
GocqlastraNewClusterFromURL = func(url string, databaseID string, token string, timeout time.Duration) (*gocql.ClusterConfig, error) {
panic("should not be used for cassandra")
}
GocqlastraNewClusterFromBundle = func(path string, username string, password string, timeout time.Duration) (*gocql.ClusterConfig, error) {
panic("should not be used for cassandra")
}
)
19 changes: 19 additions & 0 deletions database/cassandra/astra/astra.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package cassandra

import (
gocqlastra "github.com/datastax/gocql-astra"
"github.com/golang-migrate/migrate/v4/database"
"github.com/golang-migrate/migrate/v4/database/cassandra"
)

func init() {
db := new(Astra)
database.Register("astra", db)
}

type Astra = cassandra.Cassandra

func init() {
cassandra.GocqlastraNewClusterFromURL = gocqlastra.NewClusterFromURL
cassandra.GocqlastraNewClusterFromBundle = gocqlastra.NewClusterFromBundle
}
148 changes: 148 additions & 0 deletions database/cassandra/astra/astra_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
package cassandra

import (
"errors"
"net/url"
"testing"
"time"

"github.com/gocql/gocql"
cas "github.com/golang-migrate/migrate/v4/database/cassandra"
)

func TestAstra(t *testing.T) {
type mockResult struct {
timeout time.Duration

// NewClusterFromBundle
path string
username string
password string

// NewClusterFromURL
apiUrl string
databaseID string
token string
}

var (
errNewClusterFromBundle = errors.New("NewClusterFromBundle")
errNewClusterFromURL = errors.New("NewClusterFromURL")
)

test := func(t *testing.T, url string) (mockResult, error) {
t.Helper()

var mr mockResult

// Since we can't actually call the Astra API, we mock the calls and return an error so we never dial.
cas.GocqlastraNewClusterFromBundle = func(path string, username string, password string, timeout time.Duration) (*gocql.ClusterConfig, error) {
mr.path = path
mr.username = username
mr.password = password
mr.timeout = timeout
return nil, errNewClusterFromBundle
}

cas.GocqlastraNewClusterFromURL = func(apiUrl string, databaseID string, token string, timeout time.Duration) (*gocql.ClusterConfig, error) {
mr.apiUrl = apiUrl
mr.databaseID = databaseID
mr.token = token
mr.timeout = timeout
return nil, errNewClusterFromURL
}

astra := &Astra{}

_, err := astra.Open(url)
return mr, err
}

t.Run("Token", func(t *testing.T) {
mr, err := test(t, "astra:///testks?token=token&database_id=database_id")
if err != errNewClusterFromURL {
t.Error("Expected", errNewClusterFromURL, "but got", err)
}
if mr.token != "token" {
t.Error("Expected token to be 'token' but got", mr.token)
}
if mr.databaseID != "database_id" {
t.Error("Expected database_id to be 'database_id' but got", mr.databaseID)
}
})
t.Run("Bundle", func(t *testing.T) {
mr, err := test(t, "astra:///testks?bundle=bundle.zip&token=AstraCS:password")
if err != errNewClusterFromBundle {
t.Error("Expected", errNewClusterFromBundle, "but got", err)
}
if mr.path != "bundle.zip" {
t.Error("Expected path to be 'bundle.zip' but got", mr.path)
}
if mr.username != "token" {
t.Error("Expected username to be 'token' but got", mr.username)
}
if mr.password != "AstraCS:password" {
t.Error("Expected password to be 'AstraCS:password' but got", mr.password)
}
})

t.Run("No Keyspace", func(t *testing.T) {
astra := &Astra{}
_, err := astra.Open("astra://")
if err != cas.ErrNoKeyspace {
t.Error("Expected", cas.ErrNoKeyspace, "but got", err)
}
})

t.Run("AstraMissing", func(t *testing.T) {
astra := &Astra{}
_, err := astra.Open("astra:///testks")
if err != cas.ErrAstraMissing {
t.Error("Expected", cas.ErrAstraMissing, "but got", err)
}
})
t.Run("No Token", func(t *testing.T) {
astra := &Astra{}
_, err := astra.Open("astra:///testks?database_id=database_id")
if err != cas.ErrAstraMissing {
t.Error("Expected", cas.ErrAstraMissing, "but got", err)
}
})
t.Run("No DatabaseID", func(t *testing.T) {
astra := &Astra{}
_, err := astra.Open("astra:///testks?token=AstraCS:password")
if err != cas.ErrAstraMissing {
t.Error("Expected", cas.ErrAstraMissing, "but got", err)
}
})
t.Run("No Bundle", func(t *testing.T) {
astra := &Astra{}
_, err := astra.Open("astra:///testks?token=AstraCS:password")
if err != cas.ErrAstraMissing {
t.Error("Expected", cas.ErrAstraMissing, "but got", err)
}
})
t.Run("Custom API URL", func(t *testing.T) {
mr, err := test(t, "astra:///testks?token=token&database_id=database_id&api_url=api_url")
if err != errNewClusterFromURL {
t.Error("Expected", errNewClusterFromURL, "but got", err)
}
if mr.apiUrl != "api_url" {
t.Error("Expected api_url to be 'api_url' but got", mr.apiUrl)
}
})
}

func TestTripleSlashInURLMeansNoHost(t *testing.T) {
const str = "astra:///testks?token=token&database_id=database_id"
u, err := url.Parse(str)
if err != nil {
t.Fatal(err)
}
if u.Host != "" {
t.Error("Expected host to be empty but got", u.Host)
}
if u.Path != "/testks" {
t.Error("Expected path to be '/testks' but got", u.Path)
}
}
51 changes: 46 additions & 5 deletions database/cassandra/cassandra.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"go.uber.org/atomic"

gocqlastra "github.com/datastax/gocql-astra"
"github.com/gocql/gocql"
"github.com/golang-migrate/migrate/v4/database"
"github.com/golang-migrate/migrate/v4/database/multistmt"
Expand All @@ -35,6 +36,7 @@ var (
ErrNoKeyspace = errors.New("no keyspace provided")
ErrDatabaseDirty = errors.New("database is dirty")
ErrClosedSession = errors.New("session is closed")
ErrAstraMissing = errors.New("missing required parameters for Astra connection")
)

type Config struct {
Expand Down Expand Up @@ -84,25 +86,64 @@ func WithInstance(session *gocql.Session, config *Config) (database.Driver, erro
}

func (c *Cassandra) Open(url string) (database.Driver, error) {
const (
timeout = 1 * time.Minute
)
u, err := nurl.Parse(url)
if err != nil {
return nil, err
}

isAstra := u.Scheme == "astra"

username := u.Query().Get("username")
password := u.Query().Get("password")

if isAstra {
username = "token"
}

// Check for missing mandatory attributes
if len(u.Path) == 0 {
return nil, ErrNoKeyspace
}

cluster := gocql.NewCluster(u.Host)
var cluster *gocql.ClusterConfig

if isAstra {
bundle := u.Query().Get("bundle")
databaseID := u.Query().Get("database_id")
token := u.Query().Get("token")
apiUrl := u.Query().Get("api_url")
if apiUrl == "" {
apiUrl = gocqlastra.AstraAPIURL
}

if bundle == "" && databaseID != "" && token != "" {
cluster, err = GocqlastraNewClusterFromURL(apiUrl, databaseID, token, timeout)
if err != nil {
return nil, err
}
} else if bundle != "" && token != "" {
cluster, err = GocqlastraNewClusterFromBundle(bundle, username, token, timeout)
if err != nil {
return nil, err
}
} else {
return nil, ErrAstraMissing
}
} else {
cluster = gocql.NewCluster(u.Host)
}

cluster.Keyspace = strings.TrimPrefix(u.Path, "/")
cluster.Consistency = gocql.All
cluster.Timeout = 1 * time.Minute
cluster.Timeout = timeout

if len(u.Query().Get("username")) > 0 && len(u.Query().Get("password")) > 0 {
if !isAstra && len(username) > 0 && len(password) > 0 {
authenticator := gocql.PasswordAuthenticator{
Username: u.Query().Get("username"),
Password: u.Query().Get("password"),
Username: username,
Password: password,
}
cluster.Authenticator = authenticator
}
Expand Down
Loading

0 comments on commit 5e87d11

Please sign in to comment.