add migration commands into sh2-server (because migrate cli doesn't ship with sqlite3)

This commit is contained in:
sawka 2022-07-01 10:48:14 -07:00
parent 7340d89089
commit 3f01ff44c3
7 changed files with 2069 additions and 157 deletions

View File

@ -8,7 +8,6 @@ import (
"io/fs"
"net/http"
"os"
"os/exec"
"strings"
"sync"
"time"
@ -19,7 +18,7 @@ import (
"github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/mshell/pkg/cmdtail"
"github.com/scripthaus-dev/mshell/pkg/packet"
"github.com/scripthaus-dev/mshell/pkg/shexec"
"github.com/scripthaus-dev/sh2-server/pkg/remote"
"github.com/scripthaus-dev/sh2-server/pkg/sstore"
"github.com/scripthaus-dev/sh2-server/pkg/wsshell"
)
@ -36,7 +35,7 @@ const WSStatePacketChSize = 20
const MaxInputDataSize = 1000
var GlobalMShellProc *MShellProc
var GlobalMShellProc *remote.MShellProc
var GlobalLock = &sync.Mutex{}
var WSStateMap = make(map[string]*WSState) // clientid -> WsState
@ -158,39 +157,6 @@ func (ws *WSState) replaceExistingShell(shell *wsshell.WSShell) {
return
}
type RpcEntry struct {
PacketId string
RespCh chan packet.RpcPacketType
}
type MShellProc struct {
Lock *sync.Mutex
Cmd *exec.Cmd
Input *packet.PacketSender
Output *packet.PacketParser
Local bool
DoneCh chan bool
CurDir string
HomeDir string
User string
Host string
Env []string
Initialized bool
RpcMap map[string]*RpcEntry
}
func (r *MShellProc) GetPrompt() string {
r.Lock.Lock()
defer r.Lock.Unlock()
var curDir = r.CurDir
if r.CurDir == r.HomeDir {
curDir = "~"
} else if strings.HasPrefix(r.CurDir, r.HomeDir+"/") {
curDir = "~/" + r.CurDir[0:len(r.HomeDir)+1]
}
return fmt.Sprintf("[%s@%s %s]", r.User, r.Host, curDir)
}
func HandleWs(w http.ResponseWriter, r *http.Request) {
shell, err := wsshell.StartWS(w, r)
if err != nil {
@ -401,7 +367,7 @@ func HandleRunCommand(w http.ResponseWriter, r *http.Request) {
return
}
rtnLine := sstore.MakeNewLineCmd(params.SessionId, params.WindowId)
rtnLine.CmdText = commandStr
// rtnLine.CmdText = commandStr
runPacket := packet.MakeRunPacket()
runPacket.CK = base.MakeCommandKey(params.SessionId, rtnLine.CmdId)
runPacket.Cwd = ""
@ -498,115 +464,6 @@ func HandleRunCommand(w http.ResponseWriter, r *http.Request) {
// startcmd will figure out the correct
//
func LaunchMShell() (*MShellProc, error) {
msPath, err := base.GetMShellPath()
if err != nil {
return nil, err
}
ecmd := exec.Command(msPath)
inputWriter, err := ecmd.StdinPipe()
if err != nil {
return nil, err
}
outputReader, err := ecmd.StdoutPipe()
if err != nil {
return nil, err
}
ecmd.Stderr = ecmd.Stdout
err = ecmd.Start()
if err != nil {
return nil, err
}
rtn := &MShellProc{Lock: &sync.Mutex{}, Local: true, Cmd: ecmd}
rtn.Output = packet.MakePacketParser(outputReader)
rtn.Input = packet.MakePacketSender(inputWriter)
rtn.RpcMap = make(map[string]*RpcEntry)
rtn.DoneCh = make(chan bool)
go func() {
exitErr := ecmd.Wait()
exitCode := shexec.GetExitCode(exitErr)
fmt.Printf("[error] RUNNER PROC EXITED code[%d]\n", exitCode)
close(rtn.DoneCh)
}()
return rtn, nil
}
func (runner *MShellProc) PacketRpc(pk packet.RpcPacketType, timeout time.Duration) (packet.RpcPacketType, error) {
if pk == nil {
return nil, fmt.Errorf("PacketRpc passed nil packet")
}
id := pk.GetPacketId()
respCh := make(chan packet.RpcPacketType)
runner.Lock.Lock()
runner.RpcMap[id] = &RpcEntry{PacketId: id, RespCh: respCh}
runner.Lock.Unlock()
defer func() {
runner.Lock.Lock()
delete(runner.RpcMap, id)
runner.Lock.Unlock()
}()
runner.Input.SendPacket(pk)
timer := time.NewTimer(timeout)
defer timer.Stop()
select {
case rtnPk := <-respCh:
return rtnPk, nil
case <-timer.C:
return nil, fmt.Errorf("PacketRpc timeout")
}
}
func (runner *MShellProc) ProcessPackets() {
for pk := range runner.Output.MainCh {
if rpcPk, ok := pk.(packet.RpcPacketType); ok {
rpcId := rpcPk.GetPacketId()
runner.Lock.Lock()
entry := runner.RpcMap[rpcId]
if entry != nil {
delete(runner.RpcMap, rpcId)
go func() {
entry.RespCh <- rpcPk
close(entry.RespCh)
}()
}
runner.Lock.Unlock()
}
if pk.GetType() == packet.CmdDataPacketStr {
dataPacket := pk.(*packet.CmdDataPacketType)
fmt.Printf("cmd-data %s pty=%d run=%d\n", dataPacket.CK, len(dataPacket.PtyData), len(dataPacket.RunData))
continue
}
if pk.GetType() == packet.InitPacketStr {
initPacket := pk.(*packet.InitPacketType)
fmt.Printf("runner-init %s user=%s dir=%s\n", initPacket.MShellHomeDir, initPacket.User, initPacket.HomeDir)
runner.Lock.Lock()
runner.Initialized = true
runner.User = initPacket.User
runner.CurDir = initPacket.HomeDir
runner.HomeDir = initPacket.HomeDir
runner.Env = initPacket.Env
if runner.Local {
runner.Host = "local"
}
runner.Lock.Unlock()
continue
}
if pk.GetType() == packet.MessagePacketStr {
msgPacket := pk.(*packet.MessagePacketType)
fmt.Printf("# %s\n", msgPacket.Message)
continue
}
if pk.GetType() == packet.RawPacketStr {
rawPacket := pk.(*packet.RawPacketType)
fmt.Printf("stderr> %s\n", rawPacket.Data)
continue
}
fmt.Printf("runner-packet: %v\n", pk)
}
}
func runWebSocketServer() {
gr := mux.NewRouter()
gr.HandleFunc("/ws", HandleWs)
@ -626,7 +483,19 @@ func runWebSocketServer() {
}
func main() {
runnerProc, err := LaunchMShell()
if len(os.Args) >= 2 && strings.HasPrefix(os.Args[1], "--migrate") {
err := sstore.MigrateCommandOpts(os.Args[1:])
if err != nil {
fmt.Printf("[error] %v\n", err)
}
return
}
err := sstore.TryMigrateUp()
if err != nil {
fmt.Printf("[error] %v\n", err)
return
}
runnerProc, err := remote.LaunchMShell()
if err != nil {
fmt.Printf("error launching runner-proc: %v\n", err)
return

View File

@ -0,0 +1,7 @@
DROP TABLE session;
DROP TABLE window;
DROP TABLE session_remote;
DROP TABLE line;
DROP TABLE remote;
DROP TABLE session_cmd;
DROP TABLE history;

View File

@ -0,0 +1,67 @@
CREATE TABLE session (
sessionid varchar(36) PRIMARY KEY,
name varchar(50) NOT NULL
);
CREATE UNIQUE INDEX session_name_unique ON session(name);
CREATE TABLE window (
sessionid varchar(36) NOT NULL,
windowid varchar(36) NOT NULL,
name varchar(50) NOT NULL,
curremote varchar(50) NOT NULL,
version int NOT NULL,
PRIMARY KEY (sessionid, windowid)
);
CREATE UNIQUE INDEX window_name_unique ON window(sessionid, name);
CREATE TABLE session_remote (
sessionid varchar(36) NOT NULL,
windowid varchar(36) NOT NULL,
remotename varchar(50) NOT NULL,
remoteid varchar(36) NOT NULL,
cwd varchar(300) NOT NULL,
PRIMARY KEY (sessionid, windowid, remotename)
);
CREATE TABLE line (
sessionid varchar(36) NOT NULL,
windowid varchar(36) NOT NULL,
lineid int NOT NULL,
userid varchar(36) NOT NULL,
ts bigint NOT NULL,
linetype varchar(10) NOT NULL,
text text NOT NULL,
cmdid varchar(36) NOT NULL,
PRIMARY KEY (sessionid, windowid, lineid)
);
CREATE TABLE remote (
remoteid varchar(36) PRIMARY KEY,
remotetype varchar(10) NOT NULL,
remotename varchar(50) NOT NULL,
connectopts varchar(300) NOT NULL
);
CREATE TABLE session_cmd (
sessionid varchar(36) NOT NULL,
cmdid varchar(36) NOT NULL,
remoteid varchar(36) NOT NULL,
status varchar(10) NOT NULL,
startts bigint NOT NULL,
pid int NOT NULL,
runnerpid int NOT NULL,
donets bigint NOT NULL,
exitcode int NOT NULL,
ptyout BLOB NOT NULL,
runout BLOB NOT NULL,
PRIMARY KEY (sessionid, cmdid)
);
CREATE TABLE history (
sessionid varchar(36) NOT NULL,
windowid varchar(36) NOT NULL,
userid varchar(36) NOT NULL,
ts int64 NOT NULL,
lineid varchar(36) NOT NULL,
PRIMARY KEY (sessionid, windowid, lineid)
);

6
go.mod
View File

@ -3,9 +3,12 @@ module github.com/scripthaus-dev/sh2-server
go 1.17
require (
github.com/golang-migrate/migrate/v4 v4.15.2
github.com/google/uuid v1.3.0
github.com/gorilla/mux v1.8.0
github.com/gorilla/websocket v1.5.0
github.com/jmoiron/sqlx v1.3.5
github.com/mattn/go-sqlite3 v1.14.14
github.com/scripthaus-dev/mshell v0.0.0
)
@ -13,6 +16,9 @@ require (
github.com/alessio/shellescape v1.4.1 // indirect
github.com/creack/pty v1.1.18 // indirect
github.com/fsnotify/fsnotify v1.5.4 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
go.uber.org/atomic v1.7.0 // indirect
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad // indirect
)

1798
go.sum

File diff suppressed because it is too large Load Diff

118
pkg/sstore/migrate.go Normal file
View File

@ -0,0 +1,118 @@
package sstore
import (
"fmt"
"os"
"path"
"strconv"
_ "github.com/golang-migrate/migrate/v4/database/sqlite3"
_ "github.com/golang-migrate/migrate/v4/source/file"
_ "github.com/mattn/go-sqlite3"
"github.com/golang-migrate/migrate/v4"
)
func MakeMigrate() (*migrate.Migrate, error) {
wd, err := os.Getwd()
if err != nil {
return nil, err
}
migrationPathUrl := fmt.Sprintf("file://%s", path.Join(wd, "db", "migrations"))
dbUrl := fmt.Sprintf("sqlite3://%s", GetSessionDBName())
m, err := migrate.New(migrationPathUrl, dbUrl)
if err != nil {
return nil, err
}
return m, nil
}
func MigrateUp() error {
m, err := MakeMigrate()
if err != nil {
return err
}
err = m.Up()
if err != nil {
return err
}
return nil
}
func MigrateVersion() (uint, bool, error) {
m, err := MakeMigrate()
if err != nil {
return 0, false, err
}
return m.Version()
}
func MigrateDown() error {
m, err := MakeMigrate()
if err != nil {
return err
}
err = m.Down()
if err != nil {
return err
}
return nil
}
func MigrateGoto(n uint) error {
m, err := MakeMigrate()
if err != nil {
return err
}
err = m.Migrate(n)
if err != nil {
return err
}
return nil
}
func TryMigrateUp() error {
err := MigrateUp()
if err != nil && err.Error() == migrate.ErrNoChange.Error() {
err = nil
}
if err != nil {
return err
}
return MigratePrintVersion()
}
func MigratePrintVersion() error {
version, dirty, err := MigrateVersion()
if err != nil {
return fmt.Errorf("error getting db version: %v", err)
}
if dirty {
return fmt.Errorf("error db is dirty, version=%d", version)
}
fmt.Printf("[db] version=%d\n", version)
return nil
}
func MigrateCommandOpts(opts []string) error {
var err error
if opts[0] == "--migrate-up" {
err = MigrateUp()
} else if opts[0] == "--migrate-down" {
err = MigrateDown()
} else if opts[0] == "--migrate-goto" {
n, err := strconv.Atoi(opts[1])
if err == nil {
err = MigrateGoto(uint(n))
}
} else {
err = fmt.Errorf("invalid migration command")
}
if err != nil && err.Error() == migrate.ErrNoChange.Error() {
err = nil
}
if err != nil {
return err
}
return MigratePrintVersion()
}

View File

@ -6,7 +6,11 @@ import (
"time"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
"github.com/scripthaus-dev/mshell/pkg/packet"
"github.com/scripthaus-dev/sh2-server/pkg/scbase"
_ "github.com/mattn/go-sqlite3"
)
var NextLineId = 10
@ -14,22 +18,45 @@ var NextLineLock = &sync.Mutex{}
const LineTypeCmd = "cmd"
const LineTypeText = "text"
const DBFileName = "scripthaus.db"
const DBFileName = "sh2.db"
func GetSessionDBName(sessionId string) string {
func GetSessionDBName() string {
scHome := scbase.GetScHomeDir()
return path.Join(scHome, DBFileName)
}
func OpenConnPool() (*sqlx.DB, error) {
connPool, err := sqlx.Open("sqlite3", GetSessionDBName())
if err != nil {
return nil, err
}
return connPool, nil
}
type SessionType struct {
SessionId string `json:"sessionid"`
Remote string `json:"remote"`
Cwd string `json:"cwd"`
SessionId string `json:"sessionid"`
Remote string `json:"remote"`
Name string `json:"name"`
Windows []*WindowType `json:"windows"`
Cmds []*CmdType `json:"cmds"`
}
type WindowType struct {
SessionId string `json:"sessionid"`
WindowId string `json:"windowid"`
SessionId string `json:"sessionid"`
WindowId string `json:"windowid"`
Name string `json:"name"`
CurRemote string `json:"curremote"`
Remotes []*SessionRemote `json:"remotes"`
Lines []*LineType `json:"lines"`
Version int `json:"version"`
}
type SessionRemote struct {
SessionId string `json:"sessionid"`
WindowId string `json:"windowid"`
RemoteId string `json"remoteid"`
RemoteName string `json:"name"`
Cwd string `json:"cwd"`
}
type LineType struct {
@ -41,9 +68,29 @@ type LineType struct {
LineType string `json:"linetype"`
Text string `json:"text,omitempty"`
CmdId string `json:"cmdid,omitempty"`
CmdText string `json:"cmdtext,omitempty"`
CmdRemote string `json:"cmdremote,omitempty"`
CmdCwd string `json:"cmdcwd,omitempty"`
}
type RemoteType struct {
RemoteId string `json:"remoteid"`
RemoteType string `json:"remotetype"`
RemoteName string `json:"remotename"`
ConnectOpts string `json:"connectopts"`
Connected bool `json:"connected"`
}
type CmdType struct {
RowId int64 `json:"rowid"`
SessionId string `json:"sessionid"`
CmdId string `json:"cmdid"`
RemoteId string `json:"remoteid"`
Status string `json:"status"`
StartTs int64 `json:"startts"`
DoneTs int64 `json:"donets"`
Pid int `json:"pid"`
RunnerPid int `json:"runnerpid"`
ExitCode int `json:"exitcode"`
RunOut packet.PacketType `json:"runout"`
}
func MakeNewLineCmd(sessionId string, windowId string) *LineType {