Integrate MigrationDriver to migrate.go

This commit is contained in:
Nurahmadie 2014-02-15 20:16:54 +07:00
parent d2eed594ea
commit 54a9544044

View file

@ -49,17 +49,41 @@ const deleteRevisionStmt = `
DELETE FROM migration where revision = ? DELETE FROM migration where revision = ?
` `
// Operation interface covers basic migration operations.
// Implementation details is specific for each database,
// see migrate/sqlite.go for implementation reference.
type Operation interface {
CreateTable(tableName string, args []string) (sql.Result, error)
RenameTable(tableName, newName string) (sql.Result, error)
DropTable(tableName string) (sql.Result, error)
AddColumn(tableName, columnSpec string) (sql.Result, error)
DropColumns(tableName string, columnsToDrop []string) (sql.Result, error)
RenameColumns(tableName string, columnChanges map[string]string) (sql.Result, error)
}
type Revision interface { type Revision interface {
Up(tx *sql.Tx) error Up(op Operation) error
Down(tx *sql.Tx) error Down(op Operation) error
Revision() int64 Revision() int64
} }
type MigrationDriver struct {
Tx *sql.Tx
}
type Migration struct { type Migration struct {
db *sql.DB db *sql.DB
revs []Revision revs []Revision
} }
var Driver func(tx *sql.Tx) Operation
func New(db *sql.DB) *Migration { func New(db *sql.DB) *Migration {
return &Migration{db: db} return &Migration{db: db}
} }
@ -119,12 +143,14 @@ func (m *Migration) up(target, current int64) error {
return err return err
} }
op := Driver(tx)
// loop through and execute revisions // loop through and execute revisions
for _, rev := range m.revs { for _, rev := range m.revs {
if rev.Revision() >= target { if rev.Revision() >= target {
current = rev.Revision() current = rev.Revision()
// execute the revision Upgrade. // execute the revision Upgrade.
if err := rev.Up(tx); err != nil { if err := rev.Up(op); err != nil {
log.Printf("Failed to upgrade to Revision Number %v\n", current) log.Printf("Failed to upgrade to Revision Number %v\n", current)
log.Println(err) log.Println(err)
return tx.Rollback() return tx.Rollback()
@ -150,6 +176,8 @@ func (m *Migration) down(target, current int64) error {
return err return err
} }
op := Driver(tx)
// reverse the list of revisions // reverse the list of revisions
revs := []Revision{} revs := []Revision{}
for _, rev := range m.revs { for _, rev := range m.revs {
@ -162,7 +190,7 @@ func (m *Migration) down(target, current int64) error {
if rev.Revision() > target { if rev.Revision() > target {
current = rev.Revision() current = rev.Revision()
// execute the revision Upgrade. // execute the revision Upgrade.
if err := rev.Down(tx); err != nil { if err := rev.Down(op); err != nil {
log.Printf("Failed to downgrade to Revision Number %v\n", current) log.Printf("Failed to downgrade to Revision Number %v\n", current)
log.Println(err) log.Println(err)
return tx.Rollback() return tx.Rollback()