harness-drone/vendor/github.com/russross/meddler/meddler.go
2015-09-29 18:21:17 -07:00

347 lines
9.5 KiB
Go

package meddler
import (
"bytes"
"compress/gzip"
"encoding/gob"
"encoding/json"
"fmt"
"reflect"
"time"
)
// Meddler is the interface for a field meddler. Implementations can be
// registered to convert struct fields being loaded and saved in the database.
type Meddler interface {
// PreRead is called before a Scan operation. It is given a pointer to
// the raw struct field, and returns the value that will be given to
// the database driver.
PreRead(fieldAddr interface{}) (scanTarget interface{}, err error)
// PostRead is called after a Scan operation. It is given the value returned
// by PreRead and a pointer to the raw struct field. It is expected to fill
// in the struct field if the two are different.
PostRead(fieldAddr interface{}, scanTarget interface{}) error
// PreWrite is called before an Insert or Update operation. It is given
// a pointer to the raw struct field, and returns the value that will be
// given to the database driver.
PreWrite(field interface{}) (saveValue interface{}, err error)
}
// Register sets up a meddler type. Meddlers get a chance to meddle with the
// data being loaded or saved when a field is annotated with the name of the meddler.
// The registry is global.
func Register(name string, m Meddler) {
if name == "pk" {
panic("meddler.Register: pk cannot be used as a meddler name")
}
registry[name] = m
}
var registry = make(map[string]Meddler)
func init() {
Register("identity", IdentityMeddler(false))
Register("localtime", TimeMeddler{ZeroIsNull: false, Local: true})
Register("localtimez", TimeMeddler{ZeroIsNull: true, Local: true})
Register("utctime", TimeMeddler{ZeroIsNull: false, Local: false})
Register("utctimez", TimeMeddler{ZeroIsNull: true, Local: false})
Register("zeroisnull", ZeroIsNullMeddler(false))
Register("json", JSONMeddler(false))
Register("jsongzip", JSONMeddler(true))
Register("gob", GobMeddler(false))
Register("gobgzip", GobMeddler(true))
}
// IdentityMeddler is the default meddler, and it passes the original value through with
// no changes.
type IdentityMeddler bool
func (elt IdentityMeddler) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) {
return fieldAddr, nil
}
func (elt IdentityMeddler) PostRead(fieldAddr, scanTarget interface{}) error {
return nil
}
func (elt IdentityMeddler) PreWrite(field interface{}) (saveValue interface{}, err error) {
return field, nil
}
// TimeMeddler provides useful operations on time.Time fields. It can convert the zero time
// to and from a null column, and it can convert the time zone to UTC on save and to Local on load.
type TimeMeddler struct {
ZeroIsNull bool
Local bool
}
func (elt TimeMeddler) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) {
switch tgt := fieldAddr.(type) {
case *time.Time:
if elt.ZeroIsNull {
return &tgt, nil
}
return fieldAddr, nil
case **time.Time:
if elt.ZeroIsNull {
return nil, fmt.Errorf("meddler.TimeMeddler cannot be used on a *time.Time field, only time.Time")
}
return fieldAddr, nil
default:
return nil, fmt.Errorf("meddler.TimeMeddler.PreRead: unknown struct field type: %T", fieldAddr)
}
}
func (elt TimeMeddler) PostRead(fieldAddr, scanTarget interface{}) error {
switch tgt := fieldAddr.(type) {
case *time.Time:
if elt.ZeroIsNull {
src := scanTarget.(**time.Time)
if *src == nil {
*tgt = time.Time{}
} else if elt.Local {
*tgt = (*src).Local()
} else {
*tgt = (*src).UTC()
}
return nil
}
src := scanTarget.(*time.Time)
if elt.Local {
*tgt = src.Local()
} else {
*tgt = src.UTC()
}
return nil
case **time.Time:
if elt.ZeroIsNull {
return fmt.Errorf("meddler TimeMeddler cannot be used on a *time.Time field, only time.Time")
}
src := scanTarget.(**time.Time)
if *src == nil {
*tgt = nil
} else if elt.Local {
**src = (*src).Local()
*tgt = *src
} else {
**src = (*src).UTC()
*tgt = *src
}
return nil
default:
return fmt.Errorf("meddler.TimeMeddler.PostRead: unknown struct field type: %T", fieldAddr)
}
}
func (elt TimeMeddler) PreWrite(field interface{}) (saveValue interface{}, err error) {
switch tgt := field.(type) {
case time.Time:
if elt.ZeroIsNull && tgt.IsZero() {
return nil, nil
}
return tgt.UTC(), nil
case *time.Time:
if tgt == nil || elt.ZeroIsNull && tgt.IsZero() {
return nil, nil
}
return tgt.UTC(), nil
default:
return nil, fmt.Errorf("meddler.TimeMeddler.PreWrite: unknown struct field type: %T", field)
}
}
// ZeroIsNullMeddler converts zero value fields (integers both signed and unsigned, floats, complex numbers,
// and strings) to and from null database columns.
type ZeroIsNullMeddler bool
func (elt ZeroIsNullMeddler) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) {
// create a pointer to this element
// the database driver will set it to nil if the column value is null
return reflect.New(reflect.TypeOf(fieldAddr)).Interface(), nil
}
func (elt ZeroIsNullMeddler) PostRead(fieldAddr, scanTarget interface{}) error {
sv := reflect.ValueOf(scanTarget)
fv := reflect.ValueOf(fieldAddr)
if sv.Elem().IsNil() {
// null column, so set target to be zero value
fv.Elem().Set(reflect.Zero(fv.Elem().Type()))
} else {
// copy the value that scan found
fv.Elem().Set(sv.Elem().Elem())
}
return nil
}
func (elt ZeroIsNullMeddler) PreWrite(field interface{}) (saveValue interface{}, err error) {
val := reflect.ValueOf(field)
switch val.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if val.Int() == 0 {
return nil, nil
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if val.Uint() == 0 {
return nil, nil
}
case reflect.Float32, reflect.Float64:
if val.Float() == 0 {
return nil, nil
}
case reflect.Complex64, reflect.Complex128:
if val.Complex() == 0 {
return nil, nil
}
case reflect.String:
if val.String() == "" {
return nil, nil
}
default:
return nil, fmt.Errorf("ZeroIsNullMeddler.PreWrite: unknown struct field type: %T", field)
}
return field, nil
}
type JSONMeddler bool
func (zip JSONMeddler) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) {
// give a pointer to a byte buffer to grab the raw data
return new([]byte), nil
}
func (zip JSONMeddler) PostRead(fieldAddr, scanTarget interface{}) error {
ptr := scanTarget.(*[]byte)
if ptr == nil {
return fmt.Errorf("JSONMeddler.PostRead: nil pointer")
}
raw := *ptr
if zip {
// un-gzip and decode json
gzipReader, err := gzip.NewReader(bytes.NewReader(raw))
if err != nil {
return fmt.Errorf("Error creating gzip Reader: %v", err)
}
defer gzipReader.Close()
jsonDecoder := json.NewDecoder(gzipReader)
if err := jsonDecoder.Decode(fieldAddr); err != nil {
return fmt.Errorf("JSON decoder/gzip error: %v", err)
}
if err := gzipReader.Close(); err != nil {
return fmt.Errorf("Closing gzip reader: %v", err)
}
return nil
}
// decode json
jsonDecoder := json.NewDecoder(bytes.NewReader(raw))
if err := jsonDecoder.Decode(fieldAddr); err != nil {
return fmt.Errorf("JSON decode error: %v", err)
}
return nil
}
func (zip JSONMeddler) PreWrite(field interface{}) (saveValue interface{}, err error) {
buffer := new(bytes.Buffer)
if zip {
// json encode and gzip
gzipWriter := gzip.NewWriter(buffer)
defer gzipWriter.Close()
jsonEncoder := json.NewEncoder(gzipWriter)
if err := jsonEncoder.Encode(field); err != nil {
return nil, fmt.Errorf("JSON encoding/gzip error: %v", err)
}
if err := gzipWriter.Close(); err != nil {
return nil, fmt.Errorf("Closing gzip writer: %v", err)
}
return buffer.Bytes(), nil
}
// json encode
jsonEncoder := json.NewEncoder(buffer)
if err := jsonEncoder.Encode(field); err != nil {
return nil, fmt.Errorf("JSON encoding error: %v", err)
}
return buffer.Bytes(), nil
}
type GobMeddler bool
func (zip GobMeddler) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) {
// give a pointer to a byte buffer to grab the raw data
return new([]byte), nil
}
func (zip GobMeddler) PostRead(fieldAddr, scanTarget interface{}) error {
ptr := scanTarget.(*[]byte)
if ptr == nil {
return fmt.Errorf("GobMeddler.PostRead: nil pointer")
}
raw := *ptr
if zip {
// un-gzip and decode gob
gzipReader, err := gzip.NewReader(bytes.NewReader(raw))
if err != nil {
return fmt.Errorf("Error creating gzip Reader: %v", err)
}
defer gzipReader.Close()
gobDecoder := gob.NewDecoder(gzipReader)
if err := gobDecoder.Decode(fieldAddr); err != nil {
return fmt.Errorf("Gob decoder/gzip error: %v", err)
}
if err := gzipReader.Close(); err != nil {
return fmt.Errorf("Closing gzip reader: %v", err)
}
return nil
}
// decode gob
gobDecoder := gob.NewDecoder(bytes.NewReader(raw))
if err := gobDecoder.Decode(fieldAddr); err != nil {
return fmt.Errorf("Gob decode error: %v", err)
}
return nil
}
func (zip GobMeddler) PreWrite(field interface{}) (saveValue interface{}, err error) {
buffer := new(bytes.Buffer)
if zip {
// gob encode and gzip
gzipWriter := gzip.NewWriter(buffer)
defer gzipWriter.Close()
gobEncoder := gob.NewEncoder(gzipWriter)
if err := gobEncoder.Encode(field); err != nil {
return nil, fmt.Errorf("Gob encoding/gzip error: %v", err)
}
if err := gzipWriter.Close(); err != nil {
return nil, fmt.Errorf("Closing gzip writer: %v", err)
}
return buffer.Bytes(), nil
}
// gob encode
gobEncoder := gob.NewEncoder(buffer)
if err := gobEncoder.Encode(field); err != nil {
return nil, fmt.Errorf("Gob encoding error: %v", err)
}
return buffer.Bytes(), nil
}