mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-10 19:58:00 +01:00
checkpoint, working on setting up db
This commit is contained in:
parent
02029b3948
commit
643f08e584
@ -1,6 +1,7 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -495,6 +496,22 @@ func main() {
|
|||||||
fmt.Printf("[error] %v\n", err)
|
fmt.Printf("[error] %v\n", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
numSessions, err := sstore.NumSessions(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("[error] getting num sessions: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = sstore.EnsureLocalRemote(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("[error] ensuring local remote: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Printf("[db] sessions count=%d\n", numSessions)
|
||||||
|
if numSessions == 0 {
|
||||||
|
sstore.CreateInitialSession(context.Background())
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
runnerProc, err := remote.LaunchMShell()
|
runnerProc, err := remote.LaunchMShell()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("error launching runner-proc: %v\n", err)
|
fmt.Printf("error launching runner-proc: %v\n", err)
|
||||||
|
@ -39,7 +39,8 @@ CREATE TABLE remote (
|
|||||||
remoteid varchar(36) PRIMARY KEY,
|
remoteid varchar(36) PRIMARY KEY,
|
||||||
remotetype varchar(10) NOT NULL,
|
remotetype varchar(10) NOT NULL,
|
||||||
remotename varchar(50) NOT NULL,
|
remotename varchar(50) NOT NULL,
|
||||||
connectopts varchar(300) NOT NULL
|
connectopts varchar(300) NOT NULL,
|
||||||
|
ptyout BLOB NOT NULL
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE session_cmd (
|
CREATE TABLE session_cmd (
|
||||||
|
@ -37,7 +37,8 @@ CREATE TABLE remote (
|
|||||||
remoteid varchar(36) PRIMARY KEY,
|
remoteid varchar(36) PRIMARY KEY,
|
||||||
remotetype varchar(10) NOT NULL,
|
remotetype varchar(10) NOT NULL,
|
||||||
remotename varchar(50) NOT NULL,
|
remotename varchar(50) NOT NULL,
|
||||||
connectopts varchar(300) NOT NULL
|
connectopts varchar(300) NOT NULL,
|
||||||
|
ptyout BLOB NOT NULL
|
||||||
);
|
);
|
||||||
CREATE TABLE session_cmd (
|
CREATE TABLE session_cmd (
|
||||||
sessionid varchar(36) NOT NULL,
|
sessionid varchar(36) NOT NULL,
|
||||||
|
160
pkg/remote/remote.go
Normal file
160
pkg/remote/remote.go
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
package remote
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/scripthaus-dev/mshell/pkg/base"
|
||||||
|
"github.com/scripthaus-dev/mshell/pkg/packet"
|
||||||
|
"github.com/scripthaus-dev/mshell/pkg/shexec"
|
||||||
|
)
|
||||||
|
|
||||||
|
const RemoteTypeMShell = "mshell"
|
||||||
|
|
||||||
|
type MShellProc struct {
|
||||||
|
Lock *sync.Mutex
|
||||||
|
RemoteId string
|
||||||
|
WindowId string
|
||||||
|
RemoteName string
|
||||||
|
Cmd *exec.Cmd
|
||||||
|
Input *packet.PacketSender
|
||||||
|
Output *packet.PacketParser
|
||||||
|
Local bool
|
||||||
|
DoneCh chan bool
|
||||||
|
CurDir string
|
||||||
|
HomeDir string
|
||||||
|
User string
|
||||||
|
Host string
|
||||||
|
Env []string
|
||||||
|
Connected bool
|
||||||
|
RpcMap map[string]*RpcEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
type RpcEntry struct {
|
||||||
|
PacketId string
|
||||||
|
RespCh chan packet.RpcPacketType
|
||||||
|
}
|
||||||
|
|
||||||
|
func LaunchMShell() (*MShellProc, error) {
|
||||||
|
msPath, err := base.GetMShellPath()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ecmd := exec.Command(msPath)
|
||||||
|
inputWriter, err := ecmd.StdinPipe()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
outputReader, err := ecmd.StdoutPipe()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ecmd.Stderr = ecmd.Stdout
|
||||||
|
err = ecmd.Start()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
rtn := &MShellProc{Lock: &sync.Mutex{}, Local: true, Cmd: ecmd}
|
||||||
|
rtn.Output = packet.MakePacketParser(outputReader)
|
||||||
|
rtn.Input = packet.MakePacketSender(inputWriter)
|
||||||
|
rtn.RpcMap = make(map[string]*RpcEntry)
|
||||||
|
rtn.DoneCh = make(chan bool)
|
||||||
|
go func() {
|
||||||
|
exitErr := ecmd.Wait()
|
||||||
|
exitCode := shexec.GetExitCode(exitErr)
|
||||||
|
fmt.Printf("[error] RUNNER PROC EXITED code[%d]\n", exitCode)
|
||||||
|
close(rtn.DoneCh)
|
||||||
|
}()
|
||||||
|
return rtn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (runner *MShellProc) PacketRpc(pk packet.RpcPacketType, timeout time.Duration) (packet.RpcPacketType, error) {
|
||||||
|
if pk == nil {
|
||||||
|
return nil, fmt.Errorf("PacketRpc passed nil packet")
|
||||||
|
}
|
||||||
|
id := pk.GetPacketId()
|
||||||
|
respCh := make(chan packet.RpcPacketType)
|
||||||
|
runner.Lock.Lock()
|
||||||
|
runner.RpcMap[id] = &RpcEntry{PacketId: id, RespCh: respCh}
|
||||||
|
runner.Lock.Unlock()
|
||||||
|
defer func() {
|
||||||
|
runner.Lock.Lock()
|
||||||
|
delete(runner.RpcMap, id)
|
||||||
|
runner.Lock.Unlock()
|
||||||
|
}()
|
||||||
|
runner.Input.SendPacket(pk)
|
||||||
|
timer := time.NewTimer(timeout)
|
||||||
|
defer timer.Stop()
|
||||||
|
select {
|
||||||
|
case rtnPk := <-respCh:
|
||||||
|
return rtnPk, nil
|
||||||
|
|
||||||
|
case <-timer.C:
|
||||||
|
return nil, fmt.Errorf("PacketRpc timeout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (runner *MShellProc) ProcessPackets() {
|
||||||
|
for pk := range runner.Output.MainCh {
|
||||||
|
if rpcPk, ok := pk.(packet.RpcPacketType); ok {
|
||||||
|
rpcId := rpcPk.GetPacketId()
|
||||||
|
runner.Lock.Lock()
|
||||||
|
entry := runner.RpcMap[rpcId]
|
||||||
|
if entry != nil {
|
||||||
|
delete(runner.RpcMap, rpcId)
|
||||||
|
go func() {
|
||||||
|
entry.RespCh <- rpcPk
|
||||||
|
close(entry.RespCh)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
runner.Lock.Unlock()
|
||||||
|
|
||||||
|
}
|
||||||
|
if pk.GetType() == packet.CmdDataPacketStr {
|
||||||
|
dataPacket := pk.(*packet.CmdDataPacketType)
|
||||||
|
fmt.Printf("cmd-data %s pty=%d run=%d\n", dataPacket.CK, len(dataPacket.PtyData), len(dataPacket.RunData))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if pk.GetType() == packet.InitPacketStr {
|
||||||
|
initPacket := pk.(*packet.InitPacketType)
|
||||||
|
fmt.Printf("runner-init %s user=%s dir=%s\n", initPacket.MShellHomeDir, initPacket.User, initPacket.HomeDir)
|
||||||
|
runner.Lock.Lock()
|
||||||
|
runner.Connected = true
|
||||||
|
runner.User = initPacket.User
|
||||||
|
runner.CurDir = initPacket.HomeDir
|
||||||
|
runner.HomeDir = initPacket.HomeDir
|
||||||
|
runner.Env = initPacket.Env
|
||||||
|
if runner.Local {
|
||||||
|
runner.Host = "local"
|
||||||
|
}
|
||||||
|
runner.Lock.Unlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if pk.GetType() == packet.MessagePacketStr {
|
||||||
|
msgPacket := pk.(*packet.MessagePacketType)
|
||||||
|
fmt.Printf("# %s\n", msgPacket.Message)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if pk.GetType() == packet.RawPacketStr {
|
||||||
|
rawPacket := pk.(*packet.RawPacketType)
|
||||||
|
fmt.Printf("stderr> %s\n", rawPacket.Data)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fmt.Printf("runner-packet: %v\n", pk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *MShellProc) GetPrompt() string {
|
||||||
|
r.Lock.Lock()
|
||||||
|
defer r.Lock.Unlock()
|
||||||
|
var curDir = r.CurDir
|
||||||
|
if r.CurDir == r.HomeDir {
|
||||||
|
curDir = "~"
|
||||||
|
} else if strings.HasPrefix(r.CurDir, r.HomeDir+"/") {
|
||||||
|
curDir = "~/" + r.CurDir[0:len(r.HomeDir)+1]
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("[%s@%s %s]", r.User, r.Host, curDir)
|
||||||
|
}
|
35
pkg/scpacket/scpacket.go
Normal file
35
pkg/scpacket/scpacket.go
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
package scpacket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
"github.com/scripthaus-dev/mshell/pkg/packet"
|
||||||
|
)
|
||||||
|
|
||||||
|
const FeCommandPacketStr = "fecmd"
|
||||||
|
|
||||||
|
type RemoteState struct {
|
||||||
|
RemoteId string `json:"remoteid"`
|
||||||
|
RemoteName string `json:"remotename"`
|
||||||
|
Cwd string `json:"cwd"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FeCommandPacketType struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
SessionId string `json:"sessionid"`
|
||||||
|
WindowId string `json:"windowid"`
|
||||||
|
CmdStr string `json:"cmdstr"`
|
||||||
|
RemoteState RemoteState `json:"remotestate"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
packet.RegisterPacketType(FeCommandPacketStr, reflect.TypeOf(&FeCommandPacketType{}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*FeCommandPacketType) GetType() string {
|
||||||
|
return FeCommandPacketStr
|
||||||
|
}
|
||||||
|
|
||||||
|
func MakeFeCommandPacket() *FeCommandPacketType {
|
||||||
|
return &FeCommandPacketType{Type: FeCommandPacketStr}
|
||||||
|
}
|
@ -1,12 +1,16 @@
|
|||||||
package sstore
|
package sstore
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
"path"
|
"path"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/jmoiron/sqlx"
|
"github.com/jmoiron/sqlx"
|
||||||
|
"github.com/scripthaus-dev/mshell/pkg/base"
|
||||||
"github.com/scripthaus-dev/mshell/pkg/packet"
|
"github.com/scripthaus-dev/mshell/pkg/packet"
|
||||||
"github.com/scripthaus-dev/sh2-server/pkg/scbase"
|
"github.com/scripthaus-dev/sh2-server/pkg/scbase"
|
||||||
|
|
||||||
@ -20,17 +24,26 @@ const LineTypeCmd = "cmd"
|
|||||||
const LineTypeText = "text"
|
const LineTypeText = "text"
|
||||||
const DBFileName = "sh2.db"
|
const DBFileName = "sh2.db"
|
||||||
|
|
||||||
|
const DefaultSessionName = "default"
|
||||||
|
const DefaultWindowName = "default"
|
||||||
|
const LocalRemoteName = "local"
|
||||||
|
|
||||||
|
var globalDBLock = &sync.Mutex{}
|
||||||
|
var globalDB *sqlx.DB
|
||||||
|
var globalDBErr error
|
||||||
|
|
||||||
func GetSessionDBName() string {
|
func GetSessionDBName() string {
|
||||||
scHome := scbase.GetScHomeDir()
|
scHome := scbase.GetScHomeDir()
|
||||||
return path.Join(scHome, DBFileName)
|
return path.Join(scHome, DBFileName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func OpenConnPool() (*sqlx.DB, error) {
|
func GetDB() (*sqlx.DB, error) {
|
||||||
connPool, err := sqlx.Open("sqlite3", GetSessionDBName())
|
globalDBLock.Lock()
|
||||||
if err != nil {
|
defer globalDBLock.Unlock()
|
||||||
return nil, err
|
if globalDB == nil && globalDBErr == nil {
|
||||||
|
globalDB, globalDBErr = sqlx.Open("sqlite3", GetSessionDBName())
|
||||||
}
|
}
|
||||||
return connPool, nil
|
return globalDB, globalDBErr
|
||||||
}
|
}
|
||||||
|
|
||||||
type SessionType struct {
|
type SessionType struct {
|
||||||
@ -71,6 +84,7 @@ type LineType struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type RemoteType struct {
|
type RemoteType struct {
|
||||||
|
RowId int64 `json:"rowid"`
|
||||||
RemoteId string `json:"remoteid"`
|
RemoteId string `json:"remoteid"`
|
||||||
RemoteType string `json:"remotetype"`
|
RemoteType string `json:"remotetype"`
|
||||||
RemoteName string `json:"remotename"`
|
RemoteName string `json:"remotename"`
|
||||||
@ -124,3 +138,117 @@ func GetNextLine() int {
|
|||||||
NextLineId++
|
NextLineId++
|
||||||
return rtn
|
return rtn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NumSessions(ctx context.Context) (int, error) {
|
||||||
|
db, err := GetDB()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
query := "SELECT count(*) FROM session"
|
||||||
|
var count int
|
||||||
|
err = db.GetContext(ctx, &count, query)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetRemoteById(ctx context.Context, remoteId string) (*RemoteType, error) {
|
||||||
|
db, err := GetDB()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
query := `SELECT rowid, remoteid, remotetype, remotename, connectopts FROM remote WHERE remoteid = ?`
|
||||||
|
var remote RemoteType
|
||||||
|
err = db.GetContext(ctx, &remote, query, remoteId)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &remote, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func InsertRemote(ctx context.Context, remote *RemoteType) error {
|
||||||
|
if remote == nil {
|
||||||
|
return fmt.Errorf("cannot insert nil remote")
|
||||||
|
}
|
||||||
|
if remote.RowId != 0 {
|
||||||
|
return fmt.Errorf("cannot insert a remote that already has rowid set, rowid=%d", remote.RowId)
|
||||||
|
}
|
||||||
|
db, err := GetDB()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
query := `INSERT INTO remote (remoteid, remotetype, remotename, connectopts, ptyout) VALUES (:remoteid, :remotetype, :remotename, :connectopts, '')`
|
||||||
|
result, err := db.NamedExec(query, remote)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
remote.RowId, err = result.LastInsertId()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot get lastinsertid from insert remote: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func EnsureLocalRemote(ctx context.Context) error {
|
||||||
|
remoteId, err := base.GetRemoteId()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
remote, err := GetRemoteById(ctx, remoteId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if remote != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// create the local remote
|
||||||
|
localRemote := &RemoteType{
|
||||||
|
RemoteId: remoteId,
|
||||||
|
RemoteType: "ssh",
|
||||||
|
RemoteName: LocalRemoteName,
|
||||||
|
}
|
||||||
|
err = InsertRemote(ctx, localRemote)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateInitialSession(ctx context.Context) error {
|
||||||
|
db, err := GetDB()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
session := &SessionType{
|
||||||
|
SessionId: uuid.New().String(),
|
||||||
|
Name: DefaultSessionName,
|
||||||
|
}
|
||||||
|
window := &WindowType{
|
||||||
|
SessionId: session.SessionId,
|
||||||
|
WindowId: uuid.New().String(),
|
||||||
|
Name: DefaultWindowName,
|
||||||
|
CurRemote: LocalRemoteName,
|
||||||
|
}
|
||||||
|
remoteId, err := base.GetRemoteId()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
localRemote := &RemoteType{
|
||||||
|
RemoteId: remoteId,
|
||||||
|
RemoteType: "ssh",
|
||||||
|
RemoteName: LocalRemoteName,
|
||||||
|
}
|
||||||
|
sessRemote := &SessionRemote{
|
||||||
|
SessionId: session.SessionId,
|
||||||
|
WindowId: window.WindowId,
|
||||||
|
RemoteId: remoteId,
|
||||||
|
RemoteName: localRemote.RemoteName,
|
||||||
|
Cwd: base.GetHomeDir(),
|
||||||
|
}
|
||||||
|
fmt.Printf("db=%v s=%v w=%v r=%v sr=%v\n", db, session, window, localRemote, sessRemote)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user