Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 159 additions & 75 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func (mc *mysqlConn) close() {
}

// Closes the network connection and unsets internal variables. Do not call this
// function after successfully authentication, call Close instead. This function
// function after successful authentication, call Close instead. This function
// is called before auth or on auth failure because MySQL will have already
// closed the network connection.
func (mc *mysqlConn) cleanup() {
Expand Down Expand Up @@ -246,100 +246,184 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
}

func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
// Number of ? should be same to len(args)
if strings.Count(query, "?") != len(args) {
return "", driver.ErrSkip
}
noBackslashEscapes := (mc.status & statusNoBackslashEscapes) != 0
const (
stateNormal = iota
stateString
stateEscape
stateEOLComment
stateSlashStarComment
stateBacktick
)

const (
QUOTE_BYTE = byte('\'')
DBL_QUOTE_BYTE = byte('"')
BACKSLASH_BYTE = byte('\\')
QUESTION_MARK_BYTE = byte('?')
SLASH_BYTE = byte('/')
STAR_BYTE = byte('*')
HASH_BYTE = byte('#')
MINUS_BYTE = byte('-')
LINE_FEED_BYTE = byte('\n')
BACKTICK_BYTE = byte('`')
)

buf, err := mc.buf.takeCompleteBuffer()
if err != nil {
// can not take the buffer. Something must be wrong with the connection
mc.cleanup()
// interpolateParams would be called before sending any query.
// So its safe to retry.
return "", driver.ErrBadConn
}
buf = buf[:0]
state := stateNormal
singleQuotes := false
lastChar := byte(0)
argPos := 0

for i := 0; i < len(query); i++ {
q := strings.IndexByte(query[i:], '?')
if q == -1 {
buf = append(buf, query[i:]...)
break
}
buf = append(buf, query[i:i+q]...)
i += q

arg := args[argPos]
argPos++

if arg == nil {
buf = append(buf, "NULL"...)
lenQuery := len(query)
lastIdx := 0

for i := 0; i < lenQuery; i++ {
currentChar := query[i]
if state == stateEscape && !((currentChar == QUOTE_BYTE && singleQuotes) || (currentChar == DBL_QUOTE_BYTE && !singleQuotes)) {
state = stateString
lastChar = currentChar
continue
}

switch v := arg.(type) {
case int64:
buf = strconv.AppendInt(buf, v, 10)
case uint64:
// Handle uint64 explicitly because our custom ConvertValue emits unsigned values
buf = strconv.AppendUint(buf, v, 10)
case float64:
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
case bool:
if v {
buf = append(buf, '1')
} else {
buf = append(buf, '0')
switch currentChar {
case STAR_BYTE:
if state == stateNormal && lastChar == SLASH_BYTE {
state = stateSlashStarComment
}
case time.Time:
if v.IsZero() {
buf = append(buf, "'0000-00-00'"...)
} else {
buf = append(buf, '\'')
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
if err != nil {
return "", err
}
buf = append(buf, '\'')
case SLASH_BYTE:
if state == stateSlashStarComment && lastChar == STAR_BYTE {
state = stateNormal
// Clear lastChar so the '/' that closed the comment isn't
// reused to start a new comment with a following '*'.
lastChar = 0
continue
}
case json.RawMessage:
buf = append(buf, '\'')
if mc.status&statusNoBackslashEscapes == 0 {
buf = escapeBytesBackslash(buf, v)
} else {
buf = escapeBytesQuotes(buf, v)
case HASH_BYTE:
if state == stateNormal {
state = stateEOLComment
}
buf = append(buf, '\'')
case []byte:
if v == nil {
buf = append(buf, "NULL"...)
} else {
buf = append(buf, "_binary'"...)
if mc.status&statusNoBackslashEscapes == 0 {
buf = escapeBytesBackslash(buf, v)
case MINUS_BYTE:
if state == stateNormal && lastChar == MINUS_BYTE {
Comment thread
methane marked this conversation as resolved.
// -- only starts a comment if followed by whitespace or control char
if i+1 < lenQuery {
nextChar := query[i+1]
if nextChar == ' ' || nextChar == '\t' || nextChar == '\n' || nextChar == '\r' {
state = stateEOLComment
}
} else {
buf = escapeBytesQuotes(buf, v)
state = stateEOLComment
}
buf = append(buf, '\'')
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
case string:
buf = append(buf, '\'')
if mc.status&statusNoBackslashEscapes == 0 {
buf = escapeStringBackslash(buf, v)
} else {
buf = escapeStringQuotes(buf, v)
case LINE_FEED_BYTE:
if state == stateEOLComment {
state = stateNormal
}
buf = append(buf, '\'')
default:
return "", driver.ErrSkip
}
case DBL_QUOTE_BYTE:
if state == stateNormal {
state = stateString
singleQuotes = false
} else if state == stateString && !singleQuotes {
state = stateNormal
} else if state == stateEscape {
state = stateString
}
case QUOTE_BYTE:
if state == stateNormal {
state = stateString
singleQuotes = true
} else if state == stateString && singleQuotes {
state = stateNormal
} else if state == stateEscape {
state = stateString
}
case BACKSLASH_BYTE:
if state == stateString && !noBackslashEscapes {
state = stateEscape
}
case QUESTION_MARK_BYTE:
if state == stateNormal {
if argPos >= len(args) {
return "", driver.ErrSkip
}
buf = append(buf, query[lastIdx:i]...)
arg := args[argPos]
argPos++

if arg == nil {
buf = append(buf, "NULL"...)
lastIdx = i + 1
break
}

switch v := arg.(type) {
case int64:
buf = strconv.AppendInt(buf, v, 10)
case uint64:
buf = strconv.AppendUint(buf, v, 10)
case float64:
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
case bool:
if v {
buf = append(buf, '1')
} else {
buf = append(buf, '0')
}
case time.Time:
if v.IsZero() {
buf = append(buf, "'0000-00-00'"...)
} else {
buf = append(buf, '\'')
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
if err != nil {
return "", err
}
buf = append(buf, '\'')
}
case json.RawMessage:
if noBackslashEscapes {
buf = escapeBytesQuotes(buf, v, false)
} else {
buf = escapeBytesBackslash(buf, v, false)
}
case []byte:
if v == nil {
buf = append(buf, "NULL"...)
} else {
if noBackslashEscapes {
buf = escapeBytesQuotes(buf, v, true)
} else {
buf = escapeBytesBackslash(buf, v, true)
}
}
case string:
if noBackslashEscapes {
buf = escapeStringQuotes(buf, v)
} else {
buf = escapeStringBackslash(buf, v)
}
default:
return "", driver.ErrSkip
}

if len(buf)+4 > mc.maxAllowedPacket {
return "", driver.ErrSkip
if len(buf)+4 > mc.maxAllowedPacket {
return "", driver.ErrSkip
}
lastIdx = i + 1
}
case BACKTICK_BYTE:
if state == stateBacktick {
state = stateNormal
} else if state == stateNormal {
state = stateBacktick
}
}
lastChar = currentChar
}
buf = append(buf, query[lastIdx:]...)
if argPos != len(args) {
return "", driver.ErrSkip
}
Expand Down
76 changes: 58 additions & 18 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,24 +80,6 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) {
}
}

// We don't support placeholder in string literal for now.
// https://github.com/go-sql-driver/mysql/pull/490
func TestInterpolateParamsPlaceholderInString(t *testing.T) {
mc := &mysqlConn{
buf: newBuffer(),
maxAllowedPacket: maxPacketSize,
cfg: &Config{
InterpolateParams: true,
},
}

q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)})
// When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42`
if err != driver.ErrSkip {
t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q)
}
}

func TestInterpolateParamsUint64(t *testing.T) {
mc := &mysqlConn{
buf: newBuffer(),
Expand Down Expand Up @@ -206,6 +188,64 @@ func (bc badConnection) Close() error {
return nil
}

func TestInterpolateParamsWithComments(t *testing.T) {
mc := &mysqlConn{
buf: newBuffer(),
maxAllowedPacket: maxPacketSize,
cfg: &Config{
InterpolateParams: true,
},
}

tests := []struct {
query string
args []driver.Value
expected string
shouldSkip bool
}{
// ? in single-line comment (--) should not be replaced
{"SELECT 1 -- ?\n, ?", []driver.Value{int64(42)}, "SELECT 1 -- ?\n, 42", false},
// ? in single-line comment (#) should not be replaced
{"SELECT 1 # ?\n, ?", []driver.Value{int64(42)}, "SELECT 1 # ?\n, 42", false},
// ? in multi-line comment should not be replaced
{"SELECT /* ? */ ?", []driver.Value{int64(42)}, "SELECT /* ? */ 42", false},
// ? in string literal should not be replaced
{"SELECT '?', ?", []driver.Value{int64(42)}, "SELECT '?', 42", false},
// ? in backtick identifier should not be replaced
{"SELECT `?`, ?", []driver.Value{int64(42)}, "SELECT `?`, 42", false},
// ? in backslash-escaped string literal should not be replaced
{"SELECT 'C:\\path\\?x.txt', ?", []driver.Value{int64(42)}, "SELECT 'C:\\path\\?x.txt', 42", false},
// ? in backslash-escaped string literal should not be replaced
{"SELECT '\\'?', col FROM tbl WHERE id = ? AND desc = 'foo\\'bar?'", []driver.Value{int64(42)}, "SELECT '\\'?', col FROM tbl WHERE id = 42 AND desc = 'foo\\'bar?'", false},
// Multiple comments and real placeholders
{"SELECT ? -- comment ?\n, ? /* ? */ , ? # ?\n, ?", []driver.Value{int64(1), int64(2), int64(3)}, "SELECT 1 -- comment ?\n, 2 /* ? */ , 3 # ?\n, ?", true},
// 2--1: -- followed by digit is NOT a comment (it's the number 2 minus minus 1)
{"SELECT ?--1", []driver.Value{int64(2)}, "SELECT 2--1", false},
// /* */*: After closing block comment, */* should NOT start a new comment
{"SELECT /* comment */* ?, ?", []driver.Value{int64(1), int64(2)}, "SELECT /* comment */* 1, 2", false},
// /* */*: More complex case with actual comment after
{"SELECT /* c1 */*/* c2 */ ?, ?", []driver.Value{int64(1), int64(2)}, "SELECT /* c1 */*/* c2 */ 1, 2", false},
}

for i, test := range tests {

q, err := mc.interpolateParams(test.query, test.args)
if test.shouldSkip {
if err != driver.ErrSkip {
t.Errorf("Test %d: Expected driver.ErrSkip, got err=%#v, q=%#v", i, err, q)
}
continue
}
if err != nil {
t.Errorf("Test %d: Expected err=nil, got %#v", i, err)
continue
}
if q != test.expected {
t.Errorf("Test %d: Expected: %q\nGot: %q", i, test.expected, q)
}
}
}

// chunkedConn is a net.Conn that serves pre-built data chunks, one per Read
// call. This simulates the behavior seen with TLS connections, where the
// server's TLS library typically produces a separate TLS record per write
Expand Down
Loading
Loading