waveterm/pkg/util/migrateutil/migrateutil.go

68 lines
1.9 KiB
Go

// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package migrateutil
import (
"database/sql"
"fmt"
"io/fs"
"log"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/source/iofs"
sqlite3migrate "github.com/golang-migrate/migrate/v4/database/sqlite3"
)
func GetMigrateVersion(m *migrate.Migrate) (uint, bool, error) {
curVersion, dirty, err := m.Version()
if err == migrate.ErrNilVersion {
return 0, false, nil
}
return curVersion, dirty, err
}
func MakeMigrate(storeName string, db *sql.DB, migrationFS fs.FS, migrationsName string) (*migrate.Migrate, error) {
fsVar, err := iofs.New(migrationFS, migrationsName)
if err != nil {
return nil, fmt.Errorf("opening fs: %w", err)
}
mdriver, err := sqlite3migrate.WithInstance(db, &sqlite3migrate.Config{})
if err != nil {
return nil, fmt.Errorf("making %s migration driver: %w", storeName, err)
}
m, err := migrate.NewWithInstance("iofs", fsVar, "sqlite3", mdriver)
if err != nil {
return nil, fmt.Errorf("making %s migration: %w", storeName, err)
}
return m, nil
}
func Migrate(storeName string, db *sql.DB, migrationFS fs.FS, migrationsName string) error {
log.Printf("migrate %s\n", storeName)
m, err := MakeMigrate(storeName, db, migrationFS, migrationsName)
if err != nil {
return err
}
curVersion, dirty, err := GetMigrateVersion(m)
if dirty {
return fmt.Errorf("%s, migrate up, database is dirty", storeName)
}
if err != nil {
return fmt.Errorf("%s, cannot get current migration version: %v", storeName, err)
}
err = m.Up()
if err != nil && err != migrate.ErrNoChange {
return fmt.Errorf("migrating %s: %w", storeName, err)
}
newVersion, _, err := GetMigrateVersion(m)
if err != nil {
return fmt.Errorf("%s, cannot get new migration version: %v", storeName, err)
}
if newVersion != curVersion {
log.Printf("[db] %s migration done, version %d -> %d\n", storeName, curVersion, newVersion)
}
return nil
}