diff --git a/cli/main.go b/cli/main.go index f019b5cb..4c727a97 100644 --- a/cli/main.go +++ b/cli/main.go @@ -3,10 +3,10 @@ package main import ( "flag" "fmt" - "strings" "os" "os/signal" "strconv" + "strings" "syscall" "time" diff --git a/database/cassandra/cassandra.go b/database/cassandra/cassandra.go index 3dd93793..041a4682 100644 --- a/database/cassandra/cassandra.go +++ b/database/cassandra/cassandra.go @@ -5,10 +5,11 @@ import ( "io" "io/ioutil" nurl "net/url" - "github.com/gocql/gocql" + "strconv" "time" + + "github.com/gocql/gocql" "github.com/mattes/migrate/database" - "strconv" ) func init() { @@ -20,8 +21,8 @@ var DefaultMigrationsTable = "schema_migrations" var dbLocked = false var ( - ErrNilConfig = fmt.Errorf("no config") - ErrNoKeyspace = fmt.Errorf("no keyspace provided") + ErrNilConfig = fmt.Errorf("no config") + ErrNoKeyspace = fmt.Errorf("no keyspace provided") ErrDatabaseDirty = fmt.Errorf("database is dirty") ) @@ -35,7 +36,7 @@ type Cassandra struct { isLocked bool // Open and WithInstance need to guarantee that config is never nil - config *Config + config *Config } func (p *Cassandra) Open(url string) (database.Driver, error) { @@ -111,7 +112,7 @@ func (p *Cassandra) Close() error { } func (p *Cassandra) Lock() error { - if (dbLocked) { + if dbLocked { return database.ErrLocked } dbLocked = true @@ -153,7 +154,6 @@ func (p *Cassandra) SetVersion(version int, dirty bool) error { return nil } - // Return current keyspace version func (p *Cassandra) Version() (version int, dirty bool, err error) { query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1` @@ -191,7 +191,6 @@ func (p *Cassandra) Drop() error { return nil } - // Ensure version table exists func (p *Cassandra) ensureVersionTable() error { err := p.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", p.config.MigrationsTable)).Exec() @@ -204,7 +203,6 @@ func (p *Cassandra) ensureVersionTable() error { return nil } - // ParseConsistency wraps gocql.ParseConsistency // to return an error instead of a panicking. func parseConsistency(consistencyStr string) (consistency gocql.Consistency, err error) { diff --git a/database/mysql/mysql.go b/database/mysql/mysql.go index f00f886e..f4a75808 100644 --- a/database/mysql/mysql.go +++ b/database/mysql/mysql.go @@ -23,10 +23,11 @@ func init() { var DefaultMigrationsTable = "schema_migrations" var ( - ErrDatabaseDirty = fmt.Errorf("database is dirty") - ErrNilConfig = fmt.Errorf("no config") - ErrNoDatabaseName = fmt.Errorf("no database name") - ErrAppendPEM = fmt.Errorf("failed to append PEM") + ErrDatabaseDirty = fmt.Errorf("database is dirty") + ErrNilConfig = fmt.Errorf("no config") + ErrNoDatabaseName = fmt.Errorf("no database name") + ErrAppendPEM = fmt.Errorf("failed to append PEM") + ErrDatabaseCouldNotBeCreated = fmt.Errorf("Database could not be created") ) type Config struct { @@ -89,6 +90,10 @@ func (m *Mysql) Open(url string) (database.Driver, error) { q.Set("multiStatements", "true") purl.RawQuery = q.Encode() + if err = ensureDatabaseExist(purl); err != nil { + return nil, err + } + db, err := sql.Open("mysql", strings.Replace( migrate.FilterCustomQuery(purl).String(), "mysql://", "", 1)) if err != nil { @@ -147,6 +152,28 @@ func (m *Mysql) Open(url string) (database.Driver, error) { return mx, nil } +func ensureDatabaseExist(url *nurl.URL) error { + urlCopy, _ := nurl.Parse(url.String()) // Copy + + database := urlCopy.Path[1:] + urlCopy.Path = "/" + + dbstring := strings.Replace(migrate.FilterCustomQuery(urlCopy).String(), "mysql://", "", 1) + db, err := sql.Open("mysql", dbstring) + if err != nil { + return err + } + + // Unfortunately placeholder parameters did not work here, so there might + // be a risk of SQL injection here if database URL isn't properly + // saniticed. + if _, err = db.Exec("CREATE DATABASE IF NOT EXISTS " + database); err != nil { + return err + } + + return db.Close() +} + func (m *Mysql) Close() error { return m.db.Close() }