Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions third_party/datastax/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ func (c *client) handleServerPreparedQuery(raw *frame.RawFrame, msg *message.Pre
// function to handle and delete query of prepared type
func (c *client) prepareDeleteType(raw *frame.RawFrame, msg *message.Prepare, id [16]byte) ([]*message.ColumnMetadata, []*message.ColumnMetadata, error) {
var returnColumns, variableColumns, columnsWithInOp []string
deleteQueryMetadata, err := c.proxy.translator.ToSpannerDelete(c.keyspace, msg.Query)
deleteQueryMetadata, err := c.proxy.translator.ToSpannerDelete(c.keyspace, msg.Query, false)
if err != nil {
c.proxy.logger.Error(translatorErrorMessage, zap.String(Query, msg.Query), zap.Error(err))
c.sender.Send(raw.Header, &message.Invalid{ErrorMessage: err.Error()})
Expand Down Expand Up @@ -879,7 +879,7 @@ func (c *client) prepareInsertType(raw *frame.RawFrame, msg *message.Prepare, id
// function to handle and select query of prepared type
func (c *client) prepareSelectType(raw *frame.RawFrame, msg *message.Prepare, id [16]byte) ([]*message.ColumnMetadata, []*message.ColumnMetadata, error) {
var variableColumns, columnsWithInOp []string
queryMetadata, err := c.proxy.translator.ToSpannerSelect(c.keyspace, msg.Query)
queryMetadata, err := c.proxy.translator.ToSpannerSelect(c.keyspace, msg.Query, false)
if err != nil {
c.proxy.logger.Error(translatorErrorMessage, zap.String(Query, msg.Query), zap.Error(err))
c.sender.Send(raw.Header, &message.Invalid{ErrorMessage: err.Error()})
Expand Down Expand Up @@ -932,7 +932,7 @@ func (c *client) prepareSelectType(raw *frame.RawFrame, msg *message.Prepare, id
// function to handle update query of prepared type
func (c *client) prepareUpdateType(raw *frame.RawFrame, msg *message.Prepare, id [16]byte) ([]*message.ColumnMetadata, []*message.ColumnMetadata, error) {
var returnColumns, variableColumns, columnsWithInOp []string
updateQueryMetadata, err := c.proxy.translator.ToSpannerUpdate(c.keyspace, msg.Query)
updateQueryMetadata, err := c.proxy.translator.ToSpannerUpdate(c.keyspace, msg.Query, false)
if err != nil {
c.proxy.logger.Error(translatorErrorMessage, zap.String(Query, msg.Query), zap.Error(err))
c.sender.Send(raw.Header, &message.Invalid{ErrorMessage: err.Error()})
Expand Down Expand Up @@ -1514,10 +1514,10 @@ func (c *client) handleQuery(raw *frame.RawFrame, msg *partialQuery) {
}
} else {
var result *message.RowsResult

var isSimpleQuery = true
switch queryType {
case selectType:
queryMetadata, err := c.proxy.translator.ToSpannerSelect(c.keyspace, msg.query)
queryMetadata, err := c.proxy.translator.ToSpannerSelect(c.keyspace, msg.query, isSimpleQuery)
if err != nil {
c.proxy.logger.Error(translatorErrorMessage, zap.String(Query, msg.query), zap.Error(err))
c.sender.Send(raw.Header, &message.Invalid{ErrorMessage: err.Error()})
Expand Down Expand Up @@ -1594,7 +1594,7 @@ func (c *client) handleQuery(raw *frame.RawFrame, msg *partialQuery) {

c.sender.Send(raw.Header, result)
case deleteType:
queryMetadata, err := c.proxy.translator.ToSpannerDelete(c.keyspace, msg.query)
queryMetadata, err := c.proxy.translator.ToSpannerDelete(c.keyspace, msg.query, isSimpleQuery)
if err != nil {
c.proxy.logger.Error(translatorErrorMessage, zap.String(Query, msg.query), zap.Error(err))
c.sender.Send(raw.Header, &message.Invalid{ErrorMessage: err.Error()})
Expand Down Expand Up @@ -1633,7 +1633,7 @@ func (c *client) handleQuery(raw *frame.RawFrame, msg *partialQuery) {
result.Metadata.ColumnCount = int32(len(VariableMetadata))
c.sender.Send(raw.Header, result)
case updateType:
queryMetadata, err := c.proxy.translator.ToSpannerUpdate(c.keyspace, msg.query)
queryMetadata, err := c.proxy.translator.ToSpannerUpdate(c.keyspace, msg.query, isSimpleQuery)
if err != nil {
c.proxy.logger.Error(translatorErrorMessage, zap.String(Query, msg.query), zap.Error(err))
c.sender.Send(raw.Header, &message.Invalid{ErrorMessage: err.Error()})
Expand Down
4 changes: 2 additions & 2 deletions translator/translator_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func createSpannerDeleteQuery(table string, clauses []Clause) string {
// - query: CQL Delete query
//
// Returns: DeleteQueryMap struct and error if any
func (t *Translator) ToSpannerDelete(keyspace string, queryStr string) (*DeleteQueryMap, error) {
func (t *Translator) ToSpannerDelete(keyspace string, queryStr string, isSimpleQuery bool) (*DeleteQueryMap, error) {
lowerQuery := strings.ToLower(queryStr)
query := renameLiterals(queryStr)
p, err := NewCqlParser(query, t.Debug)
Expand Down Expand Up @@ -95,7 +95,7 @@ func (t *Translator) ToSpannerDelete(keyspace string, queryStr string) (*DeleteQ
var clauseResponse ClauseResponse

if hasWhere(lowerQuery) {
resp, err := parseWhereByClause(deleteObj.WhereSpec(), tableSpec.TableName, t.TableConfig)
resp, err := parseWhereByClause(deleteObj.WhereSpec(), tableSpec.TableName, t.TableConfig, isSimpleQuery)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion translator/translator_delete_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ func TestTranslator_ToSpannerDelete(t *testing.T) {
Logger: tt.fields.Logger,
TableConfig: tableConfig,
}
got, err := tr.ToSpannerDelete(tt.args.keyspace, tt.args.query)
got, err := tr.ToSpannerDelete(tt.args.keyspace, tt.args.query, false)
if (err != nil) != tt.wantErr {
t.Errorf("Translator.ToSpannerDelete() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down
4 changes: 2 additions & 2 deletions translator/translator_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ func getSpannerSelectQuery(t *Translator, data *SelectQueryMap) (string, error)
// - originalQuery: CQL Select statement
//
// Returns: SelectQueryMap struct and error if any
func (t *Translator) ToSpannerSelect(keyspace string, originalQuery string) (*SelectQueryMap, error) {
func (t *Translator) ToSpannerSelect(keyspace string, originalQuery string, isSimpleQuery bool) (*SelectQueryMap, error) {
lowerQuery := strings.ToLower(originalQuery)
//Create copy of cassandra query where literals are substituted with a suffix
query := renameLiterals(originalQuery)
Expand Down Expand Up @@ -413,7 +413,7 @@ func (t *Translator) ToSpannerSelect(keyspace string, originalQuery string) (*Se
var clauseResponse ClauseResponse

if hasWhere(lowerQuery) {
resp, err := parseWhereByClause(selectObj.WhereSpec(), tableSpec.TableName, t.TableConfig)
resp, err := parseWhereByClause(selectObj.WhereSpec(), tableSpec.TableName, t.TableConfig, isSimpleQuery)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion translator/translator_select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ func TestTranslator_ToSpannerSelect(t *testing.T) {
TableConfig: tableConfig,
UseRowTimestamp: false,
}
got, err := tr.ToSpannerSelect(tt.args.keyspace, tt.args.query)
got, err := tr.ToSpannerSelect(tt.args.keyspace, tt.args.query, false)
if (err != nil) != tt.wantErr {
t.Errorf("Translator.ToSpannerSelect() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down
4 changes: 2 additions & 2 deletions translator/translator_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func createSpannerSelectQueryForMapUpdate(table string, columns string, clauses
// - query: CQL Update query
//
// Returns: UpdateQueryMap struct and error if any
func (t *Translator) ToSpannerUpdate(keyspace string, queryStr string) (*UpdateQueryMap, error) {
func (t *Translator) ToSpannerUpdate(keyspace string, queryStr string, isSimpleQuery bool) (*UpdateQueryMap, error) {
lowerQuery := strings.ToLower(queryStr)
query := renameLiterals(queryStr)
p, err := NewCqlParser(query, t.Debug)
Expand Down Expand Up @@ -352,7 +352,7 @@ func (t *Translator) ToSpannerUpdate(keyspace string, queryStr string) (*UpdateQ
var clauseResponse ClauseResponse

if hasWhere(lowerQuery) {
resp, err := parseWhereByClause(updateObj.WhereSpec(), tableName, t.TableConfig)
resp, err := parseWhereByClause(updateObj.WhereSpec(), tableName, t.TableConfig, isSimpleQuery)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions translator/translator_update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ func TestTranslator_ToSpannerUpdate(t *testing.T) {
UseRowTimestamp: true,
UseRowTTL: true,
}
got, err := tr.ToSpannerUpdate(tt.args.keyspace, tt.args.query)
got, err := tr.ToSpannerUpdate(tt.args.keyspace, tt.args.query, false)
if (err != nil) != tt.wantErr {
t.Errorf("Translator.ToSpannerUpdate() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down Expand Up @@ -326,7 +326,7 @@ func TestTranslator_ToSpannerUpdateWhenUsingTSTTLIsDisabled(t *testing.T) {
UseRowTimestamp: false,
UseRowTTL: false,
}
got, err := tr.ToSpannerUpdate(tt.args.keyspace, tt.args.query)
got, err := tr.ToSpannerUpdate(tt.args.keyspace, tt.args.query, false)
if (err != nil) != tt.wantErr {
t.Errorf("Translator.ToSpannerUpdate() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down
21 changes: 12 additions & 9 deletions translator/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -415,9 +415,10 @@ func NewCqlParser(cqlQuery string, isDebug bool) (*cql.CqlParser, error) {
// - input: The Where Spec context from the antlr Parser.
// - tableName - Table Name
// - tableConfig - JSON Config which maintains column and its datatypes info.
// - isSimpleQuery - Whether or not this where clause is part of query statement.
//
// Returns: ClauseResponse and an error if any.
func parseWhereByClause(input cql.IWhereSpecContext, tableName string, tableConfig *tableConfig.TableConfig) (*ClauseResponse, error) {
func parseWhereByClause(input cql.IWhereSpecContext, tableName string, tableConfig *tableConfig.TableConfig, isSimpleQuery bool) (*ClauseResponse, error) {
if input == nil {
return nil, errors.New("no input parameters found for clauses")
}
Expand Down Expand Up @@ -487,16 +488,18 @@ func parseWhereByClause(input cql.IWhereSpecContext, tableName string, tableConf
if value == "" {
return nil, errors.New("could not parse value from query for one of the clauses")
}
value = strings.ReplaceAll(value, "'", "")

if value != questionMark {

val, err := formatValues(value, columnType.SpannerType, columnType.CQLType)
if err != nil {
return nil, err
if isSimpleQuery && value == "'?'" {
params[placeholder] = questionMark
} else {
value = strings.ReplaceAll(value, "'", "")
if value != questionMark {
val, err := formatValues(value, columnType.SpannerType, columnType.CQLType)
if err != nil {
return nil, err
}
params[placeholder] = val
}

params[placeholder] = val
}
} else {
lower := strings.ToLower(val.GetText())
Expand Down
43 changes: 42 additions & 1 deletion translator/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,7 @@ func TestParseWhereByClause(t *testing.T) {
input string
tableName string
tableConfig tableConfig.TableConfig
isQuery bool
keyspace string
expectedResult *ClauseResponse
expectedErr error
Expand All @@ -905,6 +906,46 @@ func TestParseWhereByClause(t *testing.T) {
},
expectedErr: nil,
},
{
name: "Valid input with question mark in quotes",
input: "WHERE column1 = '?'",
tableName: "test_table",
keyspace: "key_space",
expectedResult: &ClauseResponse{
Clauses: []Clause{
{
Column: "column1",
Operator: "=",
Value: "@value1",
IsPrimaryKey: true,
},
},
Params: map[string]interface{}{
"value1": "?",
},
ParamKeys: []string{"value1"},
},
expectedErr: nil,
},
{
name: "Valid input with question mark",
input: "WHERE column1 = ?",
tableName: "test_table",
keyspace: "key_space",
expectedResult: &ClauseResponse{
Clauses: []Clause{
{
Column: "column1",
Operator: "=",
Value: "@value1",
IsPrimaryKey: true,
},
},
Params: map[string]interface{}{},
ParamKeys: []string{"value1"},
},
expectedErr: nil,
},
{
name: "Valid input with multiple clauses",
input: "WHERE column1 = 'test' AND column2 < '0x0000000000000003'",
Expand Down Expand Up @@ -1012,7 +1053,7 @@ func TestParseWhereByClause(t *testing.T) {
input = p.WhereSpec()

}
result, err := parseWhereByClause(input, test.tableName, tableConfig)
result, err := parseWhereByClause(input, test.tableName, tableConfig, true)

// Validate error
if (err == nil && test.expectedErr != nil) || (err != nil && test.expectedErr == nil) || (err != nil && test.expectedErr != nil && err.Error() != test.expectedErr.Error()) {
Expand Down
Loading