Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

basic support for prepared statements #18

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,24 @@ gorqlite.TraceOn(os.Stderr)

// turn off
gorqlite.TraceOff()


// using prepared statements
wr, err := conn.WritePrepared(
[]*gorqlite.PreparedStatement{
{
Query: "INSERT INTO secret_agents(id, name, secret) VALUES(?, ?, ?)",
Arguments: []interface{}{7, "James Bond", []byte{0x42}}
}
}
)
// alternatively
wr, err := conn.WriteOnePrepared(
&gorqlite.PreparedStatement{
Query: "INSERT INTO secret_agents(id, name, secret) VALUES(?, ?, ?)",
Arguments: []interface{}{7, "James Bond", []byte{0x42}},
},
)
```
## Important Notes

Expand Down
30 changes: 24 additions & 6 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ package gorqlite
this file has low level stuff:

rqliteApiGet()
rqliteApiPost()
rqliteApiPostPrepared()

There is some code duplication between those and they should
probably be combined into one function.
Expand All @@ -21,6 +21,11 @@ import "io/ioutil"
import "net/http"
import "time"

type PreparedStatement struct {
Query string
Arguments []interface{}
}

/* *****************************************************************

method: rqliteApiGet() - for api_STATUS
Expand Down Expand Up @@ -102,7 +107,7 @@ PeerLoop:

/* *****************************************************************

method: rqliteApiPost() - for api_QUERY and api_WRITE
method: rqliteApiPostPrepared() - for api_QUERY and api_WRITE

- lowest level interface - does not do any JSON unmarshaling
- handles 301s, etc.
Expand All @@ -114,7 +119,7 @@ PeerLoop:

* *****************************************************************/

func (conn *Connection) rqliteApiPost(apiOp apiOperation, sqlStatements []string) ([]byte, error) {
func (conn *Connection) rqliteApiPostPrepared(apiOp apiOperation, sqlStatements []*PreparedStatement) ([]byte, error) {
var responseBody []byte

switch apiOp {
Expand All @@ -123,13 +128,26 @@ func (conn *Connection) rqliteApiPost(apiOp apiOperation, sqlStatements []string
case api_WRITE:
trace("%s: rqliteApiGet() post called for a QUERY of %d statements", conn.ID, len(sqlStatements))
default:
return responseBody, errors.New("weird! called for an invalid apiOperation in rqliteApiPost()")
return responseBody, errors.New("weird! called for an invalid apiOperation in rqliteApiPostPrepared()")
}

// jsonify the statements. not really needed in the
// case of api_STATUS but doesn't hurt

jStatements, err := json.Marshal(sqlStatements)

formattedStatements := make([][]interface{}, 0, len(sqlStatements))

for _, statement := range sqlStatements {
formattedStatement := make([]interface{}, 0, len(statement.Arguments)+1)
formattedStatement = append(formattedStatement, statement.Query)

for _, argument := range statement.Arguments {
formattedStatement = append(formattedStatement, argument)
}
formattedStatements = append(formattedStatements, formattedStatement)
}

jStatements, err := json.Marshal(formattedStatements)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -200,4 +218,4 @@ PeerLoop:
stringBuffer.WriteString(fmt.Sprintf(" peer #%d: %s\n", n, v))
}
return responseBody, errors.New(stringBuffer.String())
}
}
31 changes: 28 additions & 3 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,21 @@ func (conn *Connection) QueryOne(sqlStatement string) (qr QueryResult, err error
return qra[0], err
}

func (conn *Connection) QueryOnePrepared(statement *PreparedStatement) (qr QueryResult, err error) {
if conn.hasBeenClosed {
qr.Err = errClosed
return qr, errClosed
}
qra, err := conn.QueryPrepared([]*PreparedStatement{statement})
return qra[0], err
}

/*
Query() is used to perform SELECT operations in the database.
QueryPrepared() is used to perform SELECT operations in the database.

It takes an array of SQL statements and executes them in a single transaction, returning an array of QueryResult vars.
*/
func (conn *Connection) Query(sqlStatements []string) (results []QueryResult, err error) {
func (conn *Connection) QueryPrepared(sqlStatements []*PreparedStatement) (results []QueryResult, err error) {
results = make([]QueryResult, 0)

if conn.hasBeenClosed {
Expand All @@ -122,7 +131,7 @@ func (conn *Connection) Query(sqlStatements []string) (results []QueryResult, er
trace("%s: Query() for %d statements", conn.ID, len(sqlStatements))

// if we get an error POSTing, that's a showstopper
response, err := conn.rqliteApiPost(api_QUERY, sqlStatements)
response, err := conn.rqliteApiPostPrepared(api_QUERY, sqlStatements)
if err != nil {
trace("%s: rqliteApiCall() ERROR: %s", conn.ID, err.Error())
var errResult QueryResult
Expand Down Expand Up @@ -203,6 +212,22 @@ func (conn *Connection) Query(sqlStatements []string) (results []QueryResult, er
}
}

/*
Query() is used to perform SELECT operations in the database.

It takes an array of SQL statements and executes them in a single transaction, returning an array of QueryResult vars.
*/

func (conn *Connection) Query(sqlStatements []string) (results []QueryResult, err error) {
preparedStatements := make([]*PreparedStatement, 0, len(sqlStatements))
for _, sqlStatement := range sqlStatements {
preparedStatements = append(preparedStatements, &PreparedStatement{
Query: sqlStatement,
})
}
return conn.QueryPrepared(preparedStatements)
}

/* *****************************************************************

type: QueryResult
Expand Down
31 changes: 28 additions & 3 deletions query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ func TestQueryOne(t *testing.T) {
var wr WriteResult
var qr QueryResult
var wResults []WriteResult
var qResults []QueryResult
var err error

t.Logf("trying Open")
Expand Down Expand Up @@ -61,6 +60,19 @@ func TestQueryOne(t *testing.T) {
t.Fail()
}

t.Logf("trying QueryOnePrepared")
qr, err = conn.QueryOnePrepared(
&PreparedStatement{
Query: fmt.Sprintf("SELECT name, ts FROM %s WHERE id > ?", testTableName()),
Arguments: []interface{}{3},
},
)

if err != nil {
t.Logf("--> FAILED")
t.Fail()
}

t.Logf("trying Next()")
na := qr.Next()
if na != true {
Expand Down Expand Up @@ -171,11 +183,24 @@ func TestQueryOne(t *testing.T) {
t2 = append(t2, "SELECT id FROM "+testTableName()+"")
t2 = append(t2, "SELECT name FROM "+testTableName()+"")
t2 = append(t2, "SELECT id,name FROM "+testTableName()+"")
qResults, err = conn.Query(t2)
_, err = conn.Query(t2)
if err == nil {
t.Logf("--> FAILED")
t.Fail()
}

t.Logf("trying Query after Close")
_, err = conn.QueryPrepared(
[]*PreparedStatement{
{ Query: fmt.Sprintf("SELECT id FROM %s", testTableName()), },
{ Query: fmt.Sprintf("SELECT name FROM %s", testTableName()), },
{ Query: fmt.Sprintf("SELECT id, name FROM %s", testTableName()), },
},
)

if err == nil {
t.Logf("--> FAILED")
t.Fail()
}
_ = qResults

}
39 changes: 38 additions & 1 deletion write.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,53 @@ func (conn *Connection) WriteOne(sqlStatement string) (wr WriteResult, err error
return wra[0], err
}

/*
WriteOnePrepared() is a convenience method that wraps WritePrepared() into a single-statement
method.
*/

func (conn *Connection) WriteOnePrepared(statement *PreparedStatement) (wr WriteResult, err error) {
if conn.hasBeenClosed {
wr.Err = errClosed
return wr, errClosed
}
wra, err := conn.WritePrepared([]*PreparedStatement{statement})
return wra[0], err
}


/*
Write() is used to perform DDL/DML in the database. ALTER, CREATE, DELETE, DROP, INSERT, UPDATE, etc. all go through Write().

Write() takes an array of SQL statements, and returns an equal-sized array of WriteResults, each corresponding to the SQL statement that produced it.

All statements are executed as a single transaction.

Write() uses WritePrepared()

Write() returns an error if one is encountered during its operation. If it's something like a call to the rqlite API, then it'll return that error. If one statement out of several has an error, it will return a generic "there were %d statement errors" and you'll have to look at the individual statement's Err for more info.
*/
func (conn *Connection) Write(sqlStatements []string) (results []WriteResult, err error) {
preparedStatements := make([]*PreparedStatement, 0, len(sqlStatements))
for _, sqlStatement := range sqlStatements {
preparedStatements = append(preparedStatements, &PreparedStatement{
Query: sqlStatement,
})
}
return conn.WritePrepared(preparedStatements)
}

/*
WritePrepared() is used to perform DDL/DML in the database. ALTER, CREATE, DELETE, DROP, INSERT, UPDATE, etc. all go through Write().

WritePrepared() takes an array of SQL statements, and returns an equal-sized array of WriteResults, each corresponding to the SQL statement that produced it.

All statements are executed as a single transaction.

WritePrepared() returns an error if one is encountered during its operation. If it's something like a call to the rqlite API, then it'll return that error. If one statement out of several has an error, it will return a generic "there were %d statement errors" and you'll have to look at the individual statement's Err for more info.
*/

func (conn *Connection) WritePrepared(sqlStatements []*PreparedStatement) (results []WriteResult, err error) {
results = make([]WriteResult, 0)

if conn.hasBeenClosed {
Expand All @@ -91,7 +128,7 @@ func (conn *Connection) Write(sqlStatements []string) (results []WriteResult, er

trace("%s: Write() for %d statements", conn.ID, len(sqlStatements))

response, err := conn.rqliteApiPost(api_WRITE, sqlStatements)
response, err := conn.rqliteApiPostPrepared(api_WRITE, sqlStatements)
if err != nil {
trace("%s: rqliteApiCall() ERROR: %s", conn.ID, err.Error())
var errResult WriteResult
Expand Down
54 changes: 53 additions & 1 deletion write_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package gorqlite

import "testing"
import (
"fmt"
"testing"
)

// import "os"

Expand Down Expand Up @@ -49,6 +52,25 @@ func TestWriteOne(t *testing.T) {
t.Fail()
}

t.Logf("trying WriteOne INSERT")
wr, err = conn.WriteOnePrepared(
&PreparedStatement{
Query: fmt.Sprintf("INSERT INTO %s (id, name) VALUES (?, ?)", testTableName()),
Arguments: []interface{}{1, "aaa bbb ccc"},
},
)

if err != nil {
t.Logf("--> FAILED")
t.Fail()
}

t.Logf("checking WriteOnePrepared RowsAffected")
if wr.RowsAffected != 1 {
t.Logf("--> FAILED")
t.Fail()
}

t.Logf("trying WriteOne DROP")
wr, err = conn.WriteOne("DROP TABLE IF EXISTS " + testTableName() + "")
if err != nil {
Expand Down Expand Up @@ -96,6 +118,36 @@ func TestWrite(t *testing.T) {
t.Fail()
}

t.Logf("trying Write INSERT")
results, err = conn.WritePrepared(
[]*PreparedStatement{
{
Query: fmt.Sprintf("INSERT INTO %s (id, name) VALUES (?, ?)", testTableName()),
Arguments: []interface{}{1, "aaa bbb ccc"},
},
{
Query: fmt.Sprintf("INSERT INTO %s (id, name) VALUES (?, ?)", testTableName()),
Arguments: []interface{}{1, "aaa bbb ccc"},
},
{
Query: fmt.Sprintf("INSERT INTO %s (id, name) VALUES (?, ?)", testTableName()),
Arguments: []interface{}{1, "aaa bbb ccc"},
},
{
Query: fmt.Sprintf("INSERT INTO %s (id, name) VALUES (?, ?)", testTableName()),
Arguments: []interface{}{1, "aaa bbb ccc"},
},
},
)
if err != nil {
t.Logf("--> FAILED")
t.Fail()
}
if len(results) != 4 {
t.Logf("--> FAILED")
t.Fail()
}

t.Logf("trying Write DROP")
s = make([]string, 0)
s = append(s, "DROP TABLE IF EXISTS "+testTableName()+"")
Expand Down