From 7b4c698e79cc38698b8ed0e3721535ffc9383908 Mon Sep 17 00:00:00 2001 From: Nurahmadie Date: Wed, 19 Feb 2014 07:38:53 +0700 Subject: [PATCH] Refactor columns matching method --- pkg/database/migrate/sqlite.go | 48 +++++++++++++++----- pkg/database/migrate/sqlite_test.go | 70 +++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+), 12 deletions(-) diff --git a/pkg/database/migrate/sqlite.go b/pkg/database/migrate/sqlite.go index cbd8c776..2cec5a02 100644 --- a/pkg/database/migrate/sqlite.go +++ b/pkg/database/migrate/sqlite.go @@ -87,13 +87,25 @@ func (s *SQLiteDriver) DropColumns(tableName string, columnsToDrop []string) (sq return nil, err } - var indices []string + var oldIdxColumns [][]string for _, idx := range oldSQLIndices { + idxCols, err := fetchColumns(idx) + if err != nil { + return nil, err + } + oldIdxColumns = append(oldIdxColumns, idxCols) + } + + var indices []string + for k, idx := range oldSQLIndices { listed := false - for _, cols := range columnsToDrop { - if strings.Contains(idx, cols) { - listed = true - break + OIdxLoop: + for _, oidx := range oldIdxColumns[k] { + for _, cols := range columnsToDrop { + if oidx == cols { + listed = true + break OIdxLoop + } } } if !listed { @@ -173,15 +185,27 @@ func (s *SQLiteDriver) RenameColumns(tableName string, columnChanges map[string] return nil, err } - var indices []string + var idxColumns [][]string for _, idx := range oldSQLIndices { + idxCols, err := fetchColumns(idx) + if err != nil { + return nil, err + } + idxColumns = append(idxColumns, idxCols) + } + + var indices []string + for k, idx := range oldSQLIndices { added := false - for Old, New := range columnChanges { - if strings.Contains(idx, Old) { - indx := strings.Replace(idx, Old, New, 2) - indices = append(indices, indx) - added = true - break + IdcLoop: + for _, oldIdx := range idxColumns[k] { + for Old, New := range columnChanges { + if oldIdx == Old { + indx := strings.Replace(idx, Old, New, 2) + indices = append(indices, indx) + added = true + break IdcLoop + } } } if !added { diff --git a/pkg/database/migrate/sqlite_test.go b/pkg/database/migrate/sqlite_test.go index 651ecbbc..c26d5b9b 100644 --- a/pkg/database/migrate/sqlite_test.go +++ b/pkg/database/migrate/sqlite_test.go @@ -3,6 +3,7 @@ package migrate import ( "database/sql" "os" + "strings" "testing" "github.com/russross/meddler" @@ -184,6 +185,51 @@ func (r *revision7) Revision() int64 { // ---------- end of revision 7 +// ---------- revision 8 +type revision8 struct{} + +func (r *revision8) Up(op Operation) error { + if _, err := op.AddColumn("samples", "repo_id INTEGER"); err != nil { + return err + } + _, err := op.AddColumn("samples", "repo VARCHAR(255)") + return err +} + +func (r *revision8) Down(op Operation) error { + _, err := op.DropColumns("samples", []string{"repo", "repo_id"}) + return err +} + +func (r *revision8) Revision() int64 { + return 8 +} + +// ---------- end of revision 8 + +// ---------- revision 9 +type revision9 struct{} + +func (r *revision9) Up(op Operation) error { + _, err := op.RenameColumns("samples", map[string]string{ + "repo": "repository", + }) + return err +} + +func (r *revision9) Down(op Operation) error { + _, err := op.RenameColumns("samples", map[string]string{ + "repository": "repo", + }) + return err +} + +func (r *revision9) Revision() int64 { + return 9 +} + +// ---------- end of revision 9 + var db *sql.DB var testSchema = ` @@ -430,6 +476,30 @@ func TestIndexOperations(t *testing.T) { } } +func TestColumnRedundancy(t *testing.T) { + defer tearDown() + if err := setUp(); err != nil { + t.Fatalf("Error preparing database: %q", err) + } + + Driver = SQLite + + migr := New(db) + if err := migr.Add(&revision1{}, &revision8{}, &revision9{}).Migrate(); err != nil { + t.Errorf("Can not migrate: %q", err) + } + + var tableSql string + query := `SELECT sql FROM sqlite_master where type='table' and name='samples'` + if err := db.QueryRow(query).Scan(&tableSql); err != nil { + t.Errorf("Can not query sqlite_master: %q", err) + } + + if !strings.Contains(tableSql, "repository ") { + t.Errorf("Expect column with name repository") + } +} + func setUp() error { var err error db, err = sql.Open("sqlite3", "migration_tests.sqlite")