336 lines
7.5 KiB
Go
336 lines
7.5 KiB
Go
package mssql
|
|
|
|
import (
|
|
"bytes"
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"os"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
type MockTransport struct {
|
|
bytes.Buffer
|
|
}
|
|
|
|
func (t *MockTransport) Close() error {
|
|
return nil
|
|
}
|
|
|
|
func TestSendLogin(t *testing.T) {
|
|
buf := newTdsBuffer(1024, new(MockTransport))
|
|
login := login{
|
|
TDSVersion: verTDS73,
|
|
PacketSize: 0x1000,
|
|
ClientProgVer: 0x01060100,
|
|
ClientPID: 100,
|
|
ClientTimeZone: -4 * 60,
|
|
ClientID: [6]byte{0x12, 0x34, 0x56, 0x78, 0x90, 0xab},
|
|
OptionFlags1: 0xe0,
|
|
OptionFlags3: 8,
|
|
HostName: "subdev1",
|
|
UserName: "test",
|
|
Password: "testpwd",
|
|
AppName: "appname",
|
|
ServerName: "servername",
|
|
CtlIntName: "library",
|
|
Language: "en",
|
|
Database: "database",
|
|
ClientLCID: 0x204,
|
|
AtchDBFile: "filepath",
|
|
}
|
|
err := sendLogin(buf, login)
|
|
if err != nil {
|
|
t.Error("sendLogin should succeed")
|
|
}
|
|
ref := []byte{
|
|
16, 1, 0, 222, 0, 0, 1, 0, 198 + 16, 0, 0, 0, 3, 0, 10, 115, 0, 16, 0, 0, 0, 1,
|
|
6, 1, 100, 0, 0, 0, 0, 0, 0, 0, 224, 0, 0, 8, 16, 255, 255, 255, 4, 2, 0,
|
|
0, 94, 0, 7, 0, 108, 0, 4, 0, 116, 0, 7, 0, 130, 0, 7, 0, 144, 0, 10, 0, 0,
|
|
0, 0, 0, 164, 0, 7, 0, 178, 0, 2, 0, 182, 0, 8, 0, 18, 52, 86, 120, 144, 171,
|
|
198, 0, 0, 0, 198, 0, 8, 0, 214, 0, 0, 0, 0, 0, 0, 0, 115, 0, 117, 0, 98,
|
|
0, 100, 0, 101, 0, 118, 0, 49, 0, 116, 0, 101, 0, 115, 0, 116, 0, 226, 165,
|
|
243, 165, 146, 165, 226, 165, 162, 165, 210, 165, 227, 165, 97, 0, 112,
|
|
0, 112, 0, 110, 0, 97, 0, 109, 0, 101, 0, 115, 0, 101, 0, 114, 0, 118, 0,
|
|
101, 0, 114, 0, 110, 0, 97, 0, 109, 0, 101, 0, 108, 0, 105, 0, 98, 0, 114,
|
|
0, 97, 0, 114, 0, 121, 0, 101, 0, 110, 0, 100, 0, 97, 0, 116, 0, 97, 0, 98,
|
|
0, 97, 0, 115, 0, 101, 0, 102, 0, 105, 0, 108, 0, 101, 0, 112, 0, 97, 0,
|
|
116, 0, 104, 0}
|
|
out := buf.buf[:buf.pos]
|
|
if !bytes.Equal(ref, out) {
|
|
t.Error("input output don't match")
|
|
fmt.Print(hex.Dump(ref))
|
|
fmt.Print(hex.Dump(out))
|
|
}
|
|
}
|
|
|
|
func TestSendSqlBatch(t *testing.T) {
|
|
addr := os.Getenv("HOST")
|
|
instance := os.Getenv("INSTANCE")
|
|
|
|
conn, err := connect(map[string]string{
|
|
"server": fmt.Sprintf("%s\\%s", addr, instance),
|
|
"user id": os.Getenv("SQLUSER"),
|
|
"password": os.Getenv("SQLPASSWORD"),
|
|
"database": os.Getenv("DATABASE"),
|
|
})
|
|
if err != nil {
|
|
t.Error("Open connection failed:", err.Error())
|
|
return
|
|
}
|
|
defer conn.buf.transport.Close()
|
|
|
|
headers := []headerStruct{
|
|
{hdrtype: dataStmHdrTransDescr,
|
|
data: transDescrHdr{0, 1}.pack()},
|
|
}
|
|
err = sendSqlBatch72(conn.buf, "select 1", headers)
|
|
if err != nil {
|
|
t.Error("Sending sql batch failed", err.Error())
|
|
return
|
|
}
|
|
|
|
ch := make(chan tokenStruct, 5)
|
|
go processResponse(conn, ch)
|
|
|
|
var lastRow []interface{}
|
|
loop:
|
|
for tok := range ch {
|
|
switch token := tok.(type) {
|
|
case doneStruct:
|
|
break loop
|
|
case []columnStruct:
|
|
conn.columns = token
|
|
case []interface{}:
|
|
lastRow = token
|
|
default:
|
|
fmt.Println("unknown token", tok)
|
|
}
|
|
}
|
|
|
|
switch value := lastRow[0].(type) {
|
|
case int32:
|
|
if value != 1 {
|
|
t.Error("Invalid value returned, should be 1", value)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func makeConnStr() string {
|
|
addr := os.Getenv("HOST")
|
|
instance := os.Getenv("INSTANCE")
|
|
user := os.Getenv("SQLUSER")
|
|
password := os.Getenv("SQLPASSWORD")
|
|
database := os.Getenv("DATABASE")
|
|
return fmt.Sprintf(
|
|
"Server=%s\\%s;User Id=%s;Password=%s;Database=%s;log=63",
|
|
addr, instance, user, password, database)
|
|
}
|
|
|
|
func open(t *testing.T) *sql.DB {
|
|
conn, err := sql.Open("mssql", makeConnStr())
|
|
if err != nil {
|
|
t.Error("Open connection failed:", err.Error())
|
|
return nil
|
|
}
|
|
return conn
|
|
}
|
|
|
|
func TestConnect(t *testing.T) {
|
|
conn, err := sql.Open("mssql", makeConnStr())
|
|
if err != nil {
|
|
t.Error("Open connection failed:", err.Error())
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
}
|
|
|
|
func TestBadConnect(t *testing.T) {
|
|
badDsns := []string{
|
|
//"Server=badhost",
|
|
fmt.Sprintf("Server=%s\\%s;User ID=baduser;Password=badpwd",
|
|
os.Getenv("HOST"), os.Getenv("INSTANCE")),
|
|
}
|
|
for _, badDsn := range badDsns {
|
|
conn, err := sql.Open("mssql", badDsn)
|
|
if err != nil {
|
|
t.Error("Open connection failed:", err.Error())
|
|
}
|
|
defer conn.Close()
|
|
err = conn.Ping()
|
|
if err == nil {
|
|
t.Error("Ping should fail for connection: ", badDsn)
|
|
}
|
|
}
|
|
}
|
|
|
|
func simpleQuery(conn *sql.DB, t *testing.T) (stmt *sql.Stmt) {
|
|
stmt, err := conn.Prepare("select 1 as a")
|
|
if err != nil {
|
|
t.Error("Prepare failed:", err.Error())
|
|
return nil
|
|
}
|
|
return stmt
|
|
}
|
|
|
|
func checkSimpleQuery(rows *sql.Rows, t *testing.T) {
|
|
numrows := 0
|
|
for rows.Next() {
|
|
var val int
|
|
err := rows.Scan(&val)
|
|
if err != nil {
|
|
t.Error("Scan failed:", err.Error())
|
|
}
|
|
if val != 1 {
|
|
t.Error("query should return 1")
|
|
}
|
|
numrows++
|
|
}
|
|
if numrows != 1 {
|
|
t.Error("query should return 1 row, returned", numrows)
|
|
}
|
|
}
|
|
|
|
func TestQuery(t *testing.T) {
|
|
conn := open(t)
|
|
if conn == nil {
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
stmt := simpleQuery(conn, t)
|
|
if stmt == nil {
|
|
return
|
|
}
|
|
defer stmt.Close()
|
|
|
|
rows, err := stmt.Query()
|
|
if err != nil {
|
|
t.Error("Query failed:", err.Error())
|
|
}
|
|
defer rows.Close()
|
|
|
|
columns, err := rows.Columns()
|
|
if err != nil {
|
|
t.Error("getting columns failed", err.Error())
|
|
}
|
|
if len(columns) != 1 && columns[0] != "a" {
|
|
t.Error("returned incorrect columns (expected ['a']):", columns)
|
|
}
|
|
|
|
checkSimpleQuery(rows, t)
|
|
}
|
|
|
|
func TestMultipleQueriesSequentialy(t *testing.T) {
|
|
|
|
conn := open(t)
|
|
defer conn.Close()
|
|
|
|
stmt, err := conn.Prepare("select 1 as a")
|
|
if err != nil {
|
|
t.Error("Prepare failed:", err.Error())
|
|
return
|
|
}
|
|
defer stmt.Close()
|
|
|
|
rows, err := stmt.Query()
|
|
if err != nil {
|
|
t.Error("Query failed:", err.Error())
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
checkSimpleQuery(rows, t)
|
|
|
|
rows, err = stmt.Query()
|
|
if err != nil {
|
|
t.Error("Query failed:", err.Error())
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
checkSimpleQuery(rows, t)
|
|
}
|
|
|
|
func TestMultipleQueryClose(t *testing.T) {
|
|
conn := open(t)
|
|
defer conn.Close()
|
|
|
|
stmt, err := conn.Prepare("select 1 as a")
|
|
if err != nil {
|
|
t.Error("Prepare failed:", err.Error())
|
|
return
|
|
}
|
|
defer stmt.Close()
|
|
|
|
rows, err := stmt.Query()
|
|
if err != nil {
|
|
t.Error("Query failed:", err.Error())
|
|
return
|
|
}
|
|
rows.Close()
|
|
|
|
rows, err = stmt.Query()
|
|
if err != nil {
|
|
t.Error("Query failed:", err.Error())
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
checkSimpleQuery(rows, t)
|
|
}
|
|
|
|
func TestPing(t *testing.T) {
|
|
conn := open(t)
|
|
defer conn.Close()
|
|
conn.Ping()
|
|
}
|
|
|
|
func TestSecureWithInvalidHostName(t *testing.T) {
|
|
dsn := makeConnStr() + ";Encrypt=true;TrustServerCertificate=false;hostNameInCertificate=foo.bar"
|
|
conn, err := sql.Open("mssql", dsn)
|
|
if err != nil {
|
|
t.Fatal("Open connection failed:", err.Error())
|
|
}
|
|
defer conn.Close()
|
|
err = conn.Ping()
|
|
if err == nil {
|
|
t.Fatal("Connected to fake foo.bar server")
|
|
}
|
|
}
|
|
|
|
func TestSecureConnection(t *testing.T) {
|
|
dsn := makeConnStr() + ";Encrypt=true;TrustServerCertificate=true"
|
|
conn, err := sql.Open("mssql", dsn)
|
|
if err != nil {
|
|
t.Fatal("Open connection failed:", err.Error())
|
|
}
|
|
defer conn.Close()
|
|
var msg string
|
|
err = conn.QueryRow("select 'secret'").Scan(&msg)
|
|
if err != nil {
|
|
t.Fatal("cannot scan value", err)
|
|
}
|
|
if msg != "secret" {
|
|
t.Fatal("expected secret, got: ", msg)
|
|
}
|
|
var secure bool
|
|
err = conn.QueryRow("select encrypt_option from sys.dm_exec_connections where session_id=@@SPID").Scan(&secure)
|
|
if err != nil {
|
|
t.Fatal("cannot scan value", err)
|
|
}
|
|
if !secure {
|
|
t.Fatal("connection is not encrypted")
|
|
}
|
|
}
|
|
|
|
func TestParseConnectParamsKeepAlive(t *testing.T) {
|
|
params := parseConnectionString("keepAlive=60")
|
|
parsedParams, err := parseConnectParams(params)
|
|
if err != nil {
|
|
t.Fatal("cannot parse params: ", err)
|
|
}
|
|
|
|
if parsedParams.keepAlive != time.Duration(60)*time.Second {
|
|
t.Fail()
|
|
}
|
|
}
|