checkpoint, working on setting up db

This commit is contained in:
sawka 2022-07-01 12:17:19 -07:00
parent 02029b3948
commit 643f08e584
6 changed files with 349 additions and 7 deletions

View File

@ -1,6 +1,7 @@
package main
import (
"context"
"encoding/json"
"errors"
"fmt"
@ -495,6 +496,22 @@ func main() {
fmt.Printf("[error] %v\n", err)
return
}
numSessions, err := sstore.NumSessions(context.Background())
if err != nil {
fmt.Printf("[error] getting num sessions: %v\n", err)
return
}
err = sstore.EnsureLocalRemote(context.Background())
if err != nil {
fmt.Printf("[error] ensuring local remote: %v\n", err)
return
}
fmt.Printf("[db] sessions count=%d\n", numSessions)
if numSessions == 0 {
sstore.CreateInitialSession(context.Background())
}
return
runnerProc, err := remote.LaunchMShell()
if err != nil {
fmt.Printf("error launching runner-proc: %v\n", err)

View File

@ -39,7 +39,8 @@ CREATE TABLE remote (
remoteid varchar(36) PRIMARY KEY,
remotetype varchar(10) NOT NULL,
remotename varchar(50) NOT NULL,
connectopts varchar(300) NOT NULL
connectopts varchar(300) NOT NULL,
ptyout BLOB NOT NULL
);
CREATE TABLE session_cmd (

View File

@ -37,7 +37,8 @@ CREATE TABLE remote (
remoteid varchar(36) PRIMARY KEY,
remotetype varchar(10) NOT NULL,
remotename varchar(50) NOT NULL,
connectopts varchar(300) NOT NULL
connectopts varchar(300) NOT NULL,
ptyout BLOB NOT NULL
);
CREATE TABLE session_cmd (
sessionid varchar(36) NOT NULL,

160
pkg/remote/remote.go Normal file
View File

@ -0,0 +1,160 @@
package remote
import (
"fmt"
"os/exec"
"strings"
"sync"
"time"
"github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/mshell/pkg/packet"
"github.com/scripthaus-dev/mshell/pkg/shexec"
)
const RemoteTypeMShell = "mshell"
type MShellProc struct {
Lock *sync.Mutex
RemoteId string
WindowId string
RemoteName string
Cmd *exec.Cmd
Input *packet.PacketSender
Output *packet.PacketParser
Local bool
DoneCh chan bool
CurDir string
HomeDir string
User string
Host string
Env []string
Connected bool
RpcMap map[string]*RpcEntry
}
type RpcEntry struct {
PacketId string
RespCh chan packet.RpcPacketType
}
func LaunchMShell() (*MShellProc, error) {
msPath, err := base.GetMShellPath()
if err != nil {
return nil, err
}
ecmd := exec.Command(msPath)
inputWriter, err := ecmd.StdinPipe()
if err != nil {
return nil, err
}
outputReader, err := ecmd.StdoutPipe()
if err != nil {
return nil, err
}
ecmd.Stderr = ecmd.Stdout
err = ecmd.Start()
if err != nil {
return nil, err
}
rtn := &MShellProc{Lock: &sync.Mutex{}, Local: true, Cmd: ecmd}
rtn.Output = packet.MakePacketParser(outputReader)
rtn.Input = packet.MakePacketSender(inputWriter)
rtn.RpcMap = make(map[string]*RpcEntry)
rtn.DoneCh = make(chan bool)
go func() {
exitErr := ecmd.Wait()
exitCode := shexec.GetExitCode(exitErr)
fmt.Printf("[error] RUNNER PROC EXITED code[%d]\n", exitCode)
close(rtn.DoneCh)
}()
return rtn, nil
}
func (runner *MShellProc) PacketRpc(pk packet.RpcPacketType, timeout time.Duration) (packet.RpcPacketType, error) {
if pk == nil {
return nil, fmt.Errorf("PacketRpc passed nil packet")
}
id := pk.GetPacketId()
respCh := make(chan packet.RpcPacketType)
runner.Lock.Lock()
runner.RpcMap[id] = &RpcEntry{PacketId: id, RespCh: respCh}
runner.Lock.Unlock()
defer func() {
runner.Lock.Lock()
delete(runner.RpcMap, id)
runner.Lock.Unlock()
}()
runner.Input.SendPacket(pk)
timer := time.NewTimer(timeout)
defer timer.Stop()
select {
case rtnPk := <-respCh:
return rtnPk, nil
case <-timer.C:
return nil, fmt.Errorf("PacketRpc timeout")
}
}
func (runner *MShellProc) ProcessPackets() {
for pk := range runner.Output.MainCh {
if rpcPk, ok := pk.(packet.RpcPacketType); ok {
rpcId := rpcPk.GetPacketId()
runner.Lock.Lock()
entry := runner.RpcMap[rpcId]
if entry != nil {
delete(runner.RpcMap, rpcId)
go func() {
entry.RespCh <- rpcPk
close(entry.RespCh)
}()
}
runner.Lock.Unlock()
}
if pk.GetType() == packet.CmdDataPacketStr {
dataPacket := pk.(*packet.CmdDataPacketType)
fmt.Printf("cmd-data %s pty=%d run=%d\n", dataPacket.CK, len(dataPacket.PtyData), len(dataPacket.RunData))
continue
}
if pk.GetType() == packet.InitPacketStr {
initPacket := pk.(*packet.InitPacketType)
fmt.Printf("runner-init %s user=%s dir=%s\n", initPacket.MShellHomeDir, initPacket.User, initPacket.HomeDir)
runner.Lock.Lock()
runner.Connected = true
runner.User = initPacket.User
runner.CurDir = initPacket.HomeDir
runner.HomeDir = initPacket.HomeDir
runner.Env = initPacket.Env
if runner.Local {
runner.Host = "local"
}
runner.Lock.Unlock()
continue
}
if pk.GetType() == packet.MessagePacketStr {
msgPacket := pk.(*packet.MessagePacketType)
fmt.Printf("# %s\n", msgPacket.Message)
continue
}
if pk.GetType() == packet.RawPacketStr {
rawPacket := pk.(*packet.RawPacketType)
fmt.Printf("stderr> %s\n", rawPacket.Data)
continue
}
fmt.Printf("runner-packet: %v\n", pk)
}
}
func (r *MShellProc) GetPrompt() string {
r.Lock.Lock()
defer r.Lock.Unlock()
var curDir = r.CurDir
if r.CurDir == r.HomeDir {
curDir = "~"
} else if strings.HasPrefix(r.CurDir, r.HomeDir+"/") {
curDir = "~/" + r.CurDir[0:len(r.HomeDir)+1]
}
return fmt.Sprintf("[%s@%s %s]", r.User, r.Host, curDir)
}

35
pkg/scpacket/scpacket.go Normal file
View File

@ -0,0 +1,35 @@
package scpacket
import (
"reflect"
"github.com/scripthaus-dev/mshell/pkg/packet"
)
const FeCommandPacketStr = "fecmd"
type RemoteState struct {
RemoteId string `json:"remoteid"`
RemoteName string `json:"remotename"`
Cwd string `json:"cwd"`
}
type FeCommandPacketType struct {
Type string `json:"type"`
SessionId string `json:"sessionid"`
WindowId string `json:"windowid"`
CmdStr string `json:"cmdstr"`
RemoteState RemoteState `json:"remotestate"`
}
func init() {
packet.RegisterPacketType(FeCommandPacketStr, reflect.TypeOf(&FeCommandPacketType{}))
}
func (*FeCommandPacketType) GetType() string {
return FeCommandPacketStr
}
func MakeFeCommandPacket() *FeCommandPacketType {
return &FeCommandPacketType{Type: FeCommandPacketStr}
}

View File

@ -1,12 +1,16 @@
package sstore
import (
"context"
"database/sql"
"fmt"
"path"
"sync"
"time"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
"github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/mshell/pkg/packet"
"github.com/scripthaus-dev/sh2-server/pkg/scbase"
@ -20,17 +24,26 @@ const LineTypeCmd = "cmd"
const LineTypeText = "text"
const DBFileName = "sh2.db"
const DefaultSessionName = "default"
const DefaultWindowName = "default"
const LocalRemoteName = "local"
var globalDBLock = &sync.Mutex{}
var globalDB *sqlx.DB
var globalDBErr error
func GetSessionDBName() string {
scHome := scbase.GetScHomeDir()
return path.Join(scHome, DBFileName)
}
func OpenConnPool() (*sqlx.DB, error) {
connPool, err := sqlx.Open("sqlite3", GetSessionDBName())
if err != nil {
return nil, err
func GetDB() (*sqlx.DB, error) {
globalDBLock.Lock()
defer globalDBLock.Unlock()
if globalDB == nil && globalDBErr == nil {
globalDB, globalDBErr = sqlx.Open("sqlite3", GetSessionDBName())
}
return connPool, nil
return globalDB, globalDBErr
}
type SessionType struct {
@ -71,6 +84,7 @@ type LineType struct {
}
type RemoteType struct {
RowId int64 `json:"rowid"`
RemoteId string `json:"remoteid"`
RemoteType string `json:"remotetype"`
RemoteName string `json:"remotename"`
@ -124,3 +138,117 @@ func GetNextLine() int {
NextLineId++
return rtn
}
func NumSessions(ctx context.Context) (int, error) {
db, err := GetDB()
if err != nil {
return 0, err
}
query := "SELECT count(*) FROM session"
var count int
err = db.GetContext(ctx, &count, query)
if err != nil {
return 0, err
}
return count, nil
}
func GetRemoteById(ctx context.Context, remoteId string) (*RemoteType, error) {
db, err := GetDB()
if err != nil {
return nil, err
}
query := `SELECT rowid, remoteid, remotetype, remotename, connectopts FROM remote WHERE remoteid = ?`
var remote RemoteType
err = db.GetContext(ctx, &remote, query, remoteId)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &remote, nil
}
func InsertRemote(ctx context.Context, remote *RemoteType) error {
if remote == nil {
return fmt.Errorf("cannot insert nil remote")
}
if remote.RowId != 0 {
return fmt.Errorf("cannot insert a remote that already has rowid set, rowid=%d", remote.RowId)
}
db, err := GetDB()
if err != nil {
return err
}
query := `INSERT INTO remote (remoteid, remotetype, remotename, connectopts, ptyout) VALUES (:remoteid, :remotetype, :remotename, :connectopts, '')`
result, err := db.NamedExec(query, remote)
if err != nil {
return err
}
remote.RowId, err = result.LastInsertId()
if err != nil {
return fmt.Errorf("cannot get lastinsertid from insert remote: %w", err)
}
return nil
}
func EnsureLocalRemote(ctx context.Context) error {
remoteId, err := base.GetRemoteId()
if err != nil {
return err
}
remote, err := GetRemoteById(ctx, remoteId)
if err != nil {
return err
}
if remote != nil {
return nil
}
// create the local remote
localRemote := &RemoteType{
RemoteId: remoteId,
RemoteType: "ssh",
RemoteName: LocalRemoteName,
}
err = InsertRemote(ctx, localRemote)
if err != nil {
return err
}
return nil
}
func CreateInitialSession(ctx context.Context) error {
db, err := GetDB()
if err != nil {
return err
}
session := &SessionType{
SessionId: uuid.New().String(),
Name: DefaultSessionName,
}
window := &WindowType{
SessionId: session.SessionId,
WindowId: uuid.New().String(),
Name: DefaultWindowName,
CurRemote: LocalRemoteName,
}
remoteId, err := base.GetRemoteId()
if err != nil {
return err
}
localRemote := &RemoteType{
RemoteId: remoteId,
RemoteType: "ssh",
RemoteName: LocalRemoteName,
}
sessRemote := &SessionRemote{
SessionId: session.SessionId,
WindowId: window.WindowId,
RemoteId: remoteId,
RemoteName: localRemote.RemoteName,
Cwd: base.GetHomeDir(),
}
fmt.Printf("db=%v s=%v w=%v r=%v sr=%v\n", db, session, window, localRemote, sessRemote)
return nil
}