mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-08 19:38:51 +01:00
195 lines
4.7 KiB
Go
195 lines
4.7 KiB
Go
// Copyright 2022 Dashborg Inc
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
|
|
|
package server
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"sync"
|
|
|
|
"github.com/scripthaus-dev/mshell/pkg/base"
|
|
"github.com/scripthaus-dev/mshell/pkg/mpio"
|
|
"github.com/scripthaus-dev/mshell/pkg/packet"
|
|
"github.com/scripthaus-dev/mshell/pkg/shexec"
|
|
)
|
|
|
|
type MServer struct {
|
|
Lock *sync.Mutex
|
|
MainInput *packet.PacketParser
|
|
Sender *packet.PacketSender
|
|
FdContextMap map[base.CommandKey]*serverFdContext
|
|
Debug bool
|
|
}
|
|
|
|
func (m *MServer) Close() {
|
|
m.Sender.Close()
|
|
m.Sender.WaitForDone()
|
|
}
|
|
|
|
type serverFdContext struct {
|
|
M *MServer
|
|
Lock *sync.Mutex
|
|
Sender *packet.PacketSender
|
|
CK base.CommandKey
|
|
Readers map[int]*mpio.PacketReader
|
|
}
|
|
|
|
func (c *serverFdContext) processDataPacket(pk *packet.DataPacketType) {
|
|
c.Lock.Lock()
|
|
reader := c.Readers[pk.FdNum]
|
|
c.Lock.Unlock()
|
|
if reader == nil {
|
|
ackPacket := packet.MakeDataAckPacket()
|
|
ackPacket.CK = c.CK
|
|
ackPacket.FdNum = pk.FdNum
|
|
ackPacket.Error = "write to closed file (no fd)"
|
|
c.M.Sender.SendPacket(ackPacket)
|
|
return
|
|
}
|
|
reader.AddData(pk)
|
|
}
|
|
|
|
func (m *MServer) MakeServerFdContext(ck base.CommandKey) *serverFdContext {
|
|
m.Lock.Lock()
|
|
defer m.Lock.Unlock()
|
|
rtn := &serverFdContext{
|
|
M: m,
|
|
Lock: &sync.Mutex{},
|
|
Sender: m.Sender,
|
|
CK: ck,
|
|
Readers: make(map[int]*mpio.PacketReader),
|
|
}
|
|
m.FdContextMap[ck] = rtn
|
|
return rtn
|
|
}
|
|
|
|
func (m *MServer) ProcessCommandPacket(pk packet.CommandPacketType) {
|
|
ck := pk.GetCK()
|
|
if ck == "" {
|
|
m.Sender.SendMessage(fmt.Sprintf("received '%s' packet without ck", pk.GetType()))
|
|
return
|
|
}
|
|
m.Lock.Lock()
|
|
fdContext := m.FdContextMap[ck]
|
|
m.Lock.Unlock()
|
|
if fdContext == nil {
|
|
m.Sender.SendCmdError(ck, fmt.Errorf("no server context for ck '%s'", ck))
|
|
return
|
|
}
|
|
if pk.GetType() == packet.DataPacketStr {
|
|
dataPacket := pk.(*packet.DataPacketType)
|
|
fdContext.processDataPacket(dataPacket)
|
|
return
|
|
} else if pk.GetType() == packet.DataAckPacketStr {
|
|
m.Sender.SendPacket(pk)
|
|
return
|
|
} else {
|
|
m.Sender.SendCmdError(ck, fmt.Errorf("invalid packet '%s' received", packet.AsExtType(pk)))
|
|
return
|
|
}
|
|
}
|
|
|
|
func (c *serverFdContext) GetWriter(fdNum int) io.WriteCloser {
|
|
return mpio.MakePacketWriter(fdNum, c.Sender, c.CK)
|
|
}
|
|
|
|
func (c *serverFdContext) GetReader(fdNum int) io.ReadCloser {
|
|
c.Lock.Lock()
|
|
defer c.Lock.Unlock()
|
|
reader := mpio.MakePacketReader(fdNum)
|
|
c.Readers[fdNum] = reader
|
|
return reader
|
|
}
|
|
|
|
func (m *MServer) RemoveFdContext(ck base.CommandKey) {
|
|
m.Lock.Lock()
|
|
defer m.Lock.Unlock()
|
|
delete(m.FdContextMap, ck)
|
|
}
|
|
|
|
func (m *MServer) runCommand(runPacket *packet.RunPacketType) {
|
|
if err := runPacket.CK.Validate("packet"); err != nil {
|
|
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("server run packets require valid ck: %s", err))
|
|
return
|
|
}
|
|
fdContext := m.MakeServerFdContext(runPacket.CK)
|
|
go func() {
|
|
defer m.RemoveFdContext(runPacket.CK)
|
|
donePk, err := shexec.RunClientSSHCommandAndWait(runPacket, fdContext, shexec.SSHOpts{}, m, m.Debug)
|
|
if donePk != nil && !runPacket.Detached {
|
|
m.Sender.SendPacket(donePk)
|
|
}
|
|
if err != nil {
|
|
m.Sender.SendErrorResponse(runPacket.ReqId, err)
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (m *MServer) UnknownPacket(pk packet.PacketType) {
|
|
m.Sender.SendPacket(pk)
|
|
}
|
|
|
|
func RunServer() (int, error) {
|
|
debug := false
|
|
if len(os.Args) >= 3 && os.Args[2] == "--debug" {
|
|
debug = true
|
|
}
|
|
server := &MServer{
|
|
Lock: &sync.Mutex{},
|
|
FdContextMap: make(map[base.CommandKey]*serverFdContext),
|
|
Debug: debug,
|
|
}
|
|
if debug {
|
|
packet.GlobalDebug = true
|
|
}
|
|
server.MainInput = packet.MakePacketParser(os.Stdin)
|
|
server.Sender = packet.MakePacketSender(os.Stdout)
|
|
defer server.Close()
|
|
var err error
|
|
initPacket, err := shexec.MakeServerInitPacket()
|
|
if err != nil {
|
|
return 1, err
|
|
}
|
|
server.Sender.SendPacket(initPacket)
|
|
builder := packet.MakeRunPacketBuilder()
|
|
for pk := range server.MainInput.MainCh {
|
|
if server.Debug {
|
|
fmt.Printf("PK> %s\n", packet.AsString(pk))
|
|
}
|
|
|
|
// run-start combo
|
|
ok, runPacket := builder.ProcessPacket(pk)
|
|
if server.Debug {
|
|
fmt.Printf("PP> %s | %v\n", pk.GetType(), ok)
|
|
}
|
|
if ok {
|
|
if runPacket != nil {
|
|
server.runCommand(runPacket)
|
|
continue
|
|
}
|
|
continue
|
|
}
|
|
if startPk, ok := pk.(*packet.CmdStartPacketType); ok {
|
|
if server.Debug {
|
|
fmt.Printf("START> %v", startPk)
|
|
}
|
|
server.Sender.SendPacket(startPk)
|
|
continue
|
|
}
|
|
|
|
// command packet
|
|
if cmdPk, ok := pk.(packet.CommandPacketType); ok {
|
|
server.ProcessCommandPacket(cmdPk)
|
|
continue
|
|
}
|
|
server.Sender.SendMessage(fmt.Sprintf("invalid packet '%s' sent to mshell", packet.AsString(pk)))
|
|
continue
|
|
}
|
|
return 0, nil
|
|
}
|