Fix migration step not checked against current version.
Add tests for DropColumns.
This commit is contained in:
parent
54a9544044
commit
4465b2654d
4 changed files with 79 additions and 11 deletions
|
@ -53,7 +53,6 @@ DELETE FROM migration where revision = ?
|
|||
// Implementation details is specific for each database,
|
||||
// see migrate/sqlite.go for implementation reference.
|
||||
type Operation interface {
|
||||
|
||||
CreateTable(tableName string, args []string) (sql.Result, error)
|
||||
|
||||
RenameTable(tableName, newName string) (sql.Result, error)
|
||||
|
@ -147,7 +146,7 @@ func (m *Migration) up(target, current int64) error {
|
|||
|
||||
// loop through and execute revisions
|
||||
for _, rev := range m.revs {
|
||||
if rev.Revision() >= target {
|
||||
if rev.Revision() > current {
|
||||
current = rev.Revision()
|
||||
// execute the revision Upgrade.
|
||||
if err := rev.Up(op); err != nil {
|
||||
|
@ -191,7 +190,7 @@ func (m *Migration) down(target, current int64) error {
|
|||
current = rev.Revision()
|
||||
// execute the revision Upgrade.
|
||||
if err := rev.Down(op); err != nil {
|
||||
log.Printf("Failed to downgrade to Revision Number %v\n", current)
|
||||
log.Printf("Failed to downgrade from Revision Number %v\n", current)
|
||||
log.Println(err)
|
||||
return tx.Rollback()
|
||||
}
|
||||
|
@ -202,7 +201,7 @@ func (m *Migration) down(target, current int64) error {
|
|||
return tx.Rollback()
|
||||
}
|
||||
|
||||
log.Printf("Successfully downgraded to Revision %v\n", current)
|
||||
log.Printf("Successfully downgraded from Revision %v\n", current)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -5,8 +5,8 @@ import (
|
|||
"fmt"
|
||||
"strings"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/dchest/uniuri"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
type SQLiteDriver MigrationDriver
|
||||
|
@ -48,7 +48,8 @@ func (s *SQLiteDriver) DropColumns(tableName string, columnsToDrop []string) (sq
|
|||
}
|
||||
|
||||
columnNames := selectName(columns)
|
||||
preparedColumns := make([]string, len(columnNames)-len(columnsToDrop))
|
||||
|
||||
var preparedColumns []string
|
||||
for k, column := range columnNames {
|
||||
listed := false
|
||||
for _, dropped := range columnsToDrop {
|
||||
|
@ -98,8 +99,8 @@ func (s *SQLiteDriver) RenameColumns(tableName string, columnChanges map[string]
|
|||
return nil, err
|
||||
}
|
||||
|
||||
oldColumns := make([]string, len(columnChanges))
|
||||
newColumns := make([]string, len(columnChanges))
|
||||
var oldColumns []string
|
||||
var newColumns []string
|
||||
for k, column := range selectName(columns) {
|
||||
for Old, New := range columnChanges {
|
||||
if column == Old {
|
||||
|
@ -126,7 +127,7 @@ func (s *SQLiteDriver) RenameColumns(tableName string, columnChanges map[string]
|
|||
|
||||
func (s *SQLiteDriver) getDDLFromTable(tableName string) (string, error) {
|
||||
var sql string
|
||||
query := `SELECT sql FROM sqlite_master WHERE type='table' and name='?';`
|
||||
query := `SELECT sql FROM sqlite_master WHERE type='table' and name=?;`
|
||||
err := s.Tx.QueryRow(query, tableName).Scan(&sql)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
|
@ -76,6 +76,29 @@ func (r *revision2) Revision() int64 {
|
|||
|
||||
// ---------- end of revision 2
|
||||
|
||||
// ---------- revision 3
|
||||
|
||||
type revision3 struct{}
|
||||
|
||||
func (r *revision3) Up(op Operation) error {
|
||||
if _, err := op.AddColumn("samples", "url VARCHAR(255)"); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := op.AddColumn("samples", "likes INTEGER")
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *revision3) Down(op Operation) error {
|
||||
_, err := op.DropColumns("samples", []string{"likes", "url"})
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *revision3) Revision() int64 {
|
||||
return 3
|
||||
}
|
||||
|
||||
// ---------- end of revision 3
|
||||
|
||||
var db *sql.DB
|
||||
|
||||
var testSchema = `
|
||||
|
@ -144,6 +167,51 @@ func TestMigrateRenameTable(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
type TableInfo struct {
|
||||
CID int64 `meddler:"cid,pk"`
|
||||
Name string `meddler:"name"`
|
||||
Type string `meddler:"type"`
|
||||
Notnull bool `meddler:"notnull"`
|
||||
DfltValue interface{} `meddler:"dflt_value"`
|
||||
PK bool `meddler:"pk"`
|
||||
}
|
||||
|
||||
func TestMigrateAddRemoveColumns(t *testing.T) {
|
||||
defer tearDown()
|
||||
if err := setUp(); err != nil {
|
||||
t.Fatalf("Error preparing database: %q", err)
|
||||
}
|
||||
|
||||
Driver = SQLite
|
||||
|
||||
mgr := New(db)
|
||||
if err := mgr.Add(&revision1{}).Add(&revision3{}).Migrate(); err != nil {
|
||||
t.Errorf("Can not migrate: %q", err)
|
||||
}
|
||||
|
||||
var columns []*TableInfo
|
||||
if err := meddler.QueryAll(db, &columns, `PRAGMA table_info(samples);`); err != nil {
|
||||
t.Errorf("Can not access table info: %q", err)
|
||||
}
|
||||
|
||||
if len(columns) < 5 {
|
||||
t.Errorf("Expect length columns: %d\nGot: %d", 5, len(columns))
|
||||
}
|
||||
|
||||
if err := mgr.MigrateTo(1); err != nil {
|
||||
t.Errorf("Can not migrate: %q", err)
|
||||
}
|
||||
|
||||
var another_columns []*TableInfo
|
||||
if err := meddler.QueryAll(db, &another_columns, `PRAGMA table_info(samples);`); err != nil {
|
||||
t.Errorf("Can not access table info: %q", err)
|
||||
}
|
||||
|
||||
if len(another_columns) != 3 {
|
||||
t.Errorf("Expect length columns: %d\nGot: %d", 3, len(columns))
|
||||
}
|
||||
}
|
||||
|
||||
func setUp() error {
|
||||
var err error
|
||||
db, err = sql.Open("sqlite3", "migration_tests.sqlite")
|
||||
|
|
|
@ -15,7 +15,7 @@ func fetchColumns(sql string) ([]string, error) {
|
|||
}
|
||||
|
||||
func selectName(columns []string) []string {
|
||||
results := make([]string, len(columns))
|
||||
var results []string
|
||||
for _, column := range columns {
|
||||
col := strings.SplitN(strings.Trim(column, " \n\t"), " ", 2)
|
||||
results = append(results, col[0])
|
||||
|
@ -24,7 +24,7 @@ func selectName(columns []string) []string {
|
|||
}
|
||||
|
||||
func setForUpdate(left []string, right []string) string {
|
||||
results := make([]string, len(left))
|
||||
var results []string
|
||||
for k, str := range left {
|
||||
results = append(results, fmt.Sprintf("%s = %s", str, right[k]))
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue