mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-21 21:32:13 +01:00
working on server mode. extract fdcontext as interface. create packet writer/reader for mpio. hook up to serverFdContext.
This commit is contained in:
parent
d7eb2526f0
commit
1d44afc10e
@ -376,7 +376,7 @@ func parseClientOpts() (*shexec.ClientOpts, error) {
|
||||
return nil, fmt.Errorf("'--sudo-with-password [pw]', missing password")
|
||||
}
|
||||
opts.Sudo = true
|
||||
opts.SSHOpts.SudoWithPass = true
|
||||
opts.SudoWithPass = true
|
||||
opts.SudoPw = iter.Next()
|
||||
continue
|
||||
}
|
||||
@ -385,7 +385,7 @@ func parseClientOpts() (*shexec.ClientOpts, error) {
|
||||
return nil, fmt.Errorf("'--sudo-with-passfile [file]', missing file")
|
||||
}
|
||||
opts.Sudo = true
|
||||
opts.SSHOpts.SudoWithPass = true
|
||||
opts.SudoWithPass = true
|
||||
fileName := iter.Next()
|
||||
contents, err := os.ReadFile(fileName)
|
||||
if err != nil {
|
||||
@ -433,7 +433,7 @@ func handleClient() (int, error) {
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
donePacket, err := shexec.RunClientSSHCommandAndWait(runPacket, opts.SSHOpts, opts.Debug)
|
||||
donePacket, err := shexec.RunClientSSHCommandAndWait(runPacket, shexec.StdContext{}, opts.SSHOpts, opts.Debug)
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
|
@ -8,7 +8,6 @@ package mpio
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/scripthaus-dev/mshell/pkg/packet"
|
||||
@ -18,13 +17,13 @@ type FdReader struct {
|
||||
CVar *sync.Cond
|
||||
M *Multiplexer
|
||||
FdNum int
|
||||
Fd *os.File
|
||||
Fd io.ReadCloser
|
||||
BufSize int
|
||||
Closed bool
|
||||
ShouldCloseFd bool
|
||||
}
|
||||
|
||||
func MakeFdReader(m *Multiplexer, fd *os.File, fdNum int, shouldCloseFd bool) *FdReader {
|
||||
func MakeFdReader(m *Multiplexer, fd io.ReadCloser, fdNum int, shouldCloseFd bool) *FdReader {
|
||||
fr := &FdReader{
|
||||
CVar: sync.NewCond(&sync.Mutex{}),
|
||||
M: m,
|
||||
|
@ -8,7 +8,7 @@ package mpio
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
@ -17,13 +17,13 @@ type FdWriter struct {
|
||||
M *Multiplexer
|
||||
FdNum int
|
||||
Buffer []byte
|
||||
Fd *os.File
|
||||
Fd io.WriteCloser
|
||||
Eof bool
|
||||
Closed bool
|
||||
ShouldCloseFd bool
|
||||
}
|
||||
|
||||
func MakeFdWriter(m *Multiplexer, fd *os.File, fdNum int, shouldCloseFd bool) *FdWriter {
|
||||
func MakeFdWriter(m *Multiplexer, fd io.WriteCloser, fdNum int, shouldCloseFd bool) *FdWriter {
|
||||
fw := &FdWriter{
|
||||
CVar: sync.NewCond(&sync.Mutex{}),
|
||||
Fd: fd,
|
||||
|
@ -9,6 +9,7 @@ package mpio
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
@ -111,13 +112,13 @@ func (m *Multiplexer) MakeStringFdReader(fdNum int, contents string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Multiplexer) MakeRawFdReader(fdNum int, fd *os.File, shouldClose bool) {
|
||||
func (m *Multiplexer) MakeRawFdReader(fdNum int, fd io.ReadCloser, shouldClose bool) {
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
m.FdReaders[fdNum] = MakeFdReader(m, fd, fdNum, shouldClose)
|
||||
}
|
||||
|
||||
func (m *Multiplexer) MakeRawFdWriter(fdNum int, fd *os.File, shouldClose bool) {
|
||||
func (m *Multiplexer) MakeRawFdWriter(fdNum int, fd io.WriteCloser, shouldClose bool) {
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
m.FdWriters[fdNum] = MakeFdWriter(m, fd, fdNum, shouldClose)
|
||||
|
96
pkg/mpio/packetreader.go
Normal file
96
pkg/mpio/packetreader.go
Normal file
@ -0,0 +1,96 @@
|
||||
// Copyright 2022 Dashborg Inc
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mpio
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/scripthaus-dev/mshell/pkg/packet"
|
||||
)
|
||||
|
||||
type PacketReader struct {
|
||||
CVar *sync.Cond
|
||||
FdNum int
|
||||
Buf []byte
|
||||
Eof bool
|
||||
Err error
|
||||
}
|
||||
|
||||
func MakePacketReader(fdNum int) *PacketReader {
|
||||
return &PacketReader{
|
||||
CVar: sync.NewCond(&sync.Mutex{}),
|
||||
FdNum: fdNum,
|
||||
}
|
||||
}
|
||||
|
||||
func (pr *PacketReader) AddData(pk *packet.DataPacketType) {
|
||||
pr.CVar.L.Lock()
|
||||
defer pr.CVar.L.Unlock()
|
||||
defer pr.CVar.Broadcast()
|
||||
if pr.Eof || pr.Err != nil {
|
||||
return
|
||||
}
|
||||
if pk.Data64 != "" {
|
||||
realData, err := base64.StdEncoding.DecodeString(pk.Data64)
|
||||
if err != nil {
|
||||
pr.Err = err
|
||||
return
|
||||
}
|
||||
pr.Buf = append(pr.Buf, realData...)
|
||||
}
|
||||
pr.Eof = pk.Eof
|
||||
if pk.Error != "" {
|
||||
pr.Err = errors.New(pk.Error)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (pr *PacketReader) Read(buf []byte) (int, error) {
|
||||
pr.CVar.L.Lock()
|
||||
defer pr.CVar.L.Unlock()
|
||||
for {
|
||||
if pr.Err != nil {
|
||||
return 0, pr.Err
|
||||
}
|
||||
if pr.Eof {
|
||||
return 0, io.EOF
|
||||
}
|
||||
if len(pr.Buf) == 0 {
|
||||
pr.CVar.Wait()
|
||||
continue
|
||||
}
|
||||
nr := copy(buf, pr.Buf)
|
||||
pr.Buf = pr.Buf[nr:]
|
||||
if len(pr.Buf) == 0 {
|
||||
pr.Buf = nil
|
||||
}
|
||||
return nr, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (pr *PacketReader) Close() error {
|
||||
pr.CVar.L.Lock()
|
||||
defer pr.CVar.L.Unlock()
|
||||
defer pr.CVar.Broadcast()
|
||||
if pr.Err == nil {
|
||||
pr.Err = io.ErrClosedPipe
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type NullReader struct{}
|
||||
|
||||
func (NullReader) Read(buf []byte) (int, error) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (NullReader) Close() error {
|
||||
return nil
|
||||
}
|
40
pkg/mpio/packetwriter.go
Normal file
40
pkg/mpio/packetwriter.go
Normal file
@ -0,0 +1,40 @@
|
||||
// Copyright 2022 Dashborg Inc
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mpio
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
|
||||
"github.com/scripthaus-dev/mshell/pkg/base"
|
||||
"github.com/scripthaus-dev/mshell/pkg/packet"
|
||||
)
|
||||
|
||||
type PacketWriter struct {
|
||||
FdNum int
|
||||
Sender *packet.PacketSender
|
||||
CK base.CommandKey
|
||||
}
|
||||
|
||||
func MakePacketWriter(fdNum int, sender *packet.PacketSender, ck base.CommandKey) *PacketWriter {
|
||||
return &PacketWriter{FdNum: fdNum, Sender: sender, CK: ck}
|
||||
}
|
||||
|
||||
func (pw *PacketWriter) Write(data []byte) (int, error) {
|
||||
pk := packet.MakeDataPacket()
|
||||
pk.CK = pw.CK
|
||||
pk.FdNum = pw.FdNum
|
||||
pk.Data64 = base64.StdEncoding.EncodeToString(data)
|
||||
return len(data), pw.Sender.SendPacket(pk)
|
||||
}
|
||||
|
||||
func (pw *PacketWriter) Close() error {
|
||||
pk := packet.MakeDataPacket()
|
||||
pk.CK = pw.CK
|
||||
pk.FdNum = pw.FdNum
|
||||
pk.Eof = true
|
||||
return pw.Sender.SendPacket(pk)
|
||||
}
|
@ -8,17 +8,21 @@ package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/scripthaus-dev/mshell/pkg/base"
|
||||
"github.com/scripthaus-dev/mshell/pkg/mpio"
|
||||
"github.com/scripthaus-dev/mshell/pkg/packet"
|
||||
"github.com/scripthaus-dev/mshell/pkg/shexec"
|
||||
)
|
||||
|
||||
type MServer struct {
|
||||
Lock *sync.Mutex
|
||||
MainInput *packet.PacketParser
|
||||
Sender *packet.PacketSender
|
||||
FdContext *serverFdContext
|
||||
}
|
||||
|
||||
func (m *MServer) Close() {
|
||||
@ -26,13 +30,73 @@ func (m *MServer) Close() {
|
||||
m.Sender.WaitForDone()
|
||||
}
|
||||
|
||||
type serverFdContext struct {
|
||||
M *MServer
|
||||
Lock *sync.Mutex
|
||||
Sender *packet.PacketSender
|
||||
CK base.CommandKey
|
||||
Readers map[int]*mpio.PacketReader
|
||||
}
|
||||
|
||||
func (m *MServer) MakeServerFdContext(ck base.CommandKey) *serverFdContext {
|
||||
rtn := &serverFdContext{
|
||||
M: m,
|
||||
Lock: &sync.Mutex{},
|
||||
Sender: m.Sender,
|
||||
CK: ck,
|
||||
Readers: make(map[int]*mpio.PacketReader),
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (c *serverFdContext) processDataPacket(pk *packet.DataPacketType) {
|
||||
c.Lock.Lock()
|
||||
reader := c.Readers[pk.FdNum]
|
||||
c.Lock.Unlock()
|
||||
if reader == nil {
|
||||
ackPacket := packet.MakeDataAckPacket()
|
||||
ackPacket.CK = c.CK
|
||||
ackPacket.FdNum = pk.FdNum
|
||||
ackPacket.Error = "write to closed file (no fd)"
|
||||
c.M.Sender.SendPacket(ackPacket)
|
||||
return
|
||||
}
|
||||
reader.AddData(pk)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *serverFdContext) GetWriter(fdNum int) io.WriteCloser {
|
||||
return mpio.MakePacketWriter(fdNum, c.Sender, c.CK)
|
||||
}
|
||||
|
||||
func (c *serverFdContext) GetReader(fdNum int) io.ReadCloser {
|
||||
c.Lock.Lock()
|
||||
defer c.Lock.Unlock()
|
||||
reader := mpio.MakePacketReader(fdNum)
|
||||
c.Readers[fdNum] = reader
|
||||
return reader
|
||||
}
|
||||
|
||||
func (m *MServer) runCommand(runPacket *packet.RunPacketType) {
|
||||
fdContext := m.MakeServerFdContext(runPacket.CK)
|
||||
m.Lock.Lock()
|
||||
m.FdContext = fdContext
|
||||
m.Lock.Unlock()
|
||||
go func() {
|
||||
donePk, err := shexec.RunClientSSHCommandAndWait(runPacket, fdContext, shexec.SSHOpts{}, true)
|
||||
fmt.Printf("done: err:%v, %v\n", err, donePk)
|
||||
}()
|
||||
}
|
||||
|
||||
func RunServer() (int, error) {
|
||||
server := &MServer{
|
||||
Lock: &sync.Mutex{},
|
||||
}
|
||||
packet.GlobalDebug = true
|
||||
server.MainInput = packet.MakePacketParser(os.Stdin)
|
||||
server.Sender = packet.MakePacketSender(os.Stdout)
|
||||
defer server.Close()
|
||||
defer fmt.Printf("runserver done\n")
|
||||
initPacket := packet.MakeInitPacket()
|
||||
initPacket.Version = base.MShellVersion
|
||||
server.Sender.SendPacket(initPacket)
|
||||
@ -43,7 +107,12 @@ func RunServer() (int, error) {
|
||||
}
|
||||
if pk.GetType() == packet.RunPacketStr {
|
||||
runPacket := pk.(*packet.RunPacketType)
|
||||
fmt.Printf("RUN> %s\n", runPacket)
|
||||
server.runCommand(runPacket)
|
||||
continue
|
||||
}
|
||||
if pk.GetType() == packet.DataPacketStr {
|
||||
dataPacket := pk.(*packet.DataPacketType)
|
||||
server.FdContext.processDataPacket(dataPacket)
|
||||
continue
|
||||
}
|
||||
server.Sender.SendErrorPacket(fmt.Sprintf("invalid packet '%s' sent to mshell", packet.AsExtType(pk)))
|
||||
|
@ -64,6 +64,41 @@ type ShExecType struct {
|
||||
Multiplexer *mpio.Multiplexer
|
||||
}
|
||||
|
||||
type StdContext struct{}
|
||||
|
||||
func (StdContext) GetWriter(fdNum int) io.WriteCloser {
|
||||
if fdNum == 0 {
|
||||
return os.Stdin
|
||||
}
|
||||
if fdNum == 1 {
|
||||
return os.Stdout
|
||||
}
|
||||
if fdNum == 2 {
|
||||
return os.Stderr
|
||||
}
|
||||
fd := os.NewFile(uintptr(fdNum), fmt.Sprintf("/dev/fd/%d", fdNum))
|
||||
return fd
|
||||
}
|
||||
|
||||
func (StdContext) GetReader(fdNum int) io.ReadCloser {
|
||||
if fdNum == 0 {
|
||||
return os.Stdin
|
||||
}
|
||||
if fdNum == 1 {
|
||||
return os.Stdout
|
||||
}
|
||||
if fdNum == 2 {
|
||||
return os.Stdout
|
||||
}
|
||||
fd := os.NewFile(uintptr(fdNum), fmt.Sprintf("/dev/fd/%d", fdNum))
|
||||
return fd
|
||||
}
|
||||
|
||||
type FdContext interface {
|
||||
GetWriter(fdNum int) io.WriteCloser
|
||||
GetReader(fdNum int) io.ReadCloser
|
||||
}
|
||||
|
||||
func MakeShExec(ck base.CommandKey) *ShExecType {
|
||||
return &ShExecType{
|
||||
Lock: &sync.Mutex{},
|
||||
@ -241,7 +276,6 @@ type SSHOpts struct {
|
||||
SSHOptsStr string
|
||||
SSHIdentity string
|
||||
SSHUser string
|
||||
SudoWithPass bool
|
||||
}
|
||||
|
||||
type InstallOpts struct {
|
||||
@ -258,6 +292,7 @@ type ClientOpts struct {
|
||||
Cwd string
|
||||
Debug bool
|
||||
Sudo bool
|
||||
SudoWithPass bool
|
||||
SudoPw string
|
||||
CommandStdinFdNum int
|
||||
Detach bool
|
||||
@ -316,7 +351,7 @@ func (opts *ClientOpts) MakeRunPacket() (*packet.RunPacketType, error) {
|
||||
runPacket.Command = fmt.Sprintf(RunCommandFmt, opts.Command)
|
||||
return runPacket, nil
|
||||
}
|
||||
if opts.SSHOpts.SudoWithPass {
|
||||
if opts.SudoWithPass {
|
||||
pwFdNum, err := opts.NextFreeFdNum()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -480,7 +515,16 @@ func RunInstallSSHCommand(opts *InstallOpts) error {
|
||||
return fmt.Errorf("did not receive version string from client, install not successful")
|
||||
}
|
||||
|
||||
func RunClientSSHCommandAndWait(runPacket *packet.RunPacketType, sshOpts SSHOpts, debug bool) (*packet.CmdDonePacketType, error) {
|
||||
func HasDupStdin(fds []packet.RemoteFd) bool {
|
||||
for _, rfd := range fds {
|
||||
if rfd.Read && rfd.DupStdin {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func RunClientSSHCommandAndWait(runPacket *packet.RunPacketType, fdContext FdContext, sshOpts SSHOpts, debug bool) (*packet.CmdDonePacketType, error) {
|
||||
cmd := MakeShExec("")
|
||||
ecmd := sshOpts.MakeSSHExecCmd(ClientCommand)
|
||||
cmd.Cmd = ecmd
|
||||
@ -496,11 +540,11 @@ func RunClientSSHCommandAndWait(runPacket *packet.RunPacketType, sshOpts SSHOpts
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating stderr pipe: %v", err)
|
||||
}
|
||||
if !sshOpts.SudoWithPass {
|
||||
cmd.Multiplexer.MakeRawFdReader(0, os.Stdin, false)
|
||||
if !HasDupStdin(runPacket.Fds) {
|
||||
cmd.Multiplexer.MakeRawFdReader(0, fdContext.GetReader(0), false)
|
||||
}
|
||||
cmd.Multiplexer.MakeRawFdWriter(1, os.Stdout, false)
|
||||
cmd.Multiplexer.MakeRawFdWriter(2, os.Stderr, false)
|
||||
cmd.Multiplexer.MakeRawFdWriter(1, fdContext.GetWriter(1), false)
|
||||
cmd.Multiplexer.MakeRawFdWriter(2, fdContext.GetWriter(2), false)
|
||||
for _, rfd := range runPacket.Fds {
|
||||
if rfd.Read && rfd.Content != "" {
|
||||
err = cmd.Multiplexer.MakeStringFdReader(rfd.FdNum, rfd.Content)
|
||||
@ -510,16 +554,14 @@ func RunClientSSHCommandAndWait(runPacket *packet.RunPacketType, sshOpts SSHOpts
|
||||
continue
|
||||
}
|
||||
if rfd.Read && rfd.DupStdin {
|
||||
cmd.Multiplexer.MakeRawFdReader(rfd.FdNum, os.Stdin, false)
|
||||
cmd.Multiplexer.MakeRawFdReader(rfd.FdNum, fdContext.GetReader(0), false)
|
||||
continue
|
||||
}
|
||||
fd := os.NewFile(uintptr(rfd.FdNum), fmt.Sprintf("/dev/fd/%d", rfd.FdNum))
|
||||
if fd == nil {
|
||||
return nil, fmt.Errorf("cannot open fd %d", rfd.FdNum)
|
||||
}
|
||||
if rfd.Read {
|
||||
cmd.Multiplexer.MakeRawFdReader(rfd.FdNum, fd, true)
|
||||
fd := fdContext.GetReader(rfd.FdNum)
|
||||
cmd.Multiplexer.MakeRawFdReader(rfd.FdNum, fd, false)
|
||||
} else if rfd.Write {
|
||||
fd := fdContext.GetWriter(rfd.FdNum)
|
||||
cmd.Multiplexer.MakeRawFdWriter(rfd.FdNum, fd, true)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user