waveterm/pkg/server/server.go

168 lines
4.2 KiB
Go

// Copyright 2022 Dashborg Inc
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
package server
import (
"fmt"
"io"
"os"
"sync"
"github.com/scripthaus-dev/mshell/pkg/base"
"github.com/scripthaus-dev/mshell/pkg/mpio"
"github.com/scripthaus-dev/mshell/pkg/packet"
"github.com/scripthaus-dev/mshell/pkg/shexec"
)
type MServer struct {
Lock *sync.Mutex
MainInput *packet.PacketParser
Sender *packet.PacketSender
FdContextMap map[base.CommandKey]*serverFdContext
Debug bool
}
func (m *MServer) Close() {
m.Sender.Close()
m.Sender.WaitForDone()
}
type serverFdContext struct {
M *MServer
Lock *sync.Mutex
Sender *packet.PacketSender
CK base.CommandKey
Readers map[int]*mpio.PacketReader
}
func (c *serverFdContext) processDataPacket(pk *packet.DataPacketType) {
c.Lock.Lock()
reader := c.Readers[pk.FdNum]
c.Lock.Unlock()
if reader == nil {
ackPacket := packet.MakeDataAckPacket()
ackPacket.CK = c.CK
ackPacket.FdNum = pk.FdNum
ackPacket.Error = "write to closed file (no fd)"
c.M.Sender.SendPacket(ackPacket)
return
}
reader.AddData(pk)
}
func (m *MServer) MakeServerFdContext(ck base.CommandKey) *serverFdContext {
rtn := &serverFdContext{
M: m,
Lock: &sync.Mutex{},
Sender: m.Sender,
CK: ck,
Readers: make(map[int]*mpio.PacketReader),
}
return rtn
}
func (m *MServer) ProcessCommandPacket(pk packet.CommandPacketType) {
ck := pk.GetCK()
if ck == "" {
m.Sender.SendErrorPacket(fmt.Sprintf("received '%s' packet without ck", pk.GetType()))
return
}
m.Lock.Lock()
fdContext := m.FdContextMap[ck]
m.Lock.Unlock()
if fdContext == nil {
m.Sender.SendCKErrorPacket(ck, fmt.Sprintf("no server context for ck '%s'", ck))
return
}
if pk.GetType() == packet.DataPacketStr {
dataPacket := pk.(*packet.DataPacketType)
fdContext.processDataPacket(dataPacket)
return
} else if pk.GetType() == packet.DataAckPacketStr {
m.Sender.SendPacket(pk)
return
} else {
m.Sender.SendCKErrorPacket(ck, fmt.Sprintf("invalid packet '%s' received", packet.AsExtType(pk)))
return
}
}
func (c *serverFdContext) GetWriter(fdNum int) io.WriteCloser {
return mpio.MakePacketWriter(fdNum, c.Sender, c.CK)
}
func (c *serverFdContext) GetReader(fdNum int) io.ReadCloser {
c.Lock.Lock()
defer c.Lock.Unlock()
reader := mpio.MakePacketReader(fdNum)
c.Readers[fdNum] = reader
return reader
}
func (m *MServer) runCommand(runPacket *packet.RunPacketType) {
if err := runPacket.CK.Validate("packet"); err != nil {
m.Sender.SendErrorPacket(fmt.Sprintf("server run packets require valid ck: %s", err))
return
}
fdContext := m.MakeServerFdContext(runPacket.CK)
m.Lock.Lock()
m.FdContextMap[runPacket.CK] = fdContext
m.Lock.Unlock()
go func() {
donePk, err := shexec.RunClientSSHCommandAndWait(runPacket, fdContext, shexec.SSHOpts{}, m, m.Debug)
if donePk != nil {
m.Sender.SendPacket(donePk)
}
if err != nil {
m.Sender.SendCKErrorPacket(runPacket.CK, err.Error())
}
}()
}
func (m *MServer) UnknownPacket(pk packet.PacketType) {
m.Sender.SendPacket(pk)
}
func RunServer() (int, error) {
debug := false
if len(os.Args) >= 3 && os.Args[2] == "--debug" {
debug = true
}
server := &MServer{
Lock: &sync.Mutex{},
FdContextMap: make(map[base.CommandKey]*serverFdContext),
Debug: debug,
}
if debug {
packet.GlobalDebug = true
}
server.MainInput = packet.MakePacketParser(os.Stdin)
server.Sender = packet.MakePacketSender(os.Stdout)
defer server.Close()
defer fmt.Printf("runserver done\n")
initPacket := packet.MakeInitPacket()
initPacket.Version = base.MShellVersion
server.Sender.SendPacket(initPacket)
for pk := range server.MainInput.MainCh {
if server.Debug {
fmt.Printf("PK> %s\n", packet.AsString(pk))
}
if pk.GetType() == packet.RunPacketStr {
runPacket := pk.(*packet.RunPacketType)
server.runCommand(runPacket)
continue
}
if cmdPk, ok := pk.(packet.CommandPacketType); ok {
server.ProcessCommandPacket(cmdPk)
continue
}
server.Sender.SendErrorPacket(fmt.Sprintf("invalid packet '%s' sent to mshell", packet.AsExtType(pk)))
continue
}
return 0, nil
}