476 lines
12 KiB
Go
476 lines
12 KiB
Go
package meddler
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
_ "github.com/mattn/go-sqlite3"
|
|
"reflect"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
var once sync.Once
|
|
var db *sql.DB
|
|
var when = time.Date(2013, 6, 23, 15, 30, 12, 0, time.UTC)
|
|
|
|
type Person struct {
|
|
ID int64 `meddler:"id,pk"`
|
|
Name string `meddler:"name"`
|
|
private int
|
|
Email string
|
|
Ephemeral int `meddler:"-"`
|
|
Age int `meddler:",zeroisnull"`
|
|
Opened time.Time `meddler:"opened,utctime"`
|
|
Closed time.Time `meddler:"closed,utctimez"`
|
|
Updated *time.Time `meddler:"updated,localtime"`
|
|
Height *int `meddler:"height"`
|
|
}
|
|
|
|
type HalfPerson struct {
|
|
ID int64 `meddler:"id,pk"`
|
|
private int
|
|
Ephemeral int `meddler:"-"`
|
|
Age int `meddler:",zeroisnull"`
|
|
Closed time.Time `meddler:"closed,utctimez"`
|
|
Updated *time.Time `meddler:"updated,localtime"`
|
|
}
|
|
|
|
type UintPerson struct {
|
|
ID uint64 `meddler:"id,pk"`
|
|
Name string `meddler:"name"`
|
|
private int
|
|
Email string
|
|
Ephemeral int `meddler:"-"`
|
|
Age int `meddler:",zeroisnull"`
|
|
Opened time.Time `meddler:"opened,utctime"`
|
|
Closed time.Time `meddler:"closed,utctimez"`
|
|
Updated *time.Time `meddler:"updated,localtime"`
|
|
Height *int `meddler:"height"`
|
|
}
|
|
|
|
const schema1 = `create table person (
|
|
id integer primary key,
|
|
name text not null,
|
|
Email text not null,
|
|
Age integer,
|
|
opened datetime not null,
|
|
closed datetime,
|
|
updated datetime,
|
|
height integer
|
|
)`
|
|
|
|
const schema2 = `create table item (
|
|
id integer primary key,
|
|
stuff text not null,
|
|
stuffz blob not null
|
|
)`
|
|
|
|
var aliceHeight int = 65
|
|
var alice = &Person{
|
|
Name: "Alice",
|
|
Email: "alice@alice.com",
|
|
Ephemeral: 12,
|
|
Age: 32,
|
|
Opened: when.Local(),
|
|
Closed: when,
|
|
Updated: &when,
|
|
Height: &aliceHeight,
|
|
}
|
|
|
|
var bob = &Person{
|
|
Name: "Bob",
|
|
Email: "bob@bob.com",
|
|
Opened: when,
|
|
}
|
|
|
|
func setup() {
|
|
var err error
|
|
|
|
// create the database
|
|
db, err = sql.Open("sqlite3", ":memory:")
|
|
if err != nil {
|
|
panic("error creating test database: " + err.Error())
|
|
}
|
|
|
|
// create the tables
|
|
if _, err = db.Exec(schema1); err != nil {
|
|
panic("error creating person table: " + err.Error())
|
|
}
|
|
if _, err = db.Exec(schema2); err != nil {
|
|
panic("error creating item table: " + err.Error())
|
|
}
|
|
}
|
|
|
|
func structFieldEqual(t *testing.T, elt *structField, ref *structField) {
|
|
if elt == nil {
|
|
t.Errorf("Missing field for %s", ref.column)
|
|
return
|
|
}
|
|
if elt.column != ref.column {
|
|
t.Errorf("Column %s column found as %v", ref.column, elt.column)
|
|
}
|
|
if elt.primaryKey != ref.primaryKey {
|
|
t.Errorf("Column %s primaryKey found as %v", ref.column, elt.primaryKey)
|
|
}
|
|
if elt.index != ref.index {
|
|
t.Errorf("Column %s index found as %v", ref.column, elt.index)
|
|
}
|
|
if elt.meddler != ref.meddler {
|
|
t.Errorf("Column %s meddler mismatch", ref.column)
|
|
}
|
|
}
|
|
|
|
func TestGetFields(t *testing.T) {
|
|
data, err := getFields(reflect.TypeOf((*Person)(nil)))
|
|
if err != nil {
|
|
t.Errorf("Error in getFields: %v", err)
|
|
return
|
|
}
|
|
|
|
// see if everything checks out
|
|
if len(data.fields) != 8 || len(data.columns) != 8 {
|
|
t.Errorf("Found %d/%d fields, expected 8", len(data.fields), len(data.columns))
|
|
}
|
|
structFieldEqual(t, data.fields[data.columns[0]], &structField{"id", 0, true, registry["identity"]})
|
|
structFieldEqual(t, data.fields[data.columns[1]], &structField{"name", 1, false, registry["identity"]})
|
|
structFieldEqual(t, data.fields[data.columns[2]], &structField{"Email", 3, false, registry["identity"]})
|
|
structFieldEqual(t, data.fields[data.columns[3]], &structField{"Age", 5, false, registry["zeroisnull"]})
|
|
structFieldEqual(t, data.fields[data.columns[4]], &structField{"opened", 6, false, registry["utctime"]})
|
|
structFieldEqual(t, data.fields[data.columns[5]], &structField{"closed", 7, false, registry["utctimez"]})
|
|
structFieldEqual(t, data.fields[data.columns[6]], &structField{"updated", 8, false, registry["localtime"]})
|
|
structFieldEqual(t, data.fields[data.columns[7]], &structField{"height", 9, false, registry["identity"]})
|
|
}
|
|
|
|
func personEqual(t *testing.T, elt *Person, ref *Person) {
|
|
if elt == nil {
|
|
t.Errorf("Person %s is nil", ref.Name)
|
|
return
|
|
}
|
|
if elt.ID != ref.ID {
|
|
t.Errorf("Person %s ID is %v", ref.Name, elt.ID)
|
|
}
|
|
if elt.Name != ref.Name {
|
|
t.Errorf("Person %s Name is %v", ref.Name, elt.Name)
|
|
}
|
|
if elt.private != ref.private {
|
|
t.Errorf("Person %s private is %v", ref.Name, elt.private)
|
|
}
|
|
if elt.Email != ref.Email {
|
|
t.Errorf("Person %s Email is %v", ref.Name, elt.Email)
|
|
}
|
|
if elt.Ephemeral != ref.Ephemeral {
|
|
t.Errorf("Person %s Ephemeral is %v", ref.Ephemeral, elt.Ephemeral)
|
|
}
|
|
if elt.Age != ref.Age {
|
|
t.Errorf("Person %s Age is %v", ref.Name, elt.Age)
|
|
}
|
|
if !elt.Opened.Equal(ref.Opened) {
|
|
t.Errorf("Person %s Opened is %v", ref.Name, elt.Opened)
|
|
}
|
|
if !elt.Closed.Equal(ref.Closed) {
|
|
t.Errorf("Person %s Closed is %v", ref.Name, elt.Closed)
|
|
}
|
|
if (elt.Updated == nil) != (ref.Updated == nil) {
|
|
t.Errorf("Person %s Updated == nil is %v", ref.Name, elt.Updated == nil)
|
|
} else if elt.Updated != nil && !elt.Updated.Equal(*ref.Updated) {
|
|
t.Errorf("Person %s Updated is %v", ref.Name, *elt.Updated)
|
|
}
|
|
if elt.Updated != nil {
|
|
zone, _ := elt.Updated.Zone()
|
|
local, _ := when.Local().Zone()
|
|
if zone != local {
|
|
t.Errorf("Person %s Updated in time zone %v, expected %v", ref.Name, zone, local)
|
|
}
|
|
}
|
|
if (elt.Height == nil) != (ref.Height == nil) {
|
|
t.Errorf("Person %s Height == nil is %v", ref.Name, elt.Height == nil)
|
|
} else if elt.Height != nil && *elt.Height != *ref.Height {
|
|
t.Errorf("Person %s Height is %v", ref.Name, *elt.Height)
|
|
}
|
|
}
|
|
|
|
func insertAliceBob(t *testing.T) {
|
|
// insert Alice as row #1
|
|
alice.ID = 0
|
|
if err := Insert(db, "person", alice); err != nil {
|
|
t.Errorf("Error inserting Alice: %v", err)
|
|
}
|
|
if alice.ID != 1 {
|
|
t.Errorf("Alice ID is %d, expecting 1", alice.ID)
|
|
}
|
|
|
|
// insert Bob as row #2
|
|
bob.ID = 0
|
|
if err := Insert(db, "person", bob); err != nil {
|
|
t.Errorf("Error inserting Bob: %v", err)
|
|
}
|
|
if bob.ID != 2 {
|
|
t.Errorf("Bob ID is %d, expecting 2", bob.ID)
|
|
}
|
|
}
|
|
|
|
func TestColumns(t *testing.T) {
|
|
once.Do(setup)
|
|
|
|
p := new(Person)
|
|
names, err := Columns(p, true)
|
|
if err != nil {
|
|
t.Errorf("Error getting Columns: %v", err)
|
|
}
|
|
|
|
expected := []string{"id", "name", "Email", "Age", "opened", "closed", "updated", "height"}
|
|
sort.Strings(expected)
|
|
|
|
if len(names) != len(expected) {
|
|
t.Errorf("Expected %d columns, got %d", len(expected), len(names))
|
|
}
|
|
sort.Strings(names)
|
|
for i := 0; i < len(expected); i++ {
|
|
if expected[i] != names[i] {
|
|
t.Errorf("Expected %s at position %d, got %s", expected[i], i, names[i])
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestColumnsQuoted(t *testing.T) {
|
|
once.Do(setup)
|
|
|
|
p := new(Person)
|
|
names, err := ColumnsQuoted(p, true)
|
|
if err != nil {
|
|
t.Errorf("Error getting ColumnsQuoted: %v", err)
|
|
}
|
|
|
|
lst := []string{"id", "name", "Email", "Age", "opened", "closed", "updated", "height"}
|
|
sort.Strings(lst)
|
|
for i, orig := range lst {
|
|
lst[i] = Default.quoted(orig)
|
|
}
|
|
expected := strings.Join(lst, ",")
|
|
|
|
if len(names) != len(expected) {
|
|
t.Errorf("Length mismatch: expected %d, got %d", len(expected), len(names))
|
|
}
|
|
|
|
fields := strings.Split(names, ",")
|
|
sort.Strings(fields)
|
|
names = strings.Join(fields, ",")
|
|
|
|
if expected != names {
|
|
t.Errorf("Mismatch: expected %s, got %s", expected, names)
|
|
}
|
|
}
|
|
|
|
func TestPrimaryKey(t *testing.T) {
|
|
p := new(Person)
|
|
p.ID = 56
|
|
name, val, err := PrimaryKey(p)
|
|
if err != nil {
|
|
t.Errorf("Error getting PrimaryKey: %v", err)
|
|
}
|
|
if name != "id" {
|
|
t.Errorf("Expected pk name to be id, found %s", name)
|
|
}
|
|
if val != 56 {
|
|
t.Errorf("Expected pk value to be 56, found %d", val)
|
|
}
|
|
|
|
p2 := new(UintPerson)
|
|
p2.ID = 56
|
|
name, val, err = PrimaryKey(p2)
|
|
if err != nil {
|
|
t.Errorf("Error getting PrimaryKey: %v", err)
|
|
}
|
|
if name != "id" {
|
|
t.Errorf("Expected pk name to be id, found %s", name)
|
|
}
|
|
if val != 56 {
|
|
t.Errorf("Expected pk value to be 56, found %d", val)
|
|
}
|
|
}
|
|
|
|
func TestSetPrimaryKey(t *testing.T) {
|
|
p := new(Person)
|
|
err := SetPrimaryKey(p, 14)
|
|
if err != nil {
|
|
t.Errorf("Error in SetPrimaryKey: %v", err)
|
|
}
|
|
if p.ID != 14 {
|
|
t.Errorf("Expected id to be 14, found %d", p.ID)
|
|
}
|
|
|
|
p2 := new(Person)
|
|
err = SetPrimaryKey(p2, 14)
|
|
if err != nil {
|
|
t.Errorf("Error in SetPrimaryKey: %v", err)
|
|
}
|
|
if p2.ID != 14 {
|
|
t.Errorf("Expected id to be 14, found %d", p2.ID)
|
|
}
|
|
}
|
|
|
|
func TestValues(t *testing.T) {
|
|
alice.ID = 15
|
|
lst, err := Values(alice, true)
|
|
if err != nil {
|
|
t.Errorf("Values error: %v", err)
|
|
}
|
|
|
|
if lst[0] != int64(15) {
|
|
t.Errorf("expected 15, got %v", lst[0])
|
|
}
|
|
if lst[1] != "Alice" {
|
|
t.Errorf("Expected Alice, got %v", lst[1])
|
|
}
|
|
if lst[2] != "alice@alice.com" {
|
|
t.Errorf("Expected alice@alice.com, got %v", lst[2])
|
|
}
|
|
if lst[3] != 32 {
|
|
t.Errorf("Expected 32, got %v", lst[3])
|
|
}
|
|
if lst[4] != when.UTC() {
|
|
t.Errorf("Expected %v, got %v", when.UTC(), lst[4])
|
|
}
|
|
if lst[5] != when.UTC() {
|
|
t.Errorf("Expected %v, got %v", when.UTC(), lst[5])
|
|
}
|
|
if lst[6] != when.UTC() {
|
|
t.Errorf("Expected %v, got %v", when.UTC(), lst[6])
|
|
}
|
|
if *(lst[7].(*int)) != aliceHeight {
|
|
t.Errorf("Expected %d, got %v", aliceHeight, lst[7])
|
|
}
|
|
|
|
lst, err = Values(alice, false)
|
|
if err != nil {
|
|
t.Errorf("Values error: %v", err)
|
|
}
|
|
if lst[0] != "Alice" {
|
|
t.Errorf("Expected Alice, got %v", lst[0])
|
|
}
|
|
}
|
|
|
|
func TestPlaceholders(t *testing.T) {
|
|
lst, err := MySQL.Placeholders(alice, true)
|
|
if err != nil {
|
|
t.Errorf("Error in Placeholders: %v", err)
|
|
}
|
|
if len(lst) != 8 {
|
|
t.Errorf("expected 8 items, found %d", len(lst))
|
|
}
|
|
for _, elt := range lst {
|
|
if elt != MySQL.Placeholder {
|
|
t.Errorf("expected %s, found %s", MySQL.Placeholder, elt)
|
|
}
|
|
}
|
|
|
|
lst, err = PostgreSQL.Placeholders(alice, false)
|
|
if err != nil {
|
|
t.Errorf("Error in Placeholders: %v", err)
|
|
}
|
|
if len(lst) != 7 {
|
|
t.Errorf("expected 7 items, found %d", len(lst))
|
|
}
|
|
for i, elt := range lst {
|
|
expected := fmt.Sprintf("$%d", i+1)
|
|
if expected != elt {
|
|
t.Errorf("expected %s, found %s", expected, elt)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestPlaceholdersString(t *testing.T) {
|
|
s, err := SQLite.PlaceholdersString(alice, false)
|
|
if err != nil {
|
|
t.Errorf("Error in PlaceholdersString: %v", err)
|
|
}
|
|
expected := "?,?,?,?,?,?,?"
|
|
if s != expected {
|
|
t.Errorf("expected %s, found %s", expected, s)
|
|
}
|
|
|
|
s, err = PostgreSQL.PlaceholdersString(alice, true)
|
|
if err != nil {
|
|
t.Errorf("Error in PlaceholdersString: %v", err)
|
|
}
|
|
expected = "$1,$2,$3,$4,$5,$6,$7,$8"
|
|
if s != expected {
|
|
t.Errorf("expected %s, found %s", expected, s)
|
|
}
|
|
}
|
|
|
|
func TestScanRow(t *testing.T) {
|
|
once.Do(setup)
|
|
insertAliceBob(t)
|
|
|
|
rows, err := db.Query("select * from person where id in (1,2) order by id")
|
|
if err != nil {
|
|
t.Errorf("DB error on query: %v", err)
|
|
return
|
|
}
|
|
|
|
alice := new(Person)
|
|
if err = Scan(rows, alice); err != nil {
|
|
t.Errorf("Scan error on Alice: %v", err)
|
|
return
|
|
}
|
|
|
|
bob := new(Person)
|
|
bob.Age = 50
|
|
bob.Closed = time.Now()
|
|
bob.private = 14
|
|
bob.Ephemeral = 16
|
|
if err = ScanRow(rows, bob); err != nil {
|
|
t.Errorf("ScanRow error on Bob: %v", err)
|
|
return
|
|
}
|
|
|
|
height := 65
|
|
personEqual(t, alice, &Person{1, "Alice", 0, "alice@alice.com", 0, 32, when, when, &when, &height})
|
|
personEqual(t, bob, &Person{2, "Bob", 14, "bob@bob.com", 16, 0, when, time.Time{}, nil, nil})
|
|
db.Exec("delete from person")
|
|
}
|
|
|
|
func TestScanAll(t *testing.T) {
|
|
once.Do(setup)
|
|
insertAliceBob(t)
|
|
|
|
rows, err := db.Query("select * from person order by id")
|
|
if err != nil {
|
|
t.Errorf("DB error on query: %v", err)
|
|
return
|
|
}
|
|
|
|
var lst []*Person
|
|
if err = ScanAll(rows, &lst); err != nil {
|
|
t.Errorf("ScanAll error: %v", err)
|
|
return
|
|
}
|
|
|
|
if len(lst) != 2 {
|
|
t.Errorf("ScanAll found %d rows, expected 2", len(lst))
|
|
return
|
|
}
|
|
|
|
height := 65
|
|
personEqual(t, lst[0], &Person{1, "Alice", 0, "alice@alice.com", 0, 32, when, when, &when, &height})
|
|
personEqual(t, lst[1], &Person{2, "Bob", 0, "bob@bob.com", 0, 0, when, time.Time{}, nil, nil})
|
|
db.Exec("delete from person")
|
|
}
|
|
|
|
func TestThrowAway(t *testing.T) {
|
|
once.Do(setup)
|
|
insertAliceBob(t)
|
|
|
|
Debug = false
|
|
hp := new(HalfPerson)
|
|
err := QueryRow(db, hp, "select * from person where id = 1")
|
|
if err != nil {
|
|
t.Errorf("QueryRow error: %v", err)
|
|
}
|
|
Debug = true
|
|
db.Exec("delete from person")
|
|
}
|