Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions go/test/endtoend/vtgate/unsharded/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ BEGIN
insert into allDefaults(id) values (128);
select 128 into val from dual;
END;
`,
`CREATE DEFINER=current_user() PROCEDURE with_definer(OUT val int)
BEGIN
insert into allDefaults(id) values (128);
select 128 into val from dual;
END;
`}
)

Expand Down
63 changes: 46 additions & 17 deletions go/vt/sqlparser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,38 @@ func (p *Parser) SplitStatement(blob string) (string, string, error) {
return blob, "", nil
}

var validCreatePrefixes = [][]int{
// These are the tokens (in order) for valid "create procedure" forms.
{CREATE, PROCEDURE},
{CREATE, DEFINER, '=', CURRENT_USER, PROCEDURE},
{CREATE, DEFINER, '=', CURRENT_USER, '(', ')', PROCEDURE},
{CREATE, DEFINER, '=', STRING, PROCEDURE},
{CREATE, DEFINER, '=', STRING, AT_ID, PROCEDURE},
{CREATE, DEFINER, '=', ID, PROCEDURE},
{CREATE, DEFINER, '=', ID, AT_ID, PROCEDURE},
}

// matchesCreateProcedurePrefix checks if the given token sequence
// is a create procedure statement or not.
func matchesCreateProcedurePrefix(tokens []int) bool {
// Check each candidate sequence.
for _, pattern := range validCreatePrefixes {
if len(tokens) >= len(pattern) {
match := true
for i, tok := range pattern {
if tokens[i] != tok {
match = false
break
}
}
if match {
return true
}
}
}
return false
}

// SplitStatementToPieces splits raw sql statement that may have multi sql pieces to sql pieces
// returns the sql pieces blob contains; or error if sql cannot be parsed.
func (p *Parser) SplitStatementToPieces(blob string) (pieces []string, err error) {
Expand All @@ -263,27 +295,25 @@ func (p *Parser) SplitStatementToPieces(blob string) (pieces []string, err error
var stmt string
stmtBegin := 0
emptyStatement := true
var prevToken int
var isCreateProcedureStatement bool
var startTokens []int // holds the first tokens of the current statement

loop:
for {
tkn, _ = tokenizer.Scan()
switch tkn {
case ';':
// Potential end of the statement.
stmt = blob[stmtBegin : tokenizer.Pos-1]
// We now try to parse the statement to see if its complete.
// If it is a create procedure, then it might not be complete, and we
// would need to scan to the next ;
if isCreateProcedureStatement && p.IsStatementIncomplete(stmt) {
// If it's a create procedure statement and is incomplete, skip appending.
if matchesCreateProcedurePrefix(startTokens) && p.IsStatementIncomplete(stmt) {
continue
}
if !emptyStatement {
pieces = append(pieces, stmt)
// We can now reset the variables for the next statement.
// It starts off as an empty statement and we don't know if it is
// a create procedure statement yet.
// It starts off as an empty statement.
emptyStatement = true
isCreateProcedureStatement = false
startTokens = startTokens[:0] // clear token slice
}
stmtBegin = tokenizer.Pos
case 0, eofChar:
Expand All @@ -296,16 +326,15 @@ loop:
}
break loop
case COMMENT:
// We want to ignore comments and not store them in the prevToken for knowing
// if the current statement is a create procedure statement.
// Skip comments entirely without altering the token list.
continue
case PROCEDURE:
if prevToken == CREATE {
isCreateProcedureStatement = true
}
fallthrough
default:
prevToken = tkn
// If we're at the very start of a statement, or we haven't filled out enough tokens
// for our valid prefix match (assuming our longest valid sequence is 10 tokens),
// accumulate the token.
if len(startTokens) < 10 {
startTokens = append(startTokens, tkn)
}
emptyStatement = false
}
}
Expand Down
33 changes: 28 additions & 5 deletions go/vt/sqlparser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,35 @@ func TestSplitStatementToPieces(t *testing.T) {
// Test that we don't split on semicolons inside create procedure calls.
input: "select * from t1;create procedure p1 (in country CHAR(3), out cities INT) begin select count(*) from x where d = e; end;select * from t2",
lenWanted: 3,
}, {
// Create procedure with comments.
input: "select * from t1; /* comment1 */ create /* comment2 */ procedure /* comment3 */ p1 (in country CHAR(3), out cities INT) begin select count(*) from x where d = e; end;select * from t2",
lenWanted: 3,
}, {
// Create procedure with definer current_user.
input: "create DEFINER=CURRENT_USER procedure p1 (in country CHAR(3)) begin declare abc DECIMAL(14,2); DECLARE def DECIMAL(14,2); end",
lenWanted: 1,
}, {
// Create procedure with definer current_user().
input: "create DEFINER=CURRENT_USER() procedure p1 (in country CHAR(3)) begin declare abc DECIMAL(14,2); DECLARE def DECIMAL(14,2); end",
lenWanted: 1,
}, {
// Create procedure with definer string.
input: "create DEFINER='root' procedure p1 (in country CHAR(3)) begin declare abc DECIMAL(14,2); DECLARE def DECIMAL(14,2); end",
lenWanted: 1,
}, {
// Create procedure with definer string at_id.
input: "create DEFINER='root'@localhost procedure p1 (in country CHAR(3)) begin declare abc DECIMAL(14,2); DECLARE def DECIMAL(14,2); end",
lenWanted: 1,
}, {
// Create procedure with definer id.
input: "create DEFINER=`root` procedure p1 (in country CHAR(3)) begin declare abc DECIMAL(14,2); DECLARE def DECIMAL(14,2); end",
lenWanted: 1,
}, {
// Create procedure with definer id at_id.
input: "create DEFINER=`root`@`localhost` procedure p1 (in country CHAR(3)) begin declare abc DECIMAL(14,2); DECLARE def DECIMAL(14,2); end",
lenWanted: 1,
},
{
// Create procedure with comments.
input: "select * from t1; /* comment1 */ create /* comment2 */ procedure /* comment3 */ p1 (in country CHAR(3), out cities INT) begin select count(*) from x where d = e; end;select * from t2",
lenWanted: 3,
},
}

parser := NewTestParser()
Expand Down
Loading