checkpoint, working on setting up db

This commit is contained in:
sawka 2022-07-01 12:17:19 -07:00
parent 02029b3948
commit 643f08e584
6 changed files with 349 additions and 7 deletions

View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -495,6 +496,22 @@ func main() {
fmt.Printf("[error] %v\n", err) fmt.Printf("[error] %v\n", err)
return return
} }
numSessions, err := sstore.NumSessions(context.Background())
if err != nil {
fmt.Printf("[error] getting num sessions: %v\n", err)
return
}
err = sstore.EnsureLocalRemote(context.Background())
if err != nil {
fmt.Printf("[error] ensuring local remote: %v\n", err)
return
}
fmt.Printf("[db] sessions count=%d\n", numSessions)
if numSessions == 0 {
sstore.CreateInitialSession(context.Background())
}
return
runnerProc, err := remote.LaunchMShell() runnerProc, err := remote.LaunchMShell()
if err != nil { if err != nil {
fmt.Printf("error launching runner-proc: %v\n", err) fmt.Printf("error launching runner-proc: %v\n", err)

View File

@ -39,7 +39,8 @@ CREATE TABLE remote (
remoteid varchar(36) PRIMARY KEY, remoteid varchar(36) PRIMARY KEY,
remotetype varchar(10) NOT NULL, remotetype varchar(10) NOT NULL,
remotename varchar(50) NOT NULL, remotename varchar(50) NOT NULL,
connectopts varchar(300) NOT NULL connectopts varchar(300) NOT NULL,
ptyout BLOB NOT NULL
); );
CREATE TABLE session_cmd ( CREATE TABLE session_cmd (

View File

@ -37,7 +37,8 @@ CREATE TABLE remote (
remoteid varchar(36) PRIMARY KEY, remoteid varchar(36) PRIMARY KEY,
remotetype varchar(10) NOT NULL, remotetype varchar(10) NOT NULL,
remotename varchar(50) NOT NULL, remotename varchar(50) NOT NULL,
connectopts varchar(300) NOT NULL connectopts varchar(300) NOT NULL,
ptyout BLOB NOT NULL
); );
CREATE TABLE session_cmd ( CREATE TABLE session_cmd (
sessionid varchar(36) NOT NULL, sessionid varchar(36) NOT NULL,

160
pkg/remote/remote.go Normal file
View File

@ -0,0 +1,160 @@
package remote
import (
"fmt"
"os/exec"
"strings"
"sync"
"time"
"github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/mshell/pkg/packet"
"github.com/scripthaus-dev/mshell/pkg/shexec"
)
const RemoteTypeMShell = "mshell"
type MShellProc struct {
Lock *sync.Mutex
RemoteId string
WindowId string
RemoteName string
Cmd *exec.Cmd
Input *packet.PacketSender
Output *packet.PacketParser
Local bool
DoneCh chan bool
CurDir string
HomeDir string
User string
Host string
Env []string
Connected bool
RpcMap map[string]*RpcEntry
}
type RpcEntry struct {
PacketId string
RespCh chan packet.RpcPacketType
}
func LaunchMShell() (*MShellProc, error) {
msPath, err := base.GetMShellPath()
if err != nil {
return nil, err
}
ecmd := exec.Command(msPath)
inputWriter, err := ecmd.StdinPipe()
if err != nil {
return nil, err
}
outputReader, err := ecmd.StdoutPipe()
if err != nil {
return nil, err
}
ecmd.Stderr = ecmd.Stdout
err = ecmd.Start()
if err != nil {
return nil, err
}
rtn := &MShellProc{Lock: &sync.Mutex{}, Local: true, Cmd: ecmd}
rtn.Output = packet.MakePacketParser(outputReader)
rtn.Input = packet.MakePacketSender(inputWriter)
rtn.RpcMap = make(map[string]*RpcEntry)
rtn.DoneCh = make(chan bool)
go func() {
exitErr := ecmd.Wait()
exitCode := shexec.GetExitCode(exitErr)
fmt.Printf("[error] RUNNER PROC EXITED code[%d]\n", exitCode)
close(rtn.DoneCh)
}()
return rtn, nil
}
func (runner *MShellProc) PacketRpc(pk packet.RpcPacketType, timeout time.Duration) (packet.RpcPacketType, error) {
if pk == nil {
return nil, fmt.Errorf("PacketRpc passed nil packet")
}
id := pk.GetPacketId()
respCh := make(chan packet.RpcPacketType)
runner.Lock.Lock()
runner.RpcMap[id] = &RpcEntry{PacketId: id, RespCh: respCh}
runner.Lock.Unlock()
defer func() {
runner.Lock.Lock()
delete(runner.RpcMap, id)
runner.Lock.Unlock()
}()
runner.Input.SendPacket(pk)
timer := time.NewTimer(timeout)
defer timer.Stop()
select {
case rtnPk := <-respCh:
return rtnPk, nil
case <-timer.C:
return nil, fmt.Errorf("PacketRpc timeout")
}
}
func (runner *MShellProc) ProcessPackets() {
for pk := range runner.Output.MainCh {
if rpcPk, ok := pk.(packet.RpcPacketType); ok {
rpcId := rpcPk.GetPacketId()
runner.Lock.Lock()
entry := runner.RpcMap[rpcId]
if entry != nil {
delete(runner.RpcMap, rpcId)
go func() {
entry.RespCh <- rpcPk
close(entry.RespCh)
}()
}
runner.Lock.Unlock()
}
if pk.GetType() == packet.CmdDataPacketStr {
dataPacket := pk.(*packet.CmdDataPacketType)
fmt.Printf("cmd-data %s pty=%d run=%d\n", dataPacket.CK, len(dataPacket.PtyData), len(dataPacket.RunData))
continue
}
if pk.GetType() == packet.InitPacketStr {
initPacket := pk.(*packet.InitPacketType)
fmt.Printf("runner-init %s user=%s dir=%s\n", initPacket.MShellHomeDir, initPacket.User, initPacket.HomeDir)
runner.Lock.Lock()
runner.Connected = true
runner.User = initPacket.User
runner.CurDir = initPacket.HomeDir
runner.HomeDir = initPacket.HomeDir
runner.Env = initPacket.Env
if runner.Local {
runner.Host = "local"
}
runner.Lock.Unlock()
continue
}
if pk.GetType() == packet.MessagePacketStr {
msgPacket := pk.(*packet.MessagePacketType)
fmt.Printf("# %s\n", msgPacket.Message)
continue
}
if pk.GetType() == packet.RawPacketStr {
rawPacket := pk.(*packet.RawPacketType)
fmt.Printf("stderr> %s\n", rawPacket.Data)
continue
}
fmt.Printf("runner-packet: %v\n", pk)
}
}
func (r *MShellProc) GetPrompt() string {
r.Lock.Lock()
defer r.Lock.Unlock()
var curDir = r.CurDir
if r.CurDir == r.HomeDir {
curDir = "~"
} else if strings.HasPrefix(r.CurDir, r.HomeDir+"/") {
curDir = "~/" + r.CurDir[0:len(r.HomeDir)+1]
}
return fmt.Sprintf("[%s@%s %s]", r.User, r.Host, curDir)
}

35
pkg/scpacket/scpacket.go Normal file
View File

@ -0,0 +1,35 @@
package scpacket
import (
"reflect"
"github.com/scripthaus-dev/mshell/pkg/packet"
)
const FeCommandPacketStr = "fecmd"
type RemoteState struct {
RemoteId string `json:"remoteid"`
RemoteName string `json:"remotename"`
Cwd string `json:"cwd"`
}
type FeCommandPacketType struct {
Type string `json:"type"`
SessionId string `json:"sessionid"`
WindowId string `json:"windowid"`
CmdStr string `json:"cmdstr"`
RemoteState RemoteState `json:"remotestate"`
}
func init() {
packet.RegisterPacketType(FeCommandPacketStr, reflect.TypeOf(&FeCommandPacketType{}))
}
func (*FeCommandPacketType) GetType() string {
return FeCommandPacketStr
}
func MakeFeCommandPacket() *FeCommandPacketType {
return &FeCommandPacketType{Type: FeCommandPacketStr}
}

View File

@ -1,12 +1,16 @@
package sstore package sstore
import ( import (
"context"
"database/sql"
"fmt"
"path" "path"
"sync" "sync"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/mshell/pkg/packet" "github.com/scripthaus-dev/mshell/pkg/packet"
"github.com/scripthaus-dev/sh2-server/pkg/scbase" "github.com/scripthaus-dev/sh2-server/pkg/scbase"
@ -20,17 +24,26 @@ const LineTypeCmd = "cmd"
const LineTypeText = "text" const LineTypeText = "text"
const DBFileName = "sh2.db" const DBFileName = "sh2.db"
const DefaultSessionName = "default"
const DefaultWindowName = "default"
const LocalRemoteName = "local"
var globalDBLock = &sync.Mutex{}
var globalDB *sqlx.DB
var globalDBErr error
func GetSessionDBName() string { func GetSessionDBName() string {
scHome := scbase.GetScHomeDir() scHome := scbase.GetScHomeDir()
return path.Join(scHome, DBFileName) return path.Join(scHome, DBFileName)
} }
func OpenConnPool() (*sqlx.DB, error) { func GetDB() (*sqlx.DB, error) {
connPool, err := sqlx.Open("sqlite3", GetSessionDBName()) globalDBLock.Lock()
if err != nil { defer globalDBLock.Unlock()
return nil, err if globalDB == nil && globalDBErr == nil {
globalDB, globalDBErr = sqlx.Open("sqlite3", GetSessionDBName())
} }
return connPool, nil return globalDB, globalDBErr
} }
type SessionType struct { type SessionType struct {
@ -71,6 +84,7 @@ type LineType struct {
} }
type RemoteType struct { type RemoteType struct {
RowId int64 `json:"rowid"`
RemoteId string `json:"remoteid"` RemoteId string `json:"remoteid"`
RemoteType string `json:"remotetype"` RemoteType string `json:"remotetype"`
RemoteName string `json:"remotename"` RemoteName string `json:"remotename"`
@ -124,3 +138,117 @@ func GetNextLine() int {
NextLineId++ NextLineId++
return rtn return rtn
} }
func NumSessions(ctx context.Context) (int, error) {
db, err := GetDB()
if err != nil {
return 0, err
}
query := "SELECT count(*) FROM session"
var count int
err = db.GetContext(ctx, &count, query)
if err != nil {
return 0, err
}
return count, nil
}
func GetRemoteById(ctx context.Context, remoteId string) (*RemoteType, error) {
db, err := GetDB()
if err != nil {
return nil, err
}
query := `SELECT rowid, remoteid, remotetype, remotename, connectopts FROM remote WHERE remoteid = ?`
var remote RemoteType
err = db.GetContext(ctx, &remote, query, remoteId)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &remote, nil
}
func InsertRemote(ctx context.Context, remote *RemoteType) error {
if remote == nil {
return fmt.Errorf("cannot insert nil remote")
}
if remote.RowId != 0 {
return fmt.Errorf("cannot insert a remote that already has rowid set, rowid=%d", remote.RowId)
}
db, err := GetDB()
if err != nil {
return err
}
query := `INSERT INTO remote (remoteid, remotetype, remotename, connectopts, ptyout) VALUES (:remoteid, :remotetype, :remotename, :connectopts, '')`
result, err := db.NamedExec(query, remote)
if err != nil {
return err
}
remote.RowId, err = result.LastInsertId()
if err != nil {
return fmt.Errorf("cannot get lastinsertid from insert remote: %w", err)
}
return nil
}
func EnsureLocalRemote(ctx context.Context) error {
remoteId, err := base.GetRemoteId()
if err != nil {
return err
}
remote, err := GetRemoteById(ctx, remoteId)
if err != nil {
return err
}
if remote != nil {
return nil
}
// create the local remote
localRemote := &RemoteType{
RemoteId: remoteId,
RemoteType: "ssh",
RemoteName: LocalRemoteName,
}
err = InsertRemote(ctx, localRemote)
if err != nil {
return err
}
return nil
}
func CreateInitialSession(ctx context.Context) error {
db, err := GetDB()
if err != nil {
return err
}
session := &SessionType{
SessionId: uuid.New().String(),
Name: DefaultSessionName,
}
window := &WindowType{
SessionId: session.SessionId,
WindowId: uuid.New().String(),
Name: DefaultWindowName,
CurRemote: LocalRemoteName,
}
remoteId, err := base.GetRemoteId()
if err != nil {
return err
}
localRemote := &RemoteType{
RemoteId: remoteId,
RemoteType: "ssh",
RemoteName: LocalRemoteName,
}
sessRemote := &SessionRemote{
SessionId: session.SessionId,
WindowId: window.WindowId,
RemoteId: remoteId,
RemoteName: localRemote.RemoteName,
Cwd: base.GetHomeDir(),
}
fmt.Printf("db=%v s=%v w=%v r=%v sr=%v\n", db, session, window, localRemote, sessRemote)
return nil
}