updates for getting pty updates sent again

This commit is contained in:
sawka 2022-07-13 14:16:08 -07:00
parent db841f2951
commit c1ace6f5d6
9 changed files with 364 additions and 177 deletions

View File

@ -16,11 +16,11 @@ import (
"github.com/google/uuid"
"github.com/gorilla/mux"
"github.com/scripthaus-dev/mshell/pkg/cmdtail"
"github.com/scripthaus-dev/mshell/pkg/packet"
"github.com/scripthaus-dev/sh2-server/pkg/remote"
"github.com/scripthaus-dev/sh2-server/pkg/scbase"
"github.com/scripthaus-dev/sh2-server/pkg/scpacket"
"github.com/scripthaus-dev/sh2-server/pkg/scws"
"github.com/scripthaus-dev/sh2-server/pkg/sstore"
"github.com/scripthaus-dev/sh2-server/pkg/wsshell"
)
@ -35,18 +35,16 @@ const MainServerAddr = "localhost:8080"
const WSStateReconnectTime = 30 * time.Second
const WSStatePacketChSize = 20
const MaxInputDataSize = 1000
var GlobalLock = &sync.Mutex{}
var WSStateMap = make(map[string]*WSState) // clientid -> WsState
var WSStateMap = make(map[string]*scws.WSState) // clientid -> WsState
func setWSState(state *WSState) {
func setWSState(state *scws.WSState) {
GlobalLock.Lock()
defer GlobalLock.Unlock()
WSStateMap[state.ClientId] = state
}
func getWSState(clientId string) *WSState {
func getWSState(clientId string) *scws.WSState {
GlobalLock.Lock()
defer GlobalLock.Unlock()
return WSStateMap[clientId]
@ -62,103 +60,10 @@ func removeWSStateAfterTimeout(clientId string, connectTime time.Time, waitDurat
return
}
delete(WSStateMap, clientId)
err := state.CloseTailer()
if err != nil {
fmt.Printf("[error] closing tailer on ws %v\n", err)
}
state.UnWatchScreen()
}()
}
type WSState struct {
Lock *sync.Mutex
ClientId string
ConnectTime time.Time
Shell *wsshell.WSShell
Tailer *cmdtail.Tailer
PacketCh chan packet.PacketType
}
func MakeWSState(clientId string) (*WSState, error) {
var err error
rtn := &WSState{}
rtn.Lock = &sync.Mutex{}
rtn.ClientId = clientId
rtn.ConnectTime = time.Now()
rtn.PacketCh = make(chan packet.PacketType, WSStatePacketChSize)
chSender := packet.MakeChannelPacketSender(rtn.PacketCh)
gen := scbase.ScFileNameGenerator{ScHome: scbase.GetScHomeDir()}
rtn.Tailer, err = cmdtail.MakeTailer(chSender, gen)
if err != nil {
return nil, err
}
go func() {
defer close(rtn.PacketCh)
rtn.Tailer.Run()
}()
go rtn.runTailerToWS()
return rtn, nil
}
func (ws *WSState) CloseTailer() error {
return ws.Tailer.Close()
}
func (ws *WSState) getShell() *wsshell.WSShell {
ws.Lock.Lock()
defer ws.Lock.Unlock()
return ws.Shell
}
func (ws *WSState) runTailerToWS() {
for pk := range ws.PacketCh {
if pk.GetType() == "cmddata" {
dataPacket := pk.(*packet.CmdDataPacketType)
err := ws.writePacket(dataPacket)
if err != nil {
fmt.Printf("[error] writing packet to ws: %v\n", err)
}
continue
}
fmt.Printf("tailer-to-ws, bad packet %v\n", pk.GetType())
}
}
func (ws *WSState) writePacket(pk packet.PacketType) error {
shell := ws.getShell()
if shell == nil || shell.IsClosed() {
return fmt.Errorf("cannot write packet, empty or closed wsshell")
}
err := shell.WriteJson(pk)
if err != nil {
return err
}
return nil
}
func (ws *WSState) getConnectTime() time.Time {
ws.Lock.Lock()
defer ws.Lock.Unlock()
return ws.ConnectTime
}
func (ws *WSState) updateConnectTime() {
ws.Lock.Lock()
defer ws.Lock.Unlock()
ws.ConnectTime = time.Now()
}
func (ws *WSState) replaceExistingShell(shell *wsshell.WSShell) {
ws.Lock.Lock()
defer ws.Lock.Unlock()
if ws.Shell == nil {
ws.Shell = shell
return
}
ws.Shell.Conn.Close()
ws.Shell = shell
return
}
func HandleWs(w http.ResponseWriter, r *http.Request) {
shell, err := wsshell.StartWS(w, r)
if err != nil {
@ -173,57 +78,20 @@ func HandleWs(w http.ResponseWriter, r *http.Request) {
}
state := getWSState(clientId)
if state == nil {
state, err = MakeWSState(clientId)
if err != nil {
fmt.Printf("cannot make wsstate: %v\n", err)
close(shell.WriteChan)
return
}
state.replaceExistingShell(shell)
state = scws.MakeWSState(clientId)
state.ReplaceShell(shell)
setWSState(state)
} else {
state.updateConnectTime()
state.replaceExistingShell(shell)
state.UpdateConnectTime()
state.ReplaceShell(shell)
}
stateConnectTime := state.getConnectTime()
stateConnectTime := state.GetConnectTime()
defer func() {
removeWSStateAfterTimeout(clientId, stateConnectTime, WSStateReconnectTime)
}()
shell.WriteJson(map[string]interface{}{"type": "hello"}) // let client know we accepted this connection, ignore error
fmt.Printf("WebSocket opened %s %s\n", shell.RemoteAddr, state.ClientId)
for msgBytes := range shell.ReadChan {
pk, err := packet.ParseJsonPacket(msgBytes)
if err != nil {
fmt.Printf("error unmarshalling ws message: %v\n", err)
continue
}
if pk.GetType() == "getcmd" {
getPk := pk.(*packet.GetCmdPacketType)
done, err := state.Tailer.AddWatch(getPk)
if err != nil {
// TODO: send responseerror
respPk := packet.MakeErrorResponsePacket(getPk.ReqId, err)
fmt.Printf("[error] adding watch to tailer: %v\n", err)
fmt.Printf("%v\n", respPk)
}
if done {
respPk := packet.MakeResponsePacket(getPk.ReqId, true)
fmt.Printf("%v\n", respPk)
// TODO: send response
}
continue
}
if pk.GetType() == "input" {
go func() {
err = sendCmdInput(pk.(*packet.InputPacketType))
if err != nil {
fmt.Printf("[error] sending command input: %v\n", err)
}
}()
continue
}
fmt.Printf("got ws bad message: %v\n", pk.GetType())
}
fmt.Printf("WebSocket opened %s %s\n", state.ClientId, shell.RemoteAddr)
state.RunWSRead()
}
// todo: sync multiple writes to the same fifoName into a single go-routine and do liveness checking on fifo
@ -248,28 +116,6 @@ func writeToFifo(fifoName string, data []byte) error {
return nil
}
func sendCmdInput(pk *packet.InputPacketType) error {
err := pk.CK.Validate("input packet")
if err != nil {
return err
}
if pk.RemoteId == "" {
return fmt.Errorf("input must set remoteid")
}
if len(pk.InputData64) == 0 && pk.SigNum == 0 {
return fmt.Errorf("empty input packet")
}
inputLen := packet.B64DecodedLen(pk.InputData64)
if inputLen > MaxInputDataSize {
return fmt.Errorf("input data size too large, len=%d (max=%d)", inputLen, MaxInputDataSize)
}
msh := remote.GetRemoteById(pk.RemoteId)
if msh == nil {
return fmt.Errorf("cannot connect to remote")
}
return msh.SendInput(pk)
}
// params: sessionid
func HandleGetSession(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))

View File

@ -251,7 +251,7 @@ func RunCommand(ctx context.Context, pk *scpacket.FeCommandPacketType, cmdId str
DonePk: nil,
RunOut: nil,
}
err = sstore.AppendToCmdPtyBlob(ctx, cmd.SessionId, cmd.CmdId, nil)
err = sstore.AppendToCmdPtyBlob(ctx, cmd.SessionId, cmd.CmdId, nil, sstore.PosAppend)
if err != nil {
return nil, err
}
@ -358,7 +358,7 @@ func (runner *MShellProc) ProcessPackets() {
}
var ack *packet.DataAckPacketType
if len(realData) > 0 {
err = sstore.AppendToCmdPtyBlob(context.Background(), dataPk.CK.GetSessionId(), dataPk.CK.GetCmdId(), realData)
err = sstore.AppendToCmdPtyBlob(context.Background(), dataPk.CK.GetSessionId(), dataPk.CK.GetCmdId(), realData, sstore.PosAppend)
if err != nil {
ack = makeDataAckPacket(dataPk.CK, dataPk.FdNum, 0, err)
} else {

View File

@ -7,6 +7,7 @@ import (
)
const FeCommandPacketStr = "fecmd"
const WatchScreenPacketStr = "watchscreen"
type RemoteState struct {
RemoteId string `json:"remoteid"`
@ -24,7 +25,8 @@ type FeCommandPacketType struct {
}
func init() {
packet.RegisterPacketType(FeCommandPacketStr, reflect.TypeOf(&FeCommandPacketType{}))
packet.RegisterPacketType(FeCommandPacketStr, reflect.TypeOf(FeCommandPacketType{}))
packet.RegisterPacketType(WatchScreenPacketStr, reflect.TypeOf(WatchScreenPacketType{}))
}
func (*FeCommandPacketType) GetType() string {
@ -34,3 +36,17 @@ func (*FeCommandPacketType) GetType() string {
func MakeFeCommandPacket() *FeCommandPacketType {
return &FeCommandPacketType{Type: FeCommandPacketStr}
}
type WatchScreenPacketType struct {
Type string `json:"type"`
SessionId string `json:"sessionid"`
ScreenId string `json:"screenid"`
}
func (*WatchScreenPacketType) GetType() string {
return WatchScreenPacketStr
}
func MakeWatchScreenPacket() *WatchScreenPacketType {
return &WatchScreenPacketType{Type: WatchScreenPacketStr}
}

178
pkg/scws/scws.go Normal file
View File

@ -0,0 +1,178 @@
package scws
import (
"fmt"
"sync"
"time"
"github.com/google/uuid"
"github.com/scripthaus-dev/mshell/pkg/packet"
"github.com/scripthaus-dev/sh2-server/pkg/remote"
"github.com/scripthaus-dev/sh2-server/pkg/scpacket"
"github.com/scripthaus-dev/sh2-server/pkg/sstore"
"github.com/scripthaus-dev/sh2-server/pkg/wsshell"
)
const WSStatePacketChSize = 20
const MaxInputDataSize = 1000
type WSState struct {
Lock *sync.Mutex
ClientId string
ConnectTime time.Time
Shell *wsshell.WSShell
UpdateCh chan interface{}
UpdateQueue []interface{}
SessionId string
ScreenId string
}
func MakeWSState(clientId string) *WSState {
rtn := &WSState{}
rtn.Lock = &sync.Mutex{}
rtn.ClientId = clientId
rtn.ConnectTime = time.Now()
return rtn
}
func (ws *WSState) GetShell() *wsshell.WSShell {
ws.Lock.Lock()
defer ws.Lock.Unlock()
return ws.Shell
}
func (ws *WSState) WriteUpdate(update interface{}) error {
shell := ws.GetShell()
if shell == nil {
return fmt.Errorf("cannot write update, empty shell")
}
err := shell.WriteJson(update)
if err != nil {
return err
}
return nil
}
func (ws *WSState) UpdateConnectTime() {
ws.Lock.Lock()
defer ws.Lock.Unlock()
ws.ConnectTime = time.Now()
}
func (ws *WSState) GetConnectTime() time.Time {
ws.Lock.Lock()
defer ws.Lock.Unlock()
return ws.ConnectTime
}
func (ws *WSState) WatchScreen(sessionId string, screenId string) {
ws.Lock.Lock()
defer ws.Lock.Unlock()
if ws.SessionId == sessionId && ws.ScreenId == screenId {
return
}
ws.SessionId = sessionId
ws.ScreenId = screenId
ws.UpdateCh = sstore.MainBus.RegisterChannel(ws.ClientId, ws.SessionId)
go ws.RunUpdates()
}
func (ws *WSState) UnWatchScreen() {
ws.Lock.Lock()
defer ws.Lock.Unlock()
sstore.MainBus.UnregisterChannel(ws.ClientId)
ws.SessionId = ""
ws.ScreenId = ""
}
func (ws *WSState) getUpdateCh() chan interface{} {
ws.Lock.Lock()
defer ws.Lock.Unlock()
return ws.UpdateCh
}
func (ws *WSState) RunUpdates() {
updateCh := ws.getUpdateCh()
if updateCh == nil {
return
}
for update := range updateCh {
shell := ws.GetShell()
if shell != nil {
shell.WriteJson(update)
}
}
}
func (ws *WSState) ReplaceShell(shell *wsshell.WSShell) {
ws.Lock.Lock()
defer ws.Lock.Unlock()
if ws.Shell == nil {
ws.Shell = shell
return
}
ws.Shell.Conn.Close()
ws.Shell = shell
return
}
func (ws *WSState) RunWSRead() {
shell := ws.GetShell()
if shell == nil {
return
}
for msgBytes := range shell.ReadChan {
pk, err := packet.ParseJsonPacket(msgBytes)
if err != nil {
fmt.Printf("error unmarshalling ws message: %v\n", err)
continue
}
if pk.GetType() == "input" {
go func() {
err = sendCmdInput(pk.(*packet.InputPacketType))
if err != nil {
fmt.Printf("[error] sending command input: %v\n", err)
}
}()
continue
}
if pk.GetType() == "watchscreen" {
wsPk := pk.(*scpacket.WatchScreenPacketType)
if _, err := uuid.Parse(wsPk.SessionId); err != nil {
fmt.Printf("[error] invalid watchscreen sessionid: %v\n", err)
continue
}
if _, err := uuid.Parse(wsPk.ScreenId); err != nil {
fmt.Printf("[error] invalid watchscreen screenid: %v\n", err)
continue
}
ws.WatchScreen(wsPk.SessionId, wsPk.ScreenId)
fmt.Printf("[ws] watch screen clientid=%s %s/%s\n", ws.ClientId, wsPk.SessionId, wsPk.ScreenId)
continue
}
fmt.Printf("got ws bad message: %v\n", pk.GetType())
}
}
func sendCmdInput(pk *packet.InputPacketType) error {
err := pk.CK.Validate("input packet")
if err != nil {
return err
}
if pk.RemoteId == "" {
return fmt.Errorf("input must set remoteid")
}
if len(pk.InputData64) == 0 && pk.SigNum == 0 {
return fmt.Errorf("empty input packet")
}
inputLen := packet.B64DecodedLen(pk.InputData64)
if inputLen > MaxInputDataSize {
return fmt.Errorf("input data size too large, len=%d (max=%d)", inputLen, MaxInputDataSize)
}
msh := remote.GetRemoteById(pk.RemoteId)
if msh == nil {
return fmt.Errorf("cannot connect to remote")
}
return msh.SendInput(pk)
}

View File

@ -2,19 +2,48 @@ package sstore
import (
"context"
"encoding/base64"
"fmt"
"os"
"github.com/scripthaus-dev/sh2-server/pkg/scbase"
)
func AppendToCmdPtyBlob(ctx context.Context, sessionId string, cmdId string, data []byte) error {
const PosAppend = -1
// when calling with PosAppend, this is not multithread safe (since file could be modified).
// we need to know the real position of the write to send a proper pty update to the frontends
// in practice this is fine since we only use PosAppend in non-detached mode where
// we are reading/writing a stream in order with a single goroutine
func AppendToCmdPtyBlob(ctx context.Context, sessionId string, cmdId string, data []byte, pos int64) error {
ptyOutFileName, err := scbase.PtyOutFile(sessionId, cmdId)
if err != nil {
return err
}
fd, err := os.OpenFile(ptyOutFileName, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0600)
if err != nil {
return err
var fd *os.File
var realPos int64
if pos == PosAppend {
fd, err = os.OpenFile(ptyOutFileName, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0600)
if err != nil {
return err
}
finfo, err := fd.Stat()
if err != nil {
return err
}
realPos = finfo.Size()
} else {
fd, err = os.OpenFile(ptyOutFileName, os.O_WRONLY|os.O_CREATE, 0600)
if err != nil {
return err
}
realPos, err = fd.Seek(pos, 0)
if err != nil {
return err
}
if realPos != pos {
return fmt.Errorf("could not seek to pos:%d (realpos=%d)", pos, realPos)
}
}
defer fd.Close()
if len(data) == 0 {
@ -24,5 +53,14 @@ func AppendToCmdPtyBlob(ctx context.Context, sessionId string, cmdId string, dat
if err != nil {
return err
}
data64 := base64.StdEncoding.EncodeToString(data)
update := &PtyDataUpdate{
SessionId: sessionId,
CmdId: cmdId,
PtyPos: realPos,
PtyData64: data64,
PtyDataLen: int64(len(data)),
}
MainBus.SendUpdate(sessionId, update)
return nil
}

View File

@ -58,14 +58,14 @@ func quickSetJson(ptr interface{}, m map[string]interface{}, name string) {
return
}
if str == "" {
return
str = "{}"
}
json.Unmarshal([]byte(str), ptr)
}
func quickJson(v interface{}) string {
if v == nil {
return ""
return "{}"
}
barr, _ := json.Marshal(v)
return string(barr)
@ -81,14 +81,14 @@ func quickScanJson(ptr interface{}, val interface{}) error {
barrVal = []byte(strVal)
}
if len(barrVal) == 0 {
return nil
barrVal = []byte("{}")
}
return json.Unmarshal(barrVal, ptr)
}
func quickValueJson(v interface{}) (driver.Value, error) {
if v == nil {
return "", nil
return "{}", nil
}
barr, err := json.Marshal(v)
if err != nil {

View File

@ -147,6 +147,7 @@ type ScreenWindowType struct {
type HistoryItemType struct {
CmdStr string `json:"cmdstr"`
Remove bool `json:"remove"`
}
type RemoteState struct {
@ -194,6 +195,7 @@ type LineType struct {
LineType string `json:"linetype"`
Text string `json:"text,omitempty"`
CmdId string `json:"cmdid,omitempty"`
Remove bool `json:"remove,omitempty"`
}
type SSHOpts struct {
@ -239,6 +241,7 @@ type CmdType struct {
DonePk *packet.CmdDonePacketType `json:"donepk"`
UsedRows int64 `json:"usedrows"`
RunOut []packet.PacketType `json:"runout"`
Remove bool `json:"remove"`
}
func (r *RemoteType) ToMap() map[string]interface{} {

102
pkg/sstore/updatebus.go Normal file
View File

@ -0,0 +1,102 @@
package sstore
import "sync"
var MainBus *UpdateBus = MakeUpdateBus()
type UpdateCmd struct {
CmdId string
Status string
}
type PtyDataUpdate struct {
SessionId string `json:"sessionid"`
CmdId string `json:"cmdid"`
PtyPos int64 `json:"ptypos"`
PtyData64 string `json:"ptydata64"`
PtyDataLen int64 `json:"ptydatalen"`
}
type WindowUpdate struct {
Window WindowType `json:"window"`
Remove bool `json:"remove,omitempty"`
}
type SessionUpdate struct {
Session SessionType `json:"session"`
Remove bool `json:"remove,omitempty"`
}
type CmdUpdate struct {
Cmd CmdType `json:"cmd"`
Remove bool `json:"remove,omitempty"`
}
type ScreenUpdate struct {
Screen CmdType `json:"screen"`
Remove bool `json:"remove,omitempty"`
}
type UpdateChannel struct {
SessionId string
ClientId string
Ch chan interface{}
}
func (uch UpdateChannel) Match(sessionId string) bool {
if sessionId == "" {
return true
}
return sessionId == uch.SessionId
}
type UpdateBus struct {
Lock *sync.Mutex
Channels map[string]UpdateChannel
}
func MakeUpdateBus() *UpdateBus {
return &UpdateBus{
Lock: &sync.Mutex{},
Channels: make(map[string]UpdateChannel),
}
}
func (bus *UpdateBus) RegisterChannel(clientId string, sessionId string) chan interface{} {
bus.Lock.Lock()
defer bus.Lock.Unlock()
uch, found := bus.Channels[clientId]
if found {
close(uch.Ch)
uch.SessionId = sessionId
uch.Ch = make(chan interface{})
} else {
uch = UpdateChannel{
ClientId: clientId,
SessionId: sessionId,
Ch: make(chan interface{}),
}
}
bus.Channels[clientId] = uch
return uch.Ch
}
func (bus *UpdateBus) UnregisterChannel(clientId string) {
bus.Lock.Lock()
defer bus.Lock.Unlock()
uch, found := bus.Channels[clientId]
if found {
close(uch.Ch)
delete(bus.Channels, clientId)
}
}
func (bus *UpdateBus) SendUpdate(sessionId string, update interface{}) {
bus.Lock.Lock()
defer bus.Lock.Unlock()
for _, uch := range bus.Channels {
if uch.Match(sessionId) {
uch.Ch <- update
}
}
}

View File

@ -2,6 +2,7 @@ package wsshell
import (
"encoding/json"
"fmt"
"log"
"net/http"
"net/url"
@ -65,6 +66,9 @@ func (ws *WSShell) WritePing() error {
}
func (ws *WSShell) WriteJson(val interface{}) error {
if ws.IsClosed() {
return fmt.Errorf("cannot write packet, empty or closed wsshell")
}
barr, err := json.Marshal(val)
if err != nil {
return err