681 lines
16 KiB
Go
681 lines
16 KiB
Go
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
|
//
|
|
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
|
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
package mysql
|
|
|
|
import (
|
|
"crypto/sha1"
|
|
"crypto/tls"
|
|
"database/sql/driver"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
var (
|
|
errLog *log.Logger // Error Logger
|
|
tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs
|
|
|
|
errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?")
|
|
errInvalidDSNAddr = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)")
|
|
)
|
|
|
|
func init() {
|
|
errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile)
|
|
tlsConfigRegister = make(map[string]*tls.Config)
|
|
}
|
|
|
|
// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open.
|
|
// Use the key as a value in the DSN where tls=value.
|
|
//
|
|
// rootCertPool := x509.NewCertPool()
|
|
// pem, err := ioutil.ReadFile("/path/ca-cert.pem")
|
|
// if err != nil {
|
|
// log.Fatal(err)
|
|
// }
|
|
// if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
|
|
// log.Fatal("Failed to append PEM.")
|
|
// }
|
|
// clientCert := make([]tls.Certificate, 0, 1)
|
|
// certs, err := tls.LoadX509KeyPair("/path/client-cert.pem", "/path/client-key.pem")
|
|
// if err != nil {
|
|
// log.Fatal(err)
|
|
// }
|
|
// clientCert = append(clientCert, certs)
|
|
// mysql.RegisterTLSConfig("custom", &tls.Config{
|
|
// RootCAs: rootCertPool,
|
|
// Certificates: clientCert,
|
|
// })
|
|
// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom")
|
|
//
|
|
func RegisterTLSConfig(key string, config *tls.Config) error {
|
|
if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" {
|
|
return fmt.Errorf("Key '%s' is reserved", key)
|
|
}
|
|
|
|
tlsConfigRegister[key] = config
|
|
return nil
|
|
}
|
|
|
|
// DeregisterTLSConfig removes the tls.Config associated with key.
|
|
func DeregisterTLSConfig(key string) {
|
|
delete(tlsConfigRegister, key)
|
|
}
|
|
|
|
// parseDSN parses the DSN string to a config
|
|
func parseDSN(dsn string) (cfg *config, err error) {
|
|
cfg = new(config)
|
|
|
|
// TODO: use strings.IndexByte when we can depend on Go 1.2
|
|
|
|
// [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN]
|
|
// Find the last '/' (since the password or the net addr might contain a '/')
|
|
for i := len(dsn) - 1; i >= 0; i-- {
|
|
if dsn[i] == '/' {
|
|
var j, k int
|
|
|
|
// left part is empty if i <= 0
|
|
if i > 0 {
|
|
// [username[:password]@][protocol[(address)]]
|
|
// Find the last '@' in dsn[:i]
|
|
for j = i; j >= 0; j-- {
|
|
if dsn[j] == '@' {
|
|
// username[:password]
|
|
// Find the first ':' in dsn[:j]
|
|
for k = 0; k < j; k++ {
|
|
if dsn[k] == ':' {
|
|
cfg.passwd = dsn[k+1 : j]
|
|
break
|
|
}
|
|
}
|
|
cfg.user = dsn[:k]
|
|
|
|
break
|
|
}
|
|
}
|
|
|
|
// [protocol[(address)]]
|
|
// Find the first '(' in dsn[j+1:i]
|
|
for k = j + 1; k < i; k++ {
|
|
if dsn[k] == '(' {
|
|
// dsn[i-1] must be == ')' if an adress is specified
|
|
if dsn[i-1] != ')' {
|
|
if strings.ContainsRune(dsn[k+1:i], ')') {
|
|
return nil, errInvalidDSNUnescaped
|
|
}
|
|
return nil, errInvalidDSNAddr
|
|
}
|
|
cfg.addr = dsn[k+1 : i-1]
|
|
break
|
|
}
|
|
}
|
|
cfg.net = dsn[j+1 : k]
|
|
}
|
|
|
|
// dbname[?param1=value1&...¶mN=valueN]
|
|
// Find the first '?' in dsn[i+1:]
|
|
for j = i + 1; j < len(dsn); j++ {
|
|
if dsn[j] == '?' {
|
|
if err = parseDSNParams(cfg, dsn[j+1:]); err != nil {
|
|
return
|
|
}
|
|
break
|
|
}
|
|
}
|
|
cfg.dbname = dsn[i+1 : j]
|
|
|
|
break
|
|
}
|
|
}
|
|
|
|
// Set default network if empty
|
|
if cfg.net == "" {
|
|
cfg.net = "tcp"
|
|
}
|
|
|
|
// Set default adress if empty
|
|
if cfg.addr == "" {
|
|
switch cfg.net {
|
|
case "tcp":
|
|
cfg.addr = "127.0.0.1:3306"
|
|
case "unix":
|
|
cfg.addr = "/tmp/mysql.sock"
|
|
default:
|
|
return nil, errors.New("Default addr for network '" + cfg.net + "' unknown")
|
|
}
|
|
|
|
}
|
|
|
|
// Set default location if empty
|
|
if cfg.loc == nil {
|
|
cfg.loc = time.UTC
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// parseDSNParams parses the DSN "query string"
|
|
// Values must be url.QueryEscape'ed
|
|
func parseDSNParams(cfg *config, params string) (err error) {
|
|
for _, v := range strings.Split(params, "&") {
|
|
param := strings.SplitN(v, "=", 2)
|
|
if len(param) != 2 {
|
|
continue
|
|
}
|
|
|
|
// cfg params
|
|
switch value := param[1]; param[0] {
|
|
|
|
// Disable INFILE whitelist / enable all files
|
|
case "allowAllFiles":
|
|
var isBool bool
|
|
cfg.allowAllFiles, isBool = readBool(value)
|
|
if !isBool {
|
|
return fmt.Errorf("Invalid Bool value: %s", value)
|
|
}
|
|
|
|
// Switch "rowsAffected" mode
|
|
case "clientFoundRows":
|
|
var isBool bool
|
|
cfg.clientFoundRows, isBool = readBool(value)
|
|
if !isBool {
|
|
return fmt.Errorf("Invalid Bool value: %s", value)
|
|
}
|
|
|
|
// Use old authentication mode (pre MySQL 4.1)
|
|
case "allowOldPasswords":
|
|
var isBool bool
|
|
cfg.allowOldPasswords, isBool = readBool(value)
|
|
if !isBool {
|
|
return fmt.Errorf("Invalid Bool value: %s", value)
|
|
}
|
|
|
|
// Time Location
|
|
case "loc":
|
|
if value, err = url.QueryUnescape(value); err != nil {
|
|
return
|
|
}
|
|
cfg.loc, err = time.LoadLocation(value)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// Dial Timeout
|
|
case "timeout":
|
|
cfg.timeout, err = time.ParseDuration(value)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// TLS-Encryption
|
|
case "tls":
|
|
boolValue, isBool := readBool(value)
|
|
if isBool {
|
|
if boolValue {
|
|
cfg.tls = &tls.Config{}
|
|
}
|
|
} else {
|
|
if strings.ToLower(value) == "skip-verify" {
|
|
cfg.tls = &tls.Config{InsecureSkipVerify: true}
|
|
} else if tlsConfig, ok := tlsConfigRegister[value]; ok {
|
|
cfg.tls = tlsConfig
|
|
} else {
|
|
return fmt.Errorf("Invalid value / unknown config name: %s", value)
|
|
}
|
|
}
|
|
|
|
default:
|
|
// lazy init
|
|
if cfg.params == nil {
|
|
cfg.params = make(map[string]string)
|
|
}
|
|
|
|
if cfg.params[param[0]], err = url.QueryUnescape(value); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// Returns the bool value of the input.
|
|
// The 2nd return value indicates if the input was a valid bool value
|
|
func readBool(input string) (value bool, valid bool) {
|
|
switch input {
|
|
case "1", "true", "TRUE", "True":
|
|
return true, true
|
|
case "0", "false", "FALSE", "False":
|
|
return false, true
|
|
}
|
|
|
|
// Not a valid bool value
|
|
return
|
|
}
|
|
|
|
/******************************************************************************
|
|
* Authentication *
|
|
******************************************************************************/
|
|
|
|
// Encrypt password using 4.1+ method
|
|
func scramblePassword(scramble, password []byte) []byte {
|
|
if len(password) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// stage1Hash = SHA1(password)
|
|
crypt := sha1.New()
|
|
crypt.Write(password)
|
|
stage1 := crypt.Sum(nil)
|
|
|
|
// scrambleHash = SHA1(scramble + SHA1(stage1Hash))
|
|
// inner Hash
|
|
crypt.Reset()
|
|
crypt.Write(stage1)
|
|
hash := crypt.Sum(nil)
|
|
|
|
// outer Hash
|
|
crypt.Reset()
|
|
crypt.Write(scramble)
|
|
crypt.Write(hash)
|
|
scramble = crypt.Sum(nil)
|
|
|
|
// token = scrambleHash XOR stage1Hash
|
|
for i := range scramble {
|
|
scramble[i] ^= stage1[i]
|
|
}
|
|
return scramble
|
|
}
|
|
|
|
// Encrypt password using pre 4.1 (old password) method
|
|
// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c
|
|
type myRnd struct {
|
|
seed1, seed2 uint32
|
|
}
|
|
|
|
const myRndMaxVal = 0x3FFFFFFF
|
|
|
|
// Pseudo random number generator
|
|
func newMyRnd(seed1, seed2 uint32) *myRnd {
|
|
return &myRnd{
|
|
seed1: seed1 % myRndMaxVal,
|
|
seed2: seed2 % myRndMaxVal,
|
|
}
|
|
}
|
|
|
|
// Tested to be equivalent to MariaDB's floating point variant
|
|
// http://play.golang.org/p/QHvhd4qved
|
|
// http://play.golang.org/p/RG0q4ElWDx
|
|
func (r *myRnd) NextByte() byte {
|
|
r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal
|
|
r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal
|
|
|
|
return byte(uint64(r.seed1) * 31 / myRndMaxVal)
|
|
}
|
|
|
|
// Generate binary hash from byte string using insecure pre 4.1 method
|
|
func pwHash(password []byte) (result [2]uint32) {
|
|
var add uint32 = 7
|
|
var tmp uint32
|
|
|
|
result[0] = 1345345333
|
|
result[1] = 0x12345671
|
|
|
|
for _, c := range password {
|
|
// skip spaces and tabs in password
|
|
if c == ' ' || c == '\t' {
|
|
continue
|
|
}
|
|
|
|
tmp = uint32(c)
|
|
result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8)
|
|
result[1] += (result[1] << 8) ^ result[0]
|
|
add += tmp
|
|
}
|
|
|
|
// Remove sign bit (1<<31)-1)
|
|
result[0] &= 0x7FFFFFFF
|
|
result[1] &= 0x7FFFFFFF
|
|
|
|
return
|
|
}
|
|
|
|
// Encrypt password using insecure pre 4.1 method
|
|
func scrambleOldPassword(scramble, password []byte) []byte {
|
|
if len(password) == 0 {
|
|
return nil
|
|
}
|
|
|
|
scramble = scramble[:8]
|
|
|
|
hashPw := pwHash(password)
|
|
hashSc := pwHash(scramble)
|
|
|
|
r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1])
|
|
|
|
var out [8]byte
|
|
for i := range out {
|
|
out[i] = r.NextByte() + 64
|
|
}
|
|
|
|
mask := r.NextByte()
|
|
for i := range out {
|
|
out[i] ^= mask
|
|
}
|
|
|
|
return out[:]
|
|
}
|
|
|
|
/******************************************************************************
|
|
* Time related utils *
|
|
******************************************************************************/
|
|
|
|
// NullTime represents a time.Time that may be NULL.
|
|
// NullTime implements the Scanner interface so
|
|
// it can be used as a scan destination:
|
|
//
|
|
// var nt NullTime
|
|
// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt)
|
|
// ...
|
|
// if nt.Valid {
|
|
// // use nt.Time
|
|
// } else {
|
|
// // NULL value
|
|
// }
|
|
//
|
|
// This NullTime implementation is not driver-specific
|
|
type NullTime struct {
|
|
Time time.Time
|
|
Valid bool // Valid is true if Time is not NULL
|
|
}
|
|
|
|
// Scan implements the Scanner interface.
|
|
// The value type must be time.Time or string / []byte (formatted time-string),
|
|
// otherwise Scan fails.
|
|
func (nt *NullTime) Scan(value interface{}) (err error) {
|
|
if value == nil {
|
|
nt.Time, nt.Valid = time.Time{}, false
|
|
return
|
|
}
|
|
|
|
switch v := value.(type) {
|
|
case time.Time:
|
|
nt.Time, nt.Valid = v, true
|
|
return
|
|
case []byte:
|
|
nt.Time, err = parseDateTime(string(v), time.UTC)
|
|
nt.Valid = (err == nil)
|
|
return
|
|
case string:
|
|
nt.Time, err = parseDateTime(v, time.UTC)
|
|
nt.Valid = (err == nil)
|
|
return
|
|
}
|
|
|
|
nt.Valid = false
|
|
return fmt.Errorf("Can't convert %T to time.Time", value)
|
|
}
|
|
|
|
// Value implements the driver Valuer interface.
|
|
func (nt NullTime) Value() (driver.Value, error) {
|
|
if !nt.Valid {
|
|
return nil, nil
|
|
}
|
|
return nt.Time, nil
|
|
}
|
|
|
|
func parseDateTime(str string, loc *time.Location) (t time.Time, err error) {
|
|
switch len(str) {
|
|
case 10: // YYYY-MM-DD
|
|
if str == "0000-00-00" {
|
|
return
|
|
}
|
|
t, err = time.Parse(timeFormat[:10], str)
|
|
case 19: // YYYY-MM-DD HH:MM:SS
|
|
if str == "0000-00-00 00:00:00" {
|
|
return
|
|
}
|
|
t, err = time.Parse(timeFormat, str)
|
|
default:
|
|
err = fmt.Errorf("Invalid Time-String: %s", str)
|
|
return
|
|
}
|
|
|
|
// Adjust location
|
|
if err == nil && loc != time.UTC {
|
|
y, mo, d := t.Date()
|
|
h, mi, s := t.Clock()
|
|
t, err = time.Date(y, mo, d, h, mi, s, t.Nanosecond(), loc), nil
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) {
|
|
switch num {
|
|
case 0:
|
|
return time.Time{}, nil
|
|
case 4:
|
|
return time.Date(
|
|
int(binary.LittleEndian.Uint16(data[:2])), // year
|
|
time.Month(data[2]), // month
|
|
int(data[3]), // day
|
|
0, 0, 0, 0,
|
|
loc,
|
|
), nil
|
|
case 7:
|
|
return time.Date(
|
|
int(binary.LittleEndian.Uint16(data[:2])), // year
|
|
time.Month(data[2]), // month
|
|
int(data[3]), // day
|
|
int(data[4]), // hour
|
|
int(data[5]), // minutes
|
|
int(data[6]), // seconds
|
|
0,
|
|
loc,
|
|
), nil
|
|
case 11:
|
|
return time.Date(
|
|
int(binary.LittleEndian.Uint16(data[:2])), // year
|
|
time.Month(data[2]), // month
|
|
int(data[3]), // day
|
|
int(data[4]), // hour
|
|
int(data[5]), // minutes
|
|
int(data[6]), // seconds
|
|
int(binary.LittleEndian.Uint32(data[7:11]))*1000, // nanoseconds
|
|
loc,
|
|
), nil
|
|
}
|
|
return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num)
|
|
}
|
|
|
|
func formatBinaryDate(num uint64, data []byte) (driver.Value, error) {
|
|
switch num {
|
|
case 0:
|
|
return []byte("0000-00-00"), nil
|
|
case 4:
|
|
return []byte(fmt.Sprintf(
|
|
"%04d-%02d-%02d",
|
|
binary.LittleEndian.Uint16(data[:2]),
|
|
data[2],
|
|
data[3],
|
|
)), nil
|
|
}
|
|
return nil, fmt.Errorf("Invalid DATE-packet length %d", num)
|
|
}
|
|
|
|
func formatBinaryDateTime(num uint64, data []byte) (driver.Value, error) {
|
|
switch num {
|
|
case 0:
|
|
return []byte("0000-00-00 00:00:00"), nil
|
|
case 4:
|
|
return []byte(fmt.Sprintf(
|
|
"%04d-%02d-%02d 00:00:00",
|
|
binary.LittleEndian.Uint16(data[:2]),
|
|
data[2],
|
|
data[3],
|
|
)), nil
|
|
case 7:
|
|
return []byte(fmt.Sprintf(
|
|
"%04d-%02d-%02d %02d:%02d:%02d",
|
|
binary.LittleEndian.Uint16(data[:2]),
|
|
data[2],
|
|
data[3],
|
|
data[4],
|
|
data[5],
|
|
data[6],
|
|
)), nil
|
|
case 11:
|
|
return []byte(fmt.Sprintf(
|
|
"%04d-%02d-%02d %02d:%02d:%02d.%06d",
|
|
binary.LittleEndian.Uint16(data[:2]),
|
|
data[2],
|
|
data[3],
|
|
data[4],
|
|
data[5],
|
|
data[6],
|
|
binary.LittleEndian.Uint32(data[7:11]),
|
|
)), nil
|
|
}
|
|
return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num)
|
|
}
|
|
|
|
/******************************************************************************
|
|
* Convert from and to bytes *
|
|
******************************************************************************/
|
|
|
|
func uint64ToBytes(n uint64) []byte {
|
|
return []byte{
|
|
byte(n),
|
|
byte(n >> 8),
|
|
byte(n >> 16),
|
|
byte(n >> 24),
|
|
byte(n >> 32),
|
|
byte(n >> 40),
|
|
byte(n >> 48),
|
|
byte(n >> 56),
|
|
}
|
|
}
|
|
|
|
func uint64ToString(n uint64) []byte {
|
|
var a [20]byte
|
|
i := 20
|
|
|
|
// U+0030 = 0
|
|
// ...
|
|
// U+0039 = 9
|
|
|
|
var q uint64
|
|
for n >= 10 {
|
|
i--
|
|
q = n / 10
|
|
a[i] = uint8(n-q*10) + 0x30
|
|
n = q
|
|
}
|
|
|
|
i--
|
|
a[i] = uint8(n) + 0x30
|
|
|
|
return a[i:]
|
|
}
|
|
|
|
// treats string value as unsigned integer representation
|
|
func stringToInt(b []byte) int {
|
|
val := 0
|
|
for i := range b {
|
|
val *= 10
|
|
val += int(b[i] - 0x30)
|
|
}
|
|
return val
|
|
}
|
|
|
|
// returns the string read as a bytes slice, wheter the value is NULL,
|
|
// the number of bytes read and an error, in case the string is longer than
|
|
// the input slice
|
|
func readLengthEnodedString(b []byte) ([]byte, bool, int, error) {
|
|
// Get length
|
|
num, isNull, n := readLengthEncodedInteger(b)
|
|
if num < 1 {
|
|
return b[n:n], isNull, n, nil
|
|
}
|
|
|
|
n += int(num)
|
|
|
|
// Check data length
|
|
if len(b) >= n {
|
|
return b[n-int(num) : n], false, n, nil
|
|
}
|
|
return nil, false, n, io.EOF
|
|
}
|
|
|
|
// returns the number of bytes skipped and an error, in case the string is
|
|
// longer than the input slice
|
|
func skipLengthEnodedString(b []byte) (int, error) {
|
|
// Get length
|
|
num, _, n := readLengthEncodedInteger(b)
|
|
if num < 1 {
|
|
return n, nil
|
|
}
|
|
|
|
n += int(num)
|
|
|
|
// Check data length
|
|
if len(b) >= n {
|
|
return n, nil
|
|
}
|
|
return n, io.EOF
|
|
}
|
|
|
|
// returns the number read, whether the value is NULL and the number of bytes read
|
|
func readLengthEncodedInteger(b []byte) (uint64, bool, int) {
|
|
switch b[0] {
|
|
|
|
// 251: NULL
|
|
case 0xfb:
|
|
return 0, true, 1
|
|
|
|
// 252: value of following 2
|
|
case 0xfc:
|
|
return uint64(b[1]) | uint64(b[2])<<8, false, 3
|
|
|
|
// 253: value of following 3
|
|
case 0xfd:
|
|
return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4
|
|
|
|
// 254: value of following 8
|
|
case 0xfe:
|
|
return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 |
|
|
uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 |
|
|
uint64(b[7])<<48 | uint64(b[8])<<54,
|
|
false, 9
|
|
}
|
|
|
|
// 0-250: value of first byte
|
|
return uint64(b[0]), false, 1
|
|
}
|
|
|
|
// encodes a uint64 value and appends it to the given bytes slice
|
|
func appendLengthEncodedInteger(b []byte, n uint64) []byte {
|
|
switch {
|
|
case n <= 250:
|
|
return append(b, byte(n))
|
|
|
|
case n <= 0xffff:
|
|
return append(b, 0xfc, byte(n), byte(n>>8))
|
|
|
|
case n <= 0xffffff:
|
|
return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16))
|
|
}
|
|
return b
|
|
}
|