mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-03-02 04:02:13 +01:00
merge waveshell into waveterm repo
This commit is contained in:
commit
a8055489f8
6
.gitignore
vendored
6
.gitignore
vendored
@ -3,13 +3,17 @@ dist-dev/
|
||||
node_modules/
|
||||
*~
|
||||
*.log
|
||||
*.out
|
||||
out/
|
||||
.DS_Store
|
||||
bin/
|
||||
waveshell/bin/
|
||||
wavesrv/bin/
|
||||
dev-bin
|
||||
local-server-bin
|
||||
*.pw
|
||||
build/
|
||||
*.dmg
|
||||
webshare/dist/
|
||||
webshare/dist-dev/
|
||||
webshare/dist-dev/
|
||||
|
||||
|
13
waveshell/go.mod
Normal file
13
waveshell/go.mod
Normal file
@ -0,0 +1,13 @@
|
||||
module github.com/commandlinedev/apishell
|
||||
|
||||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/alessio/shellescape v1.4.1
|
||||
github.com/creack/pty v1.1.18
|
||||
github.com/fsnotify/fsnotify v1.6.0
|
||||
github.com/google/uuid v1.3.0
|
||||
golang.org/x/mod v0.5.1
|
||||
golang.org/x/sys v0.10.0
|
||||
mvdan.cc/sh/v3 v3.7.0
|
||||
)
|
20
waveshell/go.sum
Normal file
20
waveshell/go.sum
Normal file
@ -0,0 +1,20 @@
|
||||
github.com/alessio/shellescape v1.4.1 h1:V7yhSDDn8LP4lc4jS8pFkt0zCnzVJlG5JXy9BVKJUX0=
|
||||
github.com/alessio/shellescape v1.4.1/go.mod h1:PZAiSCk0LJaZkiCSkPv8qIobYglO3FPpyFjDCtHLS30=
|
||||
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
|
||||
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
|
||||
github.com/frankban/quicktest v1.14.5 h1:dfYrrRyLtiqT9GyKXgdh+k4inNeTvmGbuSgZ3lx3GhA=
|
||||
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
|
||||
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/rogpeppe/go-internal v1.10.1-0.20230524175051-ec119421bb97 h1:3RPlVWzZ/PDqmVuf/FKHARG5EMid/tl7cv54Sw/QRVY=
|
||||
golang.org/x/mod v0.5.1 h1:OJxoQ/rynoF0dcCdI7cLPktw/hR2cueqYfjm43oqK38=
|
||||
golang.org/x/mod v0.5.1/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro=
|
||||
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
mvdan.cc/sh/v3 v3.7.0 h1:lSTjdP/1xsddtaKfGg7Myu7DnlHItd3/M2tomOcNNBg=
|
||||
mvdan.cc/sh/v3 v3.7.0/go.mod h1:K2gwkaesF/D7av7Kxl0HbF5kGOd2ArupNTX3X44+8l8=
|
586
waveshell/main-waveshell.go
Normal file
586
waveshell/main-waveshell.go
Normal file
@ -0,0 +1,586 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/commandlinedev/apishell/pkg/base"
|
||||
"github.com/commandlinedev/apishell/pkg/packet"
|
||||
"github.com/commandlinedev/apishell/pkg/server"
|
||||
"github.com/commandlinedev/apishell/pkg/shexec"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var BuildTime = "0"
|
||||
|
||||
// func doMainRun(pk *packet.RunPacketType, sender *packet.PacketSender) {
|
||||
// err := shexec.ValidateRunPacket(pk)
|
||||
// if err != nil {
|
||||
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("invalid run packet: %v", err))
|
||||
// return
|
||||
// }
|
||||
// fileNames, err := base.GetCommandFileNames(pk.CK)
|
||||
// if err != nil {
|
||||
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot get command file names: %v", err))
|
||||
// return
|
||||
// }
|
||||
// cmd, err := shexec.MakeRunnerExec(pk.CK)
|
||||
// if err != nil {
|
||||
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot make mshell command: %v", err))
|
||||
// return
|
||||
// }
|
||||
// cmdStdin, err := cmd.StdinPipe()
|
||||
// if err != nil {
|
||||
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot pipe stdin to command: %v", err))
|
||||
// return
|
||||
// }
|
||||
// // touch ptyout file (should exist for tailer to work correctly)
|
||||
// ptyOutFd, err := os.OpenFile(fileNames.PtyOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600)
|
||||
// if err != nil {
|
||||
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot open pty out file '%s': %v", fileNames.PtyOutFile, err))
|
||||
// return
|
||||
// }
|
||||
// ptyOutFd.Close() // just opened to create the file, can close right after
|
||||
// runnerOutFd, err := os.OpenFile(fileNames.RunnerOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600)
|
||||
// if err != nil {
|
||||
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot open runner out file '%s': %v", fileNames.RunnerOutFile, err))
|
||||
// return
|
||||
// }
|
||||
// defer runnerOutFd.Close()
|
||||
// cmd.Stdout = runnerOutFd
|
||||
// cmd.Stderr = runnerOutFd
|
||||
// err = cmd.Start()
|
||||
// if err != nil {
|
||||
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("error starting command: %v", err))
|
||||
// return
|
||||
// }
|
||||
// go func() {
|
||||
// err = packet.SendPacket(cmdStdin, pk)
|
||||
// if err != nil {
|
||||
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("error sending forked runner command: %v", err))
|
||||
// return
|
||||
// }
|
||||
// cmdStdin.Close()
|
||||
|
||||
// // clean up zombies
|
||||
// cmd.Wait()
|
||||
// }()
|
||||
// }
|
||||
|
||||
// func doGetCmd(tailer *cmdtail.Tailer, pk *packet.GetCmdPacketType, sender *packet.PacketSender) error {
|
||||
// err := tailer.AddWatch(pk)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// func doMain() {
|
||||
// homeDir := base.GetHomeDir()
|
||||
// err := os.Chdir(homeDir)
|
||||
// if err != nil {
|
||||
// packet.SendErrorPacket(os.Stdout, fmt.Sprintf("cannot change directory to $HOME '%s': %v", homeDir, err))
|
||||
// return
|
||||
// }
|
||||
// _, err = base.GetMShellPath()
|
||||
// if err != nil {
|
||||
// packet.SendErrorPacket(os.Stdout, err.Error())
|
||||
// return
|
||||
// }
|
||||
// packetParser := packet.MakePacketParser(os.Stdin)
|
||||
// sender := packet.MakePacketSender(os.Stdout)
|
||||
// tailer, err := cmdtail.MakeTailer(sender)
|
||||
// if err != nil {
|
||||
// packet.SendErrorPacket(os.Stdout, err.Error())
|
||||
// return
|
||||
// }
|
||||
// go tailer.Run()
|
||||
// initPacket := shexec.MakeInitPacket()
|
||||
// sender.SendPacket(initPacket)
|
||||
// for pk := range packetParser.MainCh {
|
||||
// if pk.GetType() == packet.RunPacketStr {
|
||||
// doMainRun(pk.(*packet.RunPacketType), sender)
|
||||
// continue
|
||||
// }
|
||||
// if pk.GetType() == packet.GetCmdPacketStr {
|
||||
// err = doGetCmd(tailer, pk.(*packet.GetCmdPacketType), sender)
|
||||
// if err != nil {
|
||||
// errPk := packet.MakeErrorPacket(err.Error())
|
||||
// sender.SendPacket(errPk)
|
||||
// continue
|
||||
// }
|
||||
// continue
|
||||
// }
|
||||
// if pk.GetType() == packet.CdPacketStr {
|
||||
// cdPacket := pk.(*packet.CdPacketType)
|
||||
// err := os.Chdir(cdPacket.Dir)
|
||||
// resp := packet.MakeResponsePacket(cdPacket.ReqId)
|
||||
// if err != nil {
|
||||
// resp.Error = err.Error()
|
||||
// } else {
|
||||
// resp.Success = true
|
||||
// }
|
||||
// sender.SendPacket(resp)
|
||||
// continue
|
||||
// }
|
||||
// if pk.GetType() == packet.ErrorPacketStr {
|
||||
// errPk := pk.(*packet.ErrorPacketType)
|
||||
// errPk.Error = "invalid packet sent to mshell: " + errPk.Error
|
||||
// sender.SendPacket(errPk)
|
||||
// continue
|
||||
// }
|
||||
// sender.SendErrorPacket(fmt.Sprintf("invalid packet '%s' sent to mshell", pk.GetType()))
|
||||
// }
|
||||
// }
|
||||
|
||||
func readFullRunPacket(packetParser *packet.PacketParser) (*packet.RunPacketType, error) {
|
||||
rpb := packet.MakeRunPacketBuilder()
|
||||
for pk := range packetParser.MainCh {
|
||||
ok, runPacket := rpb.ProcessPacket(pk)
|
||||
if runPacket != nil {
|
||||
return runPacket, nil
|
||||
}
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid packet '%s' sent to mshell", pk.GetType())
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("no run packet received")
|
||||
}
|
||||
|
||||
func handleSingle(fromServer bool) {
|
||||
packetParser := packet.MakePacketParser(os.Stdin, false)
|
||||
sender := packet.MakePacketSender(os.Stdout, nil)
|
||||
defer func() {
|
||||
sender.Close()
|
||||
sender.WaitForDone()
|
||||
}()
|
||||
initPacket := shexec.MakeInitPacket()
|
||||
sender.SendPacket(initPacket)
|
||||
if len(os.Args) >= 3 && os.Args[2] == "--version" {
|
||||
return
|
||||
}
|
||||
runPacket, err := readFullRunPacket(packetParser)
|
||||
if err != nil {
|
||||
sender.SendErrorResponse(runPacket.ReqId, err)
|
||||
return
|
||||
}
|
||||
err = shexec.ValidateRunPacket(runPacket)
|
||||
if err != nil {
|
||||
sender.SendErrorResponse(runPacket.ReqId, err)
|
||||
return
|
||||
}
|
||||
if fromServer {
|
||||
err = runPacket.CK.Validate("run packet")
|
||||
if err != nil {
|
||||
sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("run packets from server must have a CK: %v", err))
|
||||
}
|
||||
}
|
||||
if runPacket.Detached {
|
||||
cmd, startPk, err := shexec.RunCommandDetached(runPacket, sender)
|
||||
if err != nil {
|
||||
sender.SendErrorResponse(runPacket.ReqId, err)
|
||||
return
|
||||
}
|
||||
sender.SendPacket(startPk)
|
||||
sender.Close()
|
||||
sender.WaitForDone()
|
||||
cmd.DetachedWait(startPk)
|
||||
return
|
||||
} else {
|
||||
shexec.IgnoreSigPipe()
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
// this will let the command detect when the server has gone away
|
||||
// that will then trigger cmd.SendHup() to send SIGHUP to the exec'ed process
|
||||
sender.SendPacket(packet.MakePingPacket())
|
||||
}
|
||||
}()
|
||||
defer ticker.Stop()
|
||||
cmd, err := shexec.RunCommandSimple(runPacket, sender, true)
|
||||
if err != nil {
|
||||
sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("error running command: %w", err))
|
||||
return
|
||||
}
|
||||
defer cmd.Close()
|
||||
startPacket := cmd.MakeCmdStartPacket(runPacket.ReqId)
|
||||
sender.SendPacket(startPacket)
|
||||
go func() {
|
||||
exitErr := sender.WaitForDone()
|
||||
if exitErr != nil {
|
||||
base.Logf("I/O error talking to server, sending SIGHUP to children\n")
|
||||
cmd.SendSignal(syscall.SIGHUP)
|
||||
}
|
||||
}()
|
||||
cmd.RunRemoteIOAndWait(packetParser, sender)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func detectOpenFds() ([]packet.RemoteFd, error) {
|
||||
var fds []packet.RemoteFd
|
||||
for fdNum := 3; fdNum <= 64; fdNum++ {
|
||||
flags, err := unix.FcntlInt(uintptr(fdNum), unix.F_GETFL, 0)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
flags = flags & 3
|
||||
rfd := packet.RemoteFd{FdNum: fdNum}
|
||||
if flags&2 == 2 {
|
||||
return nil, fmt.Errorf("invalid fd=%d, mshell does not support fds open for reading and writing", fdNum)
|
||||
}
|
||||
if flags&1 == 1 {
|
||||
rfd.Write = true
|
||||
} else {
|
||||
rfd.Read = true
|
||||
}
|
||||
fds = append(fds, rfd)
|
||||
}
|
||||
return fds, nil
|
||||
}
|
||||
|
||||
func parseInstallOpts() (*shexec.InstallOpts, error) {
|
||||
opts := &shexec.InstallOpts{}
|
||||
iter := base.MakeOptsIter(os.Args[2:]) // first arg is --install
|
||||
for iter.HasNext() {
|
||||
argStr := iter.Next()
|
||||
found, err := tryParseSSHOpt(iter, &opts.SSHOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
if argStr == "--detect" {
|
||||
opts.Detect = true
|
||||
continue
|
||||
}
|
||||
if base.IsOption(argStr) {
|
||||
return nil, fmt.Errorf("invalid option '%s' passed to mshell --install", argStr)
|
||||
}
|
||||
opts.ArchStr = argStr
|
||||
break
|
||||
}
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func tryParseSSHOpt(iter *base.OptsIter, sshOpts *shexec.SSHOpts) (bool, error) {
|
||||
argStr := iter.Current()
|
||||
if argStr == "--ssh" {
|
||||
if !iter.IsNextPlain() {
|
||||
return false, fmt.Errorf("'--ssh [user@host]' missing host")
|
||||
}
|
||||
sshOpts.SSHHost = iter.Next()
|
||||
return true, nil
|
||||
}
|
||||
if argStr == "--ssh-opts" {
|
||||
if !iter.HasNext() {
|
||||
return false, fmt.Errorf("'--ssh-opts [options]' missing options")
|
||||
}
|
||||
sshOpts.SSHOptsStr = iter.Next()
|
||||
return true, nil
|
||||
}
|
||||
if argStr == "-i" {
|
||||
if !iter.IsNextPlain() {
|
||||
return false, fmt.Errorf("-i [identity-file]' missing file")
|
||||
}
|
||||
sshOpts.SSHIdentity = iter.Next()
|
||||
return true, nil
|
||||
}
|
||||
if argStr == "-l" {
|
||||
if !iter.IsNextPlain() {
|
||||
return false, fmt.Errorf("-l [user]' missing user")
|
||||
}
|
||||
sshOpts.SSHUser = iter.Next()
|
||||
return true, nil
|
||||
}
|
||||
if argStr == "-p" {
|
||||
if !iter.IsNextPlain() {
|
||||
return false, fmt.Errorf("-p [port]' missing port")
|
||||
}
|
||||
nextArgStr := iter.Next()
|
||||
portVal, err := strconv.Atoi(nextArgStr)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("-p [port]' invalid port: %v", err)
|
||||
}
|
||||
if portVal <= 0 {
|
||||
return false, fmt.Errorf("-p [port]' invalid port: %d", portVal)
|
||||
}
|
||||
sshOpts.SSHPort = portVal
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func parseClientOpts() (*shexec.ClientOpts, error) {
|
||||
opts := &shexec.ClientOpts{}
|
||||
iter := base.MakeOptsIter(os.Args[1:])
|
||||
for iter.HasNext() {
|
||||
argStr := iter.Next()
|
||||
found, err := tryParseSSHOpt(iter, &opts.SSHOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
if argStr == "--cwd" {
|
||||
if !iter.IsNextPlain() {
|
||||
return nil, fmt.Errorf("'--cwd [dir]' missing directory")
|
||||
}
|
||||
opts.Cwd = iter.Next()
|
||||
continue
|
||||
}
|
||||
if argStr == "--detach" {
|
||||
opts.Detach = true
|
||||
continue
|
||||
}
|
||||
if argStr == "--pty" {
|
||||
opts.UsePty = true
|
||||
continue
|
||||
}
|
||||
if argStr == "--debug" {
|
||||
opts.Debug = true
|
||||
continue
|
||||
}
|
||||
if argStr == "--sudo" {
|
||||
opts.Sudo = true
|
||||
continue
|
||||
}
|
||||
if argStr == "--sudo-with-password" {
|
||||
if !iter.HasNext() {
|
||||
return nil, fmt.Errorf("'--sudo-with-password [pw]', missing password")
|
||||
}
|
||||
opts.Sudo = true
|
||||
opts.SudoWithPass = true
|
||||
opts.SudoPw = iter.Next()
|
||||
continue
|
||||
}
|
||||
if argStr == "--sudo-with-passfile" {
|
||||
if !iter.IsNextPlain() {
|
||||
return nil, fmt.Errorf("'--sudo-with-passfile [file]', missing file")
|
||||
}
|
||||
opts.Sudo = true
|
||||
opts.SudoWithPass = true
|
||||
fileName := iter.Next()
|
||||
contents, err := os.ReadFile(fileName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot read --sudo-with-passfile file '%s': %w", fileName, err)
|
||||
}
|
||||
if newlineIdx := bytes.Index(contents, []byte{'\n'}); newlineIdx != -1 {
|
||||
contents = contents[0:newlineIdx]
|
||||
}
|
||||
opts.SudoPw = string(contents) + "\n"
|
||||
continue
|
||||
}
|
||||
if argStr == "--" {
|
||||
if !iter.HasNext() {
|
||||
return nil, fmt.Errorf("'--' should be followed by command")
|
||||
}
|
||||
opts.Command = strings.Join(iter.Rest(), " ")
|
||||
break
|
||||
}
|
||||
return nil, fmt.Errorf("invalid option '%s' passed to mshell", argStr)
|
||||
}
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func handleClient() (int, error) {
|
||||
opts, err := parseClientOpts()
|
||||
if err != nil {
|
||||
return 1, fmt.Errorf("parsing opts: %w", err)
|
||||
}
|
||||
if opts.Debug {
|
||||
packet.GlobalDebug = true
|
||||
}
|
||||
if opts.Command == "" {
|
||||
return 1, fmt.Errorf("no [command] specified. [command] follows '--' option (see usage)")
|
||||
}
|
||||
fds, err := detectOpenFds()
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
opts.Fds = fds
|
||||
err = shexec.ValidateRemoteFds(opts.Fds)
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
runPacket, err := opts.MakeRunPacket() // modifies opts
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
if runPacket.Detached {
|
||||
return 1, fmt.Errorf("cannot run detached command from command line client")
|
||||
}
|
||||
donePacket, err := shexec.RunClientSSHCommandAndWait(runPacket, shexec.StdContext{}, opts.SSHOpts, nil, opts.Debug)
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
return donePacket.ExitCode, nil
|
||||
}
|
||||
|
||||
func handleInstall() (int, error) {
|
||||
opts, err := parseInstallOpts()
|
||||
if err != nil {
|
||||
return 1, fmt.Errorf("parsing opts: %w", err)
|
||||
}
|
||||
if opts.SSHOpts.SSHHost == "" {
|
||||
return 1, fmt.Errorf("cannot install without '--ssh user@host' option")
|
||||
}
|
||||
if opts.Detect && opts.ArchStr != "" {
|
||||
return 1, fmt.Errorf("cannot supply both --detect and arch '%s'", opts.ArchStr)
|
||||
}
|
||||
if opts.ArchStr == "" && !opts.Detect {
|
||||
return 1, fmt.Errorf("must supply an arch string or '--detect' to auto detect")
|
||||
}
|
||||
if opts.ArchStr != "" {
|
||||
fullArch := opts.ArchStr
|
||||
fields := strings.SplitN(fullArch, ".", 2)
|
||||
if len(fields) != 2 {
|
||||
return 1, fmt.Errorf("invalid arch format '%s' passed to mshell --install", fullArch)
|
||||
}
|
||||
goos, goarch := fields[0], fields[1]
|
||||
if !base.ValidGoArch(goos, goarch) {
|
||||
return 1, fmt.Errorf("invalid arch '%s' passed to mshell --install", fullArch)
|
||||
}
|
||||
optName := base.GoArchOptFile(base.MShellVersion, goos, goarch)
|
||||
_, err = os.Stat(optName)
|
||||
if err != nil {
|
||||
return 1, fmt.Errorf("cannot install mshell to remote host, cannot read '%s': %w", optName, err)
|
||||
}
|
||||
opts.OptName = optName
|
||||
}
|
||||
err = shexec.RunInstallFromOpts(opts)
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func handleEnv() (int, error) {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
fmt.Printf("%s\x00\x00", cwd)
|
||||
fullEnv := os.Environ()
|
||||
var linePrinted bool
|
||||
for _, envLine := range fullEnv {
|
||||
if envLine != "" {
|
||||
fmt.Printf("%s\x00", envLine)
|
||||
linePrinted = true
|
||||
}
|
||||
}
|
||||
if linePrinted {
|
||||
fmt.Printf("\x00")
|
||||
} else {
|
||||
fmt.Printf("\x00\x00")
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func handleUsage() {
|
||||
usage := `
|
||||
Client Usage: mshell [opts] --ssh user@host -- [command]
|
||||
|
||||
mshell multiplexes input and output streams to a remote command over ssh.
|
||||
|
||||
Options:
|
||||
-i [identity-file] - used to set '-i' option for ssh command
|
||||
-l [user] - used to set '-l' option for ssh command
|
||||
--cwd [dir] - execute remote command in [dir]
|
||||
--ssh-opts [opts] - addition options to pass to ssh command
|
||||
[command] - the remote command to execute
|
||||
|
||||
Sudo Options:
|
||||
--sudo - use only if sudo never requires a password
|
||||
--sudo-with-password [pw] - not recommended, use --sudo-with-passfile if possible
|
||||
--sudo-with-passfile [file]
|
||||
|
||||
Sudo options allow you to run the given command using "sudo". The first
|
||||
option only works when you can sudo without a password. Your password will be passed
|
||||
securely through a high numbered fd to "sudo -S". Note that to use high numbered
|
||||
file descriptors with sudo, you will need to add this line to your /etc/sudoers file:
|
||||
Defaults closefrom_override
|
||||
See full documentation for more details.
|
||||
|
||||
Examples:
|
||||
# execute a python script remotely, with stdin still hooked up correctly
|
||||
mshell --cwd "~/work" -i key.pem --ssh ubuntu@somehost -- "python3 /dev/fd/4" 4< myscript.py
|
||||
|
||||
# capture multiple outputs
|
||||
mshell --ssh ubuntu@test -- "cat file1.txt > /dev/fd/3; cat file2.txt > /dev/fd/4" 3> file1.txt 4> file2.txt
|
||||
|
||||
# execute a script, catpure stdout/stderr in fd-3 and fd-4
|
||||
# useful if you need to see stdout for interacting with ssh (password or host auth)
|
||||
mshell --ssh user@host -- "test.sh > /dev/fd/3 2> /dev/fd/4" 3> test.stdout 4> test.stderr
|
||||
|
||||
# run a script as root (via sudo), capture output
|
||||
mshell --sudo-with-passfile pw.txt --ssh ubuntu@somehost -- "python3 /dev/fd/3 > /dev/fd/4" 3< myscript.py 4> script-output.txt < script-input.txt
|
||||
`
|
||||
fmt.Printf("%s\n\n", strings.TrimSpace(usage))
|
||||
}
|
||||
|
||||
func main() {
|
||||
base.SetBuildTime(BuildTime)
|
||||
if len(os.Args) == 1 {
|
||||
handleUsage()
|
||||
return
|
||||
}
|
||||
firstArg := os.Args[1]
|
||||
if firstArg == "--help" {
|
||||
handleUsage()
|
||||
return
|
||||
} else if firstArg == "--version" {
|
||||
fmt.Printf("mshell %s+%s\n", base.MShellVersion, base.BuildTime)
|
||||
return
|
||||
} else if firstArg == "--test-env" {
|
||||
state, err := shexec.GetShellState()
|
||||
if state != nil {
|
||||
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "[error] %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
} else if firstArg == "--single" {
|
||||
base.InitDebugLog("single")
|
||||
handleSingle(false)
|
||||
return
|
||||
} else if firstArg == "--single-from-server" {
|
||||
base.InitDebugLog("single")
|
||||
handleSingle(true)
|
||||
return
|
||||
} else if firstArg == "--server" {
|
||||
base.InitDebugLog("server")
|
||||
rtnCode, err := server.RunServer()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "[error] %v\n", err)
|
||||
}
|
||||
if rtnCode != 0 {
|
||||
os.Exit(rtnCode)
|
||||
}
|
||||
return
|
||||
} else if firstArg == "--install" {
|
||||
rtnCode, err := handleInstall()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "[error] %v\n", err)
|
||||
}
|
||||
os.Exit(rtnCode)
|
||||
return
|
||||
} else {
|
||||
rtnCode, err := handleClient()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "[error] %v\n", err)
|
||||
}
|
||||
if rtnCode != 0 {
|
||||
os.Exit(rtnCode)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
381
waveshell/pkg/base/base.go
Normal file
381
waveshell/pkg/base/base.go
Normal file
@ -0,0 +1,381 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/mod/semver"
|
||||
)
|
||||
|
||||
const HomeVarName = "HOME"
|
||||
const DefaultMShellHome = "~/.mshell"
|
||||
const DefaultMShellName = "mshell"
|
||||
const MShellPathVarName = "MSHELL_PATH"
|
||||
const MShellHomeVarName = "MSHELL_HOME"
|
||||
const MShellInstallBinVarName = "MSHELL_INSTALLBIN_PATH"
|
||||
const SSHCommandVarName = "SSH_COMMAND"
|
||||
const MShellDebugVarName = "MSHELL_DEBUG"
|
||||
const SessionsDirBaseName = "sessions"
|
||||
const MShellVersion = "v0.3.0"
|
||||
const RemoteIdFile = "remoteid"
|
||||
const DefaultMShellInstallBinDir = "/opt/mshell/bin"
|
||||
const LogFileName = "mshell.log"
|
||||
const ForceDebugLog = false
|
||||
|
||||
const DebugFlag_LogRcFile = "logrc"
|
||||
const LogRcFileName = "debug.rcfile"
|
||||
|
||||
var sessionDirCache = make(map[string]string)
|
||||
var baseLock = &sync.Mutex{}
|
||||
var DebugLogEnabled = false
|
||||
var DebugLogger *log.Logger
|
||||
var BuildTime string = "0"
|
||||
|
||||
type CommandFileNames struct {
|
||||
PtyOutFile string
|
||||
StdinFifo string
|
||||
RunnerOutFile string
|
||||
}
|
||||
|
||||
type CommandKey string
|
||||
|
||||
func SetBuildTime(build string) {
|
||||
BuildTime = build
|
||||
}
|
||||
|
||||
func MakeCommandKey(sessionId string, cmdId string) CommandKey {
|
||||
if sessionId == "" && cmdId == "" {
|
||||
return CommandKey("")
|
||||
}
|
||||
return CommandKey(fmt.Sprintf("%s/%s", sessionId, cmdId))
|
||||
}
|
||||
|
||||
func (ckey CommandKey) IsEmpty() bool {
|
||||
return string(ckey) == ""
|
||||
}
|
||||
|
||||
func Logf(fmtStr string, args ...interface{}) {
|
||||
if (!DebugLogEnabled && !ForceDebugLog) || DebugLogger == nil {
|
||||
return
|
||||
}
|
||||
DebugLogger.Printf(fmtStr, args...)
|
||||
}
|
||||
|
||||
func InitDebugLog(prefix string) {
|
||||
homeDir := GetMShellHomeDir()
|
||||
err := os.MkdirAll(homeDir, 0777)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
logFile := path.Join(homeDir, LogFileName)
|
||||
fd, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
DebugLogger = log.New(fd, prefix+" ", log.LstdFlags)
|
||||
Logf("logger initialized\n")
|
||||
}
|
||||
|
||||
func SetEnableDebugLog(enable bool) {
|
||||
DebugLogEnabled = enable
|
||||
}
|
||||
|
||||
// deprecated (use GetGroupId instead)
|
||||
func (ckey CommandKey) GetSessionId() string {
|
||||
return ckey.GetGroupId()
|
||||
}
|
||||
|
||||
func (ckey CommandKey) GetGroupId() string {
|
||||
slashIdx := strings.Index(string(ckey), "/")
|
||||
if slashIdx == -1 {
|
||||
return ""
|
||||
}
|
||||
return string(ckey[0:slashIdx])
|
||||
}
|
||||
|
||||
func (ckey CommandKey) GetCmdId() string {
|
||||
slashIdx := strings.Index(string(ckey), "/")
|
||||
if slashIdx == -1 {
|
||||
return ""
|
||||
}
|
||||
return string(ckey[slashIdx+1:])
|
||||
}
|
||||
|
||||
func (ckey CommandKey) Split() (string, string) {
|
||||
fields := strings.SplitN(string(ckey), "/", 2)
|
||||
if len(fields) < 2 {
|
||||
return "", ""
|
||||
}
|
||||
return fields[0], fields[1]
|
||||
}
|
||||
|
||||
func (ckey CommandKey) Validate(typeStr string) error {
|
||||
if typeStr == "" {
|
||||
typeStr = "ck"
|
||||
}
|
||||
if ckey == "" {
|
||||
return fmt.Errorf("%s has empty commandkey", typeStr)
|
||||
}
|
||||
sessionId, cmdId := ckey.Split()
|
||||
if sessionId == "" {
|
||||
return fmt.Errorf("%s does not have sessionid", typeStr)
|
||||
}
|
||||
_, err := uuid.Parse(sessionId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s has invalid sessionid '%s'", typeStr, sessionId)
|
||||
}
|
||||
if cmdId == "" {
|
||||
return fmt.Errorf("%s does not have cmdid", typeStr)
|
||||
}
|
||||
_, err = uuid.Parse(cmdId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s has invalid cmdid '%s'", typeStr, cmdId)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func HasDebugFlag(envMap map[string]string, flagName string) bool {
|
||||
msDebug := envMap[MShellDebugVarName]
|
||||
flags := strings.Split(msDebug, ",")
|
||||
Logf("hasdebugflag[%s]: %s [%#v]\n", flagName, msDebug, flags)
|
||||
for _, flag := range flags {
|
||||
if strings.TrimSpace(flag) == flagName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func GetDebugRcFileName() string {
|
||||
msHome := GetMShellHomeDir()
|
||||
return path.Join(msHome, LogRcFileName)
|
||||
}
|
||||
|
||||
func GetHomeDir() string {
|
||||
homeVar := os.Getenv(HomeVarName)
|
||||
if homeVar == "" {
|
||||
return "/"
|
||||
}
|
||||
return homeVar
|
||||
}
|
||||
|
||||
func GetMShellHomeDir() string {
|
||||
homeVar := os.Getenv(MShellHomeVarName)
|
||||
if homeVar != "" {
|
||||
return homeVar
|
||||
}
|
||||
return ExpandHomeDir(DefaultMShellHome)
|
||||
}
|
||||
|
||||
func GetCommandFileNames(ck CommandKey) (*CommandFileNames, error) {
|
||||
if err := ck.Validate("ck"); err != nil {
|
||||
return nil, fmt.Errorf("cannot get command files: %w", err)
|
||||
}
|
||||
sessionId, cmdId := ck.Split()
|
||||
sdir, err := EnsureSessionDir(sessionId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
base := path.Join(sdir, cmdId)
|
||||
return &CommandFileNames{
|
||||
PtyOutFile: base + ".ptyout",
|
||||
StdinFifo: base + ".stdin",
|
||||
RunnerOutFile: base + ".runout",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func CleanUpCmdFiles(sessionId string, cmdId string) error {
|
||||
if cmdId == "" {
|
||||
return fmt.Errorf("bad cmdid, cannot clean up")
|
||||
}
|
||||
sdir, err := EnsureSessionDir(sessionId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmdFileGlob := path.Join(sdir, cmdId+".*")
|
||||
matches, err := filepath.Glob(cmdFileGlob)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, file := range matches {
|
||||
rmErr := os.Remove(file)
|
||||
if err == nil && rmErr != nil {
|
||||
err = rmErr
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func GetSessionsDir() string {
|
||||
mhome := GetMShellHomeDir()
|
||||
sdir := path.Join(mhome, SessionsDirBaseName)
|
||||
return sdir
|
||||
}
|
||||
|
||||
func EnsureSessionDir(sessionId string) (string, error) {
|
||||
if sessionId == "" {
|
||||
return "", fmt.Errorf("Bad sessionid, cannot be empty")
|
||||
}
|
||||
baseLock.Lock()
|
||||
sdir, ok := sessionDirCache[sessionId]
|
||||
baseLock.Unlock()
|
||||
if ok {
|
||||
return sdir, nil
|
||||
}
|
||||
mhome := GetMShellHomeDir()
|
||||
sdir = path.Join(mhome, SessionsDirBaseName, sessionId)
|
||||
info, err := os.Stat(sdir)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
err = os.MkdirAll(sdir, 0777)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cannot make mshell session directory[%s]: %w", sdir, err)
|
||||
}
|
||||
info, err = os.Stat(sdir)
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return "", fmt.Errorf("session dir '%s' must be a directory", sdir)
|
||||
}
|
||||
baseLock.Lock()
|
||||
sessionDirCache[sessionId] = sdir
|
||||
baseLock.Unlock()
|
||||
return sdir, nil
|
||||
}
|
||||
|
||||
func GetMShellPath() (string, error) {
|
||||
msPath := os.Getenv(MShellPathVarName) // use MSHELL_PATH
|
||||
if msPath != "" {
|
||||
return exec.LookPath(msPath)
|
||||
}
|
||||
mhome := GetMShellHomeDir()
|
||||
userMShellPath := path.Join(mhome, DefaultMShellName) // look in ~/.mshell
|
||||
msPath, err := exec.LookPath(userMShellPath)
|
||||
if err == nil {
|
||||
return msPath, nil
|
||||
}
|
||||
return exec.LookPath(DefaultMShellName) // standard path lookup for 'mshell'
|
||||
}
|
||||
|
||||
func GetMShellSessionsDir() (string, error) {
|
||||
mhome := GetMShellHomeDir()
|
||||
return path.Join(mhome, SessionsDirBaseName), nil
|
||||
}
|
||||
|
||||
func ExpandHomeDir(pathStr string) string {
|
||||
if pathStr != "~" && !strings.HasPrefix(pathStr, "~/") {
|
||||
return pathStr
|
||||
}
|
||||
homeDir := GetHomeDir()
|
||||
if pathStr == "~" {
|
||||
return homeDir
|
||||
}
|
||||
return path.Join(homeDir, pathStr[2:])
|
||||
}
|
||||
|
||||
func ValidGoArch(goos string, goarch string) bool {
|
||||
return (goos == "darwin" || goos == "linux") && (goarch == "amd64" || goarch == "arm64")
|
||||
}
|
||||
|
||||
func GoArchOptFile(version string, goos string, goarch string) string {
|
||||
installBinDir := os.Getenv(MShellInstallBinVarName)
|
||||
if installBinDir == "" {
|
||||
installBinDir = DefaultMShellInstallBinDir
|
||||
}
|
||||
versionStr := semver.MajorMinor(version)
|
||||
if versionStr == "" {
|
||||
versionStr = "unknown"
|
||||
}
|
||||
binBaseName := fmt.Sprintf("mshell-%s-%s.%s", versionStr, goos, goarch)
|
||||
return fmt.Sprintf(path.Join(installBinDir, binBaseName))
|
||||
}
|
||||
|
||||
func MShellBinaryFromOptDir(version string, goos string, goarch string) (io.ReadCloser, error) {
|
||||
if !ValidGoArch(goos, goarch) {
|
||||
return nil, fmt.Errorf("invalid goos/goarch combination: %s/%s", goos, goarch)
|
||||
}
|
||||
versionStr := semver.MajorMinor(version)
|
||||
if versionStr == "" {
|
||||
return nil, fmt.Errorf("invalid mshell version: %q", version)
|
||||
}
|
||||
fileName := GoArchOptFile(version, goos, goarch)
|
||||
fd, err := os.Open(fileName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot open mshell binary %q: %v", fileName, err)
|
||||
}
|
||||
return fd, nil
|
||||
}
|
||||
|
||||
func GetRemoteId() (string, error) {
|
||||
mhome := GetMShellHomeDir()
|
||||
homeInfo, err := os.Stat(mhome)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
err = os.MkdirAll(mhome, 0777)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cannot make mshell home directory[%s]: %w", mhome, err)
|
||||
}
|
||||
homeInfo, err = os.Stat(mhome)
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cannot stat mshell home directory[%s]: %w", mhome, err)
|
||||
}
|
||||
if !homeInfo.IsDir() {
|
||||
return "", fmt.Errorf("mshell home directory[%s] is not a directory", mhome)
|
||||
}
|
||||
remoteIdFile := path.Join(mhome, RemoteIdFile)
|
||||
fd, err := os.Open(remoteIdFile)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
// write the file
|
||||
remoteId := uuid.New().String()
|
||||
err = os.WriteFile(remoteIdFile, []byte(remoteId), 0644)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cannot write remoteid to '%s': %w", remoteIdFile, err)
|
||||
}
|
||||
return remoteId, nil
|
||||
} else if err != nil {
|
||||
return "", fmt.Errorf("cannot read remoteid file '%s': %w", remoteIdFile, err)
|
||||
} else {
|
||||
defer fd.Close()
|
||||
contents, err := io.ReadAll(fd)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cannot read remoteid file '%s': %w", remoteIdFile, err)
|
||||
}
|
||||
uuidStr := string(contents)
|
||||
_, err = uuid.Parse(uuidStr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid uuid read from '%s': %w", remoteIdFile, err)
|
||||
}
|
||||
return uuidStr, nil
|
||||
}
|
||||
}
|
||||
|
||||
func BoundInt(ival int, minVal int, maxVal int) int {
|
||||
if ival < minVal {
|
||||
return minVal
|
||||
}
|
||||
if ival > maxVal {
|
||||
return maxVal
|
||||
}
|
||||
return ival
|
||||
}
|
||||
|
||||
func BoundInt64(ival int64, minVal int64, maxVal int64) int64 {
|
||||
if ival < minVal {
|
||||
return minVal
|
||||
}
|
||||
if ival > maxVal {
|
||||
return maxVal
|
||||
}
|
||||
return ival
|
||||
}
|
47
waveshell/pkg/base/optsiter.go
Normal file
47
waveshell/pkg/base/optsiter.go
Normal file
@ -0,0 +1,47 @@
|
||||
package base
|
||||
|
||||
import "strings"
|
||||
|
||||
type OptsIter struct {
|
||||
Pos int
|
||||
Opts []string
|
||||
}
|
||||
|
||||
func MakeOptsIter(opts []string) *OptsIter {
|
||||
return &OptsIter{Opts: opts}
|
||||
}
|
||||
|
||||
func IsOption(argStr string) bool {
|
||||
return strings.HasPrefix(argStr, "-") && argStr != "-" && !strings.HasPrefix(argStr, "-/")
|
||||
}
|
||||
|
||||
func (iter *OptsIter) HasNext() bool {
|
||||
return iter.Pos <= len(iter.Opts)-1
|
||||
}
|
||||
|
||||
func (iter *OptsIter) IsNextPlain() bool {
|
||||
if !iter.HasNext() {
|
||||
return false
|
||||
}
|
||||
return !IsOption(iter.Opts[iter.Pos])
|
||||
}
|
||||
|
||||
func (iter *OptsIter) Next() string {
|
||||
if iter.Pos >= len(iter.Opts) {
|
||||
return ""
|
||||
}
|
||||
rtn := iter.Opts[iter.Pos]
|
||||
iter.Pos++
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (iter *OptsIter) Current() string {
|
||||
if iter.Pos == 0 {
|
||||
return ""
|
||||
}
|
||||
return iter.Opts[iter.Pos-1]
|
||||
}
|
||||
|
||||
func (iter *OptsIter) Rest() []string {
|
||||
return iter.Opts[iter.Pos:]
|
||||
}
|
127
waveshell/pkg/binpack/binpack.go
Normal file
127
waveshell/pkg/binpack/binpack.go
Normal file
@ -0,0 +1,127 @@
|
||||
package binpack
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
type Unpacker struct {
|
||||
R FullByteReader
|
||||
Err error
|
||||
}
|
||||
|
||||
type FullByteReader interface {
|
||||
io.ByteReader
|
||||
io.Reader
|
||||
}
|
||||
|
||||
func PackValue(w io.Writer, barr []byte) error {
|
||||
viBuf := make([]byte, binary.MaxVarintLen64)
|
||||
viLen := binary.PutUvarint(viBuf, uint64(len(barr)))
|
||||
_, err := w.Write(viBuf[0:viLen])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(barr) > 0 {
|
||||
_, err = w.Write(barr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func PackStrArr(w io.Writer, strs []string) error {
|
||||
barr, err := json.Marshal(strs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return PackValue(w, barr)
|
||||
}
|
||||
|
||||
func PackInt(w io.Writer, ival int) error {
|
||||
viBuf := make([]byte, binary.MaxVarintLen64)
|
||||
l := binary.PutUvarint(viBuf, uint64(ival))
|
||||
_, err := w.Write(viBuf[0:l])
|
||||
return err
|
||||
}
|
||||
|
||||
func UnpackValue(r FullByteReader) ([]byte, error) {
|
||||
lenVal, err := binary.ReadUvarint(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lenVal == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
rtnBuf := make([]byte, int(lenVal))
|
||||
_, err = io.ReadFull(r, rtnBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rtnBuf, nil
|
||||
}
|
||||
|
||||
func UnpackStrArr(r FullByteReader) ([]string, error) {
|
||||
barr, err := UnpackValue(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var strs []string
|
||||
err = json.Unmarshal(barr, &strs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return strs, nil
|
||||
}
|
||||
|
||||
func UnpackInt(r io.ByteReader) (int, error) {
|
||||
ival64, err := binary.ReadVarint(r)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(ival64), nil
|
||||
}
|
||||
|
||||
func (u *Unpacker) UnpackValue(name string) []byte {
|
||||
if u.Err != nil {
|
||||
return nil
|
||||
}
|
||||
rtn, err := UnpackValue(u.R)
|
||||
if err != nil {
|
||||
u.Err = fmt.Errorf("cannot unpack %s: %v", name, err)
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (u *Unpacker) UnpackInt(name string) int {
|
||||
if u.Err != nil {
|
||||
return 0
|
||||
}
|
||||
rtn, err := UnpackInt(u.R)
|
||||
if err != nil {
|
||||
u.Err = fmt.Errorf("cannot unpack %s: %v", name, err)
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (u *Unpacker) UnpackStrArr(name string) []string {
|
||||
if u.Err != nil {
|
||||
return nil
|
||||
}
|
||||
rtn, err := UnpackStrArr(u.R)
|
||||
if err != nil {
|
||||
u.Err = fmt.Errorf("cannot unpack %s: %v", name, err)
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (u *Unpacker) Error() error {
|
||||
return u.Err
|
||||
}
|
||||
|
||||
func MakeUnpacker(r FullByteReader) *Unpacker {
|
||||
return &Unpacker{R: r}
|
||||
}
|
570
waveshell/pkg/cirfile/cirfile.go
Normal file
570
waveshell/pkg/cirfile/cirfile.go
Normal file
@ -0,0 +1,570 @@
|
||||
package cirfile
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CBUF[version] [maxsize] [fileoffset] [startpos] [endpos]
|
||||
const HeaderFmt1 = "CBUF%02d %19d %19d %19d %19d\n" // 87 bytes
|
||||
const HeaderLen = 256 // set to 256 for future expandability
|
||||
const FullHeaderFmt = "%-255s\n" // 256 bytes (255 + newline)
|
||||
const CurrentVersion = 1
|
||||
const FilePosEmpty = -1 // sentinel, if startpos is set to -1, file is empty
|
||||
|
||||
const InitialLockDelay = 10 * time.Millisecond
|
||||
const InitialLockTries = 5
|
||||
const LockDelay = 100 * time.Millisecond
|
||||
|
||||
// File objects are *not* multithread safe, operations must be externally synchronized
|
||||
type File struct {
|
||||
OSFile *os.File
|
||||
Version byte
|
||||
MaxSize int64
|
||||
FileOffset int64
|
||||
StartPos int64
|
||||
EndPos int64
|
||||
FileDataSize int64 // size of data (does not include header size)
|
||||
FlockStatus int
|
||||
}
|
||||
|
||||
type Stat struct {
|
||||
Location string
|
||||
Version byte
|
||||
MaxSize int64
|
||||
FileOffset int64
|
||||
DataSize int64
|
||||
}
|
||||
|
||||
func (f *File) flock(ctx context.Context, lockType int) error {
|
||||
err := syscall.Flock(int(f.OSFile.Fd()), lockType|syscall.LOCK_NB)
|
||||
if err == nil {
|
||||
f.FlockStatus = lockType
|
||||
return nil
|
||||
}
|
||||
if err != syscall.EWOULDBLOCK {
|
||||
return err
|
||||
}
|
||||
if ctx == nil {
|
||||
return syscall.EWOULDBLOCK
|
||||
}
|
||||
// busy-wait with context
|
||||
numWaits := 0
|
||||
for {
|
||||
numWaits++
|
||||
var timeout time.Duration
|
||||
if numWaits <= InitialLockTries {
|
||||
timeout = InitialLockDelay
|
||||
} else {
|
||||
timeout = LockDelay
|
||||
}
|
||||
select {
|
||||
case <-time.After(timeout):
|
||||
break
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
err = syscall.Flock(int(f.OSFile.Fd()), lockType|syscall.LOCK_NB)
|
||||
if err == nil {
|
||||
f.FlockStatus = lockType
|
||||
return nil
|
||||
}
|
||||
if err != syscall.EWOULDBLOCK {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("could not acquire lock")
|
||||
}
|
||||
|
||||
func (f *File) unflock() {
|
||||
if f.FlockStatus != 0 {
|
||||
syscall.Flock(int(f.OSFile.Fd()), syscall.LOCK_UN) // ignore error (nothing to do about it anyway)
|
||||
f.FlockStatus = 0
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// does not read metadata because locking could block/fail. we want to be able
|
||||
// to return a valid file struct without blocking.
|
||||
func OpenCirFile(fileName string) (*File, error) {
|
||||
fd, err := os.OpenFile(fileName, os.O_RDWR, 0777)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
finfo, err := fd.Stat()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if finfo.Size() < HeaderLen {
|
||||
return nil, fmt.Errorf("invalid cirfile, file length[%d] less than HeaderLen[%d]", finfo.Size(), HeaderLen)
|
||||
}
|
||||
rtn := &File{OSFile: fd}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func StatCirFile(ctx context.Context, fileName string) (*Stat, error) {
|
||||
file, err := OpenCirFile(fileName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
fileOffset, dataSize, err := file.GetStartOffsetAndSize(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Stat{
|
||||
Location: fileName,
|
||||
Version: file.Version,
|
||||
MaxSize: file.MaxSize,
|
||||
FileOffset: fileOffset,
|
||||
DataSize: dataSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// if the file already exists, it is an error.
|
||||
// there is a race condition if two goroutines try to create the same file between Stat() and Create(), so
|
||||
// they both might get no error, but only one file will be valid. if this is a concern, this call
|
||||
// should be externally synchronized.
|
||||
func CreateCirFile(fileName string, maxSize int64) (*File, error) {
|
||||
if maxSize <= 0 {
|
||||
return nil, fmt.Errorf("invalid maxsize[%d]", maxSize)
|
||||
}
|
||||
_, err := os.Stat(fileName)
|
||||
if err == nil {
|
||||
return nil, fmt.Errorf("file[%s] already exists", fileName)
|
||||
}
|
||||
if !os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("cannot stat: %w", err)
|
||||
}
|
||||
fd, err := os.Create(fileName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rtn := &File{OSFile: fd, Version: CurrentVersion, MaxSize: maxSize, StartPos: FilePosEmpty}
|
||||
err = rtn.flock(nil, syscall.LOCK_EX)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rtn.unflock()
|
||||
err = rtn.writeMeta()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func (f *File) Close() error {
|
||||
return f.OSFile.Close()
|
||||
}
|
||||
|
||||
func (f *File) ReadMeta(ctx context.Context) error {
|
||||
err := f.flock(ctx, syscall.LOCK_SH)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.unflock()
|
||||
return f.readMeta()
|
||||
}
|
||||
|
||||
func (f *File) hasShLock() bool {
|
||||
return f.FlockStatus == syscall.LOCK_EX || f.FlockStatus == syscall.LOCK_SH
|
||||
}
|
||||
|
||||
func (f *File) hasExLock() bool {
|
||||
return f.FlockStatus == syscall.LOCK_EX
|
||||
}
|
||||
|
||||
func (f *File) readMeta() error {
|
||||
if f.OSFile == nil {
|
||||
return fmt.Errorf("no *os.File")
|
||||
}
|
||||
if !f.hasShLock() {
|
||||
return fmt.Errorf("writeMeta must hold LOCK_SH")
|
||||
}
|
||||
_, err := f.OSFile.Seek(0, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot seek file: %w", err)
|
||||
}
|
||||
finfo, err := f.OSFile.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot stat file: %w", err)
|
||||
}
|
||||
if finfo.Size() < 256 {
|
||||
return fmt.Errorf("invalid cbuf file size[%d] < 256", finfo.Size())
|
||||
}
|
||||
f.FileDataSize = finfo.Size() - 256
|
||||
buf := make([]byte, 256)
|
||||
_, err = io.ReadFull(f.OSFile, buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading header: %w", err)
|
||||
}
|
||||
// currently only one version, so we don't need to have special logic here yet
|
||||
_, err = fmt.Sscanf(string(buf), HeaderFmt1, &f.Version, &f.MaxSize, &f.FileOffset, &f.StartPos, &f.EndPos)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sscanf error: %w", err)
|
||||
}
|
||||
if f.Version != CurrentVersion {
|
||||
return fmt.Errorf("invalid cbuf version[%d]", f.Version)
|
||||
}
|
||||
// possible incomplete write, fix start/end pos to be within filesize
|
||||
if f.FileDataSize == 0 {
|
||||
f.StartPos = FilePosEmpty
|
||||
f.EndPos = 0
|
||||
} else if f.StartPos >= f.FileDataSize && f.EndPos >= f.FileDataSize {
|
||||
f.StartPos = FilePosEmpty
|
||||
f.EndPos = 0
|
||||
} else if f.StartPos >= f.FileDataSize {
|
||||
f.StartPos = 0
|
||||
} else if f.EndPos >= f.FileDataSize {
|
||||
f.EndPos = f.FileDataSize - 1
|
||||
}
|
||||
if f.MaxSize <= 0 || f.FileOffset < 0 || (f.StartPos < 0 && f.StartPos != FilePosEmpty) || f.StartPos >= f.MaxSize || f.EndPos < 0 || f.EndPos >= f.MaxSize {
|
||||
return fmt.Errorf("invalid cbuf metadata version[%d] filedatasize[%d] maxsize[%d] fileoffset[%d] startpos[%d] endpos[%d]", f.Version, f.FileDataSize, f.MaxSize, f.FileOffset, f.StartPos, f.EndPos)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// no error checking of meta values
|
||||
func (f *File) writeMeta() error {
|
||||
if f.OSFile == nil {
|
||||
return fmt.Errorf("no *os.File")
|
||||
}
|
||||
if !f.hasExLock() {
|
||||
return fmt.Errorf("writeMeta must hold LOCK_EX")
|
||||
}
|
||||
_, err := f.OSFile.Seek(0, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot seek file: %w", err)
|
||||
}
|
||||
metaStr := fmt.Sprintf(HeaderFmt1, f.Version, f.MaxSize, f.FileOffset, f.StartPos, f.EndPos)
|
||||
fullMetaStr := fmt.Sprintf(FullHeaderFmt, metaStr)
|
||||
_, err = f.OSFile.WriteString(fullMetaStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("write error: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// returns (fileOffset, datasize, error)
|
||||
// datasize is the current amount of readable data held in the cirfile
|
||||
func (f *File) GetStartOffsetAndSize(ctx context.Context) (int64, int64, error) {
|
||||
err := f.flock(ctx, syscall.LOCK_SH)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
defer f.unflock()
|
||||
err = f.readMeta()
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
chunks := f.getFileChunks()
|
||||
return f.FileOffset, totalChunksSize(chunks), nil
|
||||
}
|
||||
|
||||
type fileChunk struct {
|
||||
StartPos int64
|
||||
Len int64
|
||||
}
|
||||
|
||||
func totalChunksSize(chunks []fileChunk) int64 {
|
||||
var rtn int64
|
||||
for _, chunk := range chunks {
|
||||
rtn += chunk.Len
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func advanceChunks(chunks []fileChunk, offset int64) []fileChunk {
|
||||
if offset < 0 {
|
||||
panic(fmt.Sprintf("invalid negative offset: %d", offset))
|
||||
}
|
||||
if offset == 0 {
|
||||
return chunks
|
||||
}
|
||||
var rtn []fileChunk
|
||||
for _, chunk := range chunks {
|
||||
if offset >= chunk.Len {
|
||||
offset = offset - chunk.Len
|
||||
continue
|
||||
}
|
||||
if offset == 0 {
|
||||
rtn = append(rtn, chunk)
|
||||
} else {
|
||||
rtn = append(rtn, fileChunk{chunk.StartPos + offset, chunk.Len - offset})
|
||||
offset = 0
|
||||
}
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (f *File) getFileChunks() []fileChunk {
|
||||
if f.StartPos == FilePosEmpty {
|
||||
return nil
|
||||
}
|
||||
if f.EndPos >= f.StartPos {
|
||||
return []fileChunk{fileChunk{f.StartPos, f.EndPos - f.StartPos + 1}}
|
||||
}
|
||||
return []fileChunk{
|
||||
fileChunk{f.StartPos, f.FileDataSize - f.StartPos},
|
||||
fileChunk{0, f.EndPos + 1},
|
||||
}
|
||||
}
|
||||
|
||||
func (f *File) getFreeChunks() []fileChunk {
|
||||
if f.StartPos == FilePosEmpty {
|
||||
return []fileChunk{fileChunk{0, f.MaxSize}}
|
||||
}
|
||||
if (f.EndPos == f.StartPos-1) || (f.StartPos == 0 && f.EndPos == f.MaxSize-1) {
|
||||
return nil
|
||||
}
|
||||
if f.EndPos < f.StartPos {
|
||||
return []fileChunk{fileChunk{f.EndPos + 1, f.StartPos - f.EndPos - 1}}
|
||||
}
|
||||
var rtn []fileChunk
|
||||
if f.EndPos < f.MaxSize-1 {
|
||||
rtn = append(rtn, fileChunk{f.EndPos + 1, f.MaxSize - f.EndPos - 1})
|
||||
}
|
||||
if f.StartPos > 0 {
|
||||
rtn = append(rtn, fileChunk{0, f.StartPos})
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
// returns (offset, data, err)
|
||||
func (f *File) ReadAll(ctx context.Context) (int64, []byte, error) {
|
||||
err := f.flock(ctx, syscall.LOCK_SH)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
defer f.unflock()
|
||||
err = f.readMeta()
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
chunks := f.getFileChunks()
|
||||
curSize := totalChunksSize(chunks)
|
||||
buf := make([]byte, curSize)
|
||||
realOffset, nr, err := f.internalReadNext(buf, 0)
|
||||
return realOffset, buf[0:nr], err
|
||||
}
|
||||
|
||||
func (f *File) ReadAtWithMax(ctx context.Context, offset int64, maxSize int64) (int64, []byte, error) {
|
||||
err := f.flock(ctx, syscall.LOCK_SH)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
defer f.unflock()
|
||||
err = f.readMeta()
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
chunks := f.getFileChunks()
|
||||
curSize := totalChunksSize(chunks)
|
||||
var buf []byte
|
||||
if maxSize > curSize {
|
||||
buf = make([]byte, curSize)
|
||||
} else {
|
||||
buf = make([]byte, maxSize)
|
||||
}
|
||||
realOffset, nr, err := f.internalReadNext(buf, offset)
|
||||
return realOffset, buf[0:nr], err
|
||||
}
|
||||
|
||||
func (f *File) internalReadNext(buf []byte, offset int64) (int64, int, error) {
|
||||
if offset < f.FileOffset {
|
||||
offset = f.FileOffset
|
||||
}
|
||||
relativeOffset := offset - f.FileOffset
|
||||
chunks := f.getFileChunks()
|
||||
curSize := totalChunksSize(chunks)
|
||||
if offset >= f.FileOffset+curSize {
|
||||
return f.FileOffset + curSize, 0, nil
|
||||
}
|
||||
chunks = advanceChunks(chunks, relativeOffset)
|
||||
numRead := 0
|
||||
for _, chunk := range chunks {
|
||||
if numRead >= len(buf) {
|
||||
break
|
||||
}
|
||||
toRead := len(buf) - numRead
|
||||
if toRead > int(chunk.Len) {
|
||||
toRead = int(chunk.Len)
|
||||
}
|
||||
nr, err := f.OSFile.ReadAt(buf[numRead:numRead+toRead], chunk.StartPos+HeaderLen)
|
||||
if err != nil {
|
||||
return offset, 0, err
|
||||
}
|
||||
numRead += nr
|
||||
}
|
||||
return offset, numRead, nil
|
||||
}
|
||||
|
||||
// returns (realOffset, numread, error)
|
||||
// will only return io.EOF when len(data) == 0, otherwise will just do a short read
|
||||
func (f *File) ReadNext(ctx context.Context, buf []byte, offset int64) (int64, int, error) {
|
||||
err := f.flock(ctx, syscall.LOCK_SH)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
defer f.unflock()
|
||||
err = f.readMeta()
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
return f.internalReadNext(buf, offset)
|
||||
}
|
||||
|
||||
func (f *File) ensureFreeSpace(requiredSpace int64) error {
|
||||
chunks := f.getFileChunks()
|
||||
curSpace := f.MaxSize - totalChunksSize(chunks)
|
||||
if curSpace >= requiredSpace {
|
||||
return nil
|
||||
}
|
||||
neededSpace := requiredSpace - curSpace
|
||||
if requiredSpace >= f.MaxSize || f.StartPos == FilePosEmpty {
|
||||
f.StartPos = FilePosEmpty
|
||||
f.EndPos = 0
|
||||
f.FileOffset += neededSpace
|
||||
} else {
|
||||
f.StartPos = (f.StartPos + neededSpace) % f.MaxSize
|
||||
f.FileOffset += neededSpace
|
||||
}
|
||||
return f.writeMeta()
|
||||
}
|
||||
|
||||
// does not implement io.WriterAt (needs context)
|
||||
func (f *File) WriteAt(ctx context.Context, buf []byte, writePos int64) error {
|
||||
if writePos < 0 {
|
||||
return fmt.Errorf("WriteAt got invalid writePos[%d]", writePos)
|
||||
}
|
||||
err := f.flock(ctx, syscall.LOCK_EX)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.unflock()
|
||||
err = f.readMeta()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
chunks := f.getFileChunks()
|
||||
currentSize := totalChunksSize(chunks)
|
||||
if writePos < f.FileOffset {
|
||||
negOffset := f.FileOffset - writePos
|
||||
if negOffset >= int64(len(buf)) {
|
||||
return nil
|
||||
}
|
||||
buf = buf[negOffset:]
|
||||
writePos = f.FileOffset
|
||||
}
|
||||
if writePos > f.FileOffset+currentSize {
|
||||
// fill gap with zero bytes
|
||||
posOffset := writePos - (f.FileOffset + currentSize)
|
||||
err = f.ensureFreeSpace(int64(posOffset))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var zeroBuf []byte
|
||||
if posOffset >= f.MaxSize {
|
||||
zeroBuf = make([]byte, f.MaxSize)
|
||||
} else {
|
||||
zeroBuf = make([]byte, posOffset)
|
||||
}
|
||||
err = f.internalAppendData(zeroBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// recalc chunks/currentSize
|
||||
chunks = f.getFileChunks()
|
||||
currentSize = totalChunksSize(chunks)
|
||||
// after writing the zero bytes, writePos == f.FileOffset+currentSize (the rest is a straight append)
|
||||
}
|
||||
// now writePos >= f.FileOffset && writePos <= f.FileOffset+currentSize (check invariant)
|
||||
if writePos < f.FileOffset || writePos > f.FileOffset+currentSize {
|
||||
panic(fmt.Sprintf("invalid writePos, invariant violated writepos[%d] fileoffset[%d] currentsize[%d]", writePos, f.FileOffset, currentSize))
|
||||
}
|
||||
// overwrite existing data (in chunks). advance by writePosOffset
|
||||
writePosOffset := writePos - f.FileOffset
|
||||
if writePosOffset < currentSize {
|
||||
advChunks := advanceChunks(chunks, writePosOffset)
|
||||
nw, err := f.writeToChunks(buf, advChunks, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buf = buf[nw:]
|
||||
if len(buf) == 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
// buf contains what was unwritten. this unwritten data is now just a straight append
|
||||
return f.internalAppendData(buf)
|
||||
}
|
||||
|
||||
// try writing to chunks, returns (nw, error)
|
||||
func (f *File) writeToChunks(buf []byte, chunks []fileChunk, updatePos bool) (int64, error) {
|
||||
var numWrite int64
|
||||
for _, chunk := range chunks {
|
||||
if numWrite >= int64(len(buf)) {
|
||||
break
|
||||
}
|
||||
if chunk.Len == 0 {
|
||||
continue
|
||||
}
|
||||
toWrite := int64(len(buf)) - numWrite
|
||||
if toWrite > chunk.Len {
|
||||
toWrite = chunk.Len
|
||||
}
|
||||
nw, err := f.OSFile.WriteAt(buf[numWrite:numWrite+toWrite], chunk.StartPos+HeaderLen)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if updatePos {
|
||||
if chunk.StartPos+int64(nw) > f.FileDataSize {
|
||||
f.FileDataSize = chunk.StartPos + int64(nw)
|
||||
}
|
||||
if f.StartPos == FilePosEmpty {
|
||||
f.StartPos = chunk.StartPos
|
||||
}
|
||||
f.EndPos = chunk.StartPos + int64(nw) - 1
|
||||
}
|
||||
numWrite += int64(nw)
|
||||
}
|
||||
return numWrite, nil
|
||||
}
|
||||
|
||||
func (f *File) internalAppendData(buf []byte) error {
|
||||
err := f.ensureFreeSpace(int64(len(buf)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(buf) >= int(f.MaxSize) {
|
||||
buf = buf[len(buf)-int(f.MaxSize):]
|
||||
}
|
||||
chunks := f.getFreeChunks()
|
||||
// don't track nw because we know we have enough free space to write entire buf
|
||||
_, err = f.writeToChunks(buf, chunks, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = f.writeMeta()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *File) AppendData(ctx context.Context, buf []byte) error {
|
||||
err := f.flock(ctx, syscall.LOCK_EX)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.unflock()
|
||||
err = f.readMeta()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return f.internalAppendData(buf)
|
||||
}
|
282
waveshell/pkg/cirfile/cirfile_test.go
Normal file
282
waveshell/pkg/cirfile/cirfile_test.go
Normal file
@ -0,0 +1,282 @@
|
||||
package cirfile
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func validateFileSize(t *testing.T, name string, size int) {
|
||||
finfo, err := os.Stat(name)
|
||||
if err != nil {
|
||||
t.Fatalf("error stating file[%s]: %v", name, err)
|
||||
}
|
||||
if int(finfo.Size()) != size {
|
||||
t.Fatalf("invalid file[%s] expected[%d] got[%d]", name, size, finfo.Size())
|
||||
}
|
||||
}
|
||||
|
||||
func validateMeta(t *testing.T, desc string, f *File, startPos int64, endPos int64, dataSize int64, offset int64) {
|
||||
if f.StartPos != startPos || f.EndPos != endPos || f.FileDataSize != dataSize || f.FileOffset != offset {
|
||||
t.Fatalf("metadata error (%s): startpos[%d %d] endpos[%d %d] filedatasize[%d %d] fileoffset[%d %d]", desc, f.StartPos, startPos, f.EndPos, endPos, f.FileDataSize, dataSize, f.FileOffset, offset)
|
||||
}
|
||||
}
|
||||
|
||||
func dumpFile(name string) {
|
||||
barr, _ := os.ReadFile(name)
|
||||
str := string(barr)
|
||||
str = strings.ReplaceAll(str, "\x00", ".")
|
||||
fmt.Printf("%s<<<\n%s\n>>>\n", name, str)
|
||||
}
|
||||
|
||||
func makeData(size int) string {
|
||||
var rtn string
|
||||
for {
|
||||
if len(rtn) >= size {
|
||||
break
|
||||
}
|
||||
needed := size - len(rtn)
|
||||
if needed < 10 {
|
||||
rtn += "123456789\n"[0:needed]
|
||||
break
|
||||
}
|
||||
rtn += "123456789\n"
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func TestCreate(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
f1Name := path.Join(tempDir, "f1.cf")
|
||||
f, err := OpenCirFile(f1Name)
|
||||
if err == nil || f != nil {
|
||||
t.Fatalf("OpenCirFile f1.cf should fail (no file)")
|
||||
}
|
||||
f, err = CreateCirFile(f1Name, 100)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateCirFile f1.cf failed: %v", err)
|
||||
}
|
||||
if f == nil {
|
||||
t.Fatalf("CreateCirFile f1.cf returned nil")
|
||||
}
|
||||
err = f.ReadMeta(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("cannot readmeta from f1.cf: %v", err)
|
||||
}
|
||||
validateFileSize(t, f1Name, 256)
|
||||
if f.Version != CurrentVersion || f.MaxSize != 100 || f.FileOffset != 0 || f.StartPos != FilePosEmpty || f.EndPos != 0 || f.FileDataSize != 0 || f.FlockStatus != 0 {
|
||||
t.Fatalf("error with initial metadata #%v", f)
|
||||
}
|
||||
buf := make([]byte, 200)
|
||||
realOffset, nr, err := f.ReadNext(context.Background(), buf, 0)
|
||||
if realOffset != 0 || nr != 0 || err != nil {
|
||||
t.Fatalf("error with empty read: real-offset[%d] nr[%d] err[%v]", realOffset, nr, err)
|
||||
}
|
||||
realOffset, nr, err = f.ReadNext(context.Background(), buf, 1000)
|
||||
if realOffset != 0 || nr != 0 || err != nil {
|
||||
t.Fatalf("error with empty read: real-offset[%d] nr[%d] err[%v]", realOffset, nr, err)
|
||||
}
|
||||
f2, err := CreateCirFile(f1Name, 100)
|
||||
if err == nil || f2 != nil {
|
||||
t.Fatalf("should be an error to create duplicate CirFile")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFile(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
f1Name := path.Join(tempDir, "f1.cf")
|
||||
f, err := CreateCirFile(f1Name, 100)
|
||||
if err != nil {
|
||||
t.Fatalf("cannot create cirfile: %v", err)
|
||||
}
|
||||
err = f.AppendData(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("cannot append data: %v", err)
|
||||
}
|
||||
validateFileSize(t, f1Name, HeaderLen)
|
||||
validateMeta(t, "1", f, FilePosEmpty, 0, 0, 0)
|
||||
err = f.AppendData(context.Background(), []byte("hello"))
|
||||
if err != nil {
|
||||
t.Fatalf("cannot append data: %v", err)
|
||||
}
|
||||
validateFileSize(t, f1Name, HeaderLen+5)
|
||||
validateMeta(t, "2", f, 0, 4, 5, 0)
|
||||
err = f.AppendData(context.Background(), []byte(" foo"))
|
||||
if err != nil {
|
||||
t.Fatalf("cannot append data: %v", err)
|
||||
}
|
||||
validateFileSize(t, f1Name, HeaderLen+9)
|
||||
validateMeta(t, "3", f, 0, 8, 9, 0)
|
||||
err = f.AppendData(context.Background(), []byte("\n"+makeData(20)))
|
||||
if err != nil {
|
||||
t.Fatalf("cannot append data: %v", err)
|
||||
}
|
||||
validateFileSize(t, f1Name, HeaderLen+30)
|
||||
validateMeta(t, "4", f, 0, 29, 30, 0)
|
||||
|
||||
data120 := makeData(120)
|
||||
err = f.AppendData(context.Background(), []byte(data120))
|
||||
if err != nil {
|
||||
t.Fatalf("cannot append data: %v", err)
|
||||
}
|
||||
validateFileSize(t, f1Name, HeaderLen+100)
|
||||
validateMeta(t, "5", f, 0, 99, 100, 50)
|
||||
err = f.AppendData(context.Background(), []byte("foo "))
|
||||
if err != nil {
|
||||
t.Fatalf("cannot append data: %v", err)
|
||||
}
|
||||
validateFileSize(t, f1Name, HeaderLen+100)
|
||||
validateMeta(t, "6", f, 4, 3, 100, 54)
|
||||
|
||||
buf := make([]byte, 5)
|
||||
realOffset, nr, err := f.ReadNext(context.Background(), buf, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("cannot ReadNext: %v", err)
|
||||
}
|
||||
if realOffset != 54 {
|
||||
t.Fatalf("wrong realoffset got[%d] expected[%d]", realOffset, 54)
|
||||
}
|
||||
if nr != 5 {
|
||||
t.Fatalf("wrong nr got[%d] expected[%d]", nr, 5)
|
||||
}
|
||||
if string(buf[0:nr]) != "56789" {
|
||||
t.Fatalf("wrong buf return got[%s] expected[%s]", string(buf[0:nr]), "56789")
|
||||
}
|
||||
realOffset, nr, err = f.ReadNext(context.Background(), buf, 60)
|
||||
if err != nil {
|
||||
t.Fatalf("cannot readnext: %v", err)
|
||||
}
|
||||
if realOffset != 60 && nr != 5 {
|
||||
t.Fatalf("invalid rtn realoffset[%d] nr[%d]", realOffset, nr)
|
||||
}
|
||||
if string(buf[0:nr]) != "12345" {
|
||||
t.Fatalf("invalid rtn buf[%s]", string(buf[0:nr]))
|
||||
}
|
||||
realOffset, nr, err = f.ReadNext(context.Background(), buf, 800)
|
||||
if err != nil || realOffset != 154 || nr != 0 {
|
||||
t.Fatalf("invalid past end read: err[%v] realoffset[%d] nr[%d]", err, realOffset, nr)
|
||||
}
|
||||
realOffset, nr, err = f.ReadNext(context.Background(), buf, 150)
|
||||
if err != nil || realOffset != 150 || nr != 4 || string(buf[0:nr]) != "foo " {
|
||||
t.Fatalf("invalid end read: err[%v] realoffset[%d] nr[%d] buf[%s]", err, realOffset, nr, string(buf[0:nr]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlock(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
f1Name := path.Join(tempDir, "f1.cf")
|
||||
f, err := CreateCirFile(f1Name, 100)
|
||||
if err != nil {
|
||||
t.Fatalf("cannot create cirfile: %v", err)
|
||||
}
|
||||
fd2, err := os.OpenFile(f1Name, os.O_RDWR, 0777)
|
||||
if err != nil {
|
||||
t.Fatalf("cannot open file: %v", err)
|
||||
}
|
||||
err = syscall.Flock(int(fd2.Fd()), syscall.LOCK_EX)
|
||||
if err != nil {
|
||||
t.Fatalf("cannot lock fd: %v", err)
|
||||
}
|
||||
err = f.AppendData(nil, []byte("hello"))
|
||||
if err != syscall.EWOULDBLOCK {
|
||||
t.Fatalf("append should fail with EWOULDBLOCK")
|
||||
}
|
||||
timeoutCtx, _ := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
startTs := time.Now()
|
||||
err = f.ReadMeta(timeoutCtx)
|
||||
if err != context.DeadlineExceeded {
|
||||
t.Fatalf("readmeta should fail with context.DeadlineExceeded")
|
||||
}
|
||||
dur := time.Now().Sub(startTs)
|
||||
if dur < 20*time.Millisecond {
|
||||
t.Fatalf("readmeta should take at least 20ms")
|
||||
}
|
||||
syscall.Flock(int(fd2.Fd()), syscall.LOCK_UN)
|
||||
err = f.ReadMeta(timeoutCtx)
|
||||
if err != nil {
|
||||
t.Fatalf("readmeta err: %v", err)
|
||||
}
|
||||
err = syscall.Flock(int(fd2.Fd()), syscall.LOCK_SH)
|
||||
if err != nil {
|
||||
t.Fatalf("cannot flock: %v", err)
|
||||
}
|
||||
err = f.AppendData(nil, []byte("hello"))
|
||||
if err != syscall.EWOULDBLOCK {
|
||||
t.Fatalf("append should fail with EWOULDBLOCK")
|
||||
}
|
||||
err = f.ReadMeta(timeoutCtx)
|
||||
if err != nil {
|
||||
t.Fatalf("readmeta err (should work because LOCK_SH): %v", err)
|
||||
}
|
||||
fd2.Close()
|
||||
err = f.AppendData(nil, []byte("hello"))
|
||||
if err != nil {
|
||||
t.Fatalf("append error (should work fd2 was closed): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteAt(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
f1Name := path.Join(tempDir, "f1.cf")
|
||||
f, err := CreateCirFile(f1Name, 100)
|
||||
if err != nil {
|
||||
t.Fatalf("cannot create cirfile: %v", err)
|
||||
}
|
||||
err = f.WriteAt(nil, []byte("hello\nmike"), 4)
|
||||
if err != nil {
|
||||
t.Fatalf("writeat error: %v", err)
|
||||
}
|
||||
err = f.WriteAt(nil, []byte("t"), 2)
|
||||
if err != nil {
|
||||
t.Fatalf("writeat error: %v", err)
|
||||
}
|
||||
err = f.WriteAt(nil, []byte("more"), 30)
|
||||
if err != nil {
|
||||
t.Fatalf("writeat error: %v", err)
|
||||
}
|
||||
err = f.WriteAt(nil, []byte("\n"), 19)
|
||||
if err != nil {
|
||||
t.Fatalf("writeat error: %v", err)
|
||||
}
|
||||
dumpFile(f1Name)
|
||||
err = f.WriteAt(nil, []byte("hello"), 200)
|
||||
if err != nil {
|
||||
t.Fatalf("writeat error: %v", err)
|
||||
}
|
||||
buf := make([]byte, 10)
|
||||
realOffset, nr, err := f.ReadNext(context.Background(), buf, 200)
|
||||
if err != nil || realOffset != 200 || nr != 5 || string(buf[0:nr]) != "hello" {
|
||||
t.Fatalf("invalid readnext: err[%v] realoffset[%d] nr[%d] buf[%s]", err, realOffset, nr, string(buf[0:nr]))
|
||||
}
|
||||
err = f.WriteAt(nil, []byte("0123456789\n"), 100)
|
||||
if err != nil {
|
||||
t.Fatalf("writeat error: %v", err)
|
||||
}
|
||||
dumpFile(f1Name)
|
||||
dataStr := makeData(200)
|
||||
err = f.WriteAt(nil, []byte(dataStr), 50)
|
||||
if err != nil {
|
||||
t.Fatalf("writeat error: %v", err)
|
||||
}
|
||||
dumpFile(f1Name)
|
||||
|
||||
dataStr = makeData(1000)
|
||||
err = f.WriteAt(nil, []byte(dataStr), 1002)
|
||||
if err != nil {
|
||||
t.Fatalf("writeat error: %v", err)
|
||||
}
|
||||
err = f.WriteAt(nil, []byte("hello\n"), 2010)
|
||||
if err != nil {
|
||||
t.Fatalf("writeat error: %v", err)
|
||||
}
|
||||
err = f.AppendData(nil, []byte("foo\n"))
|
||||
if err != nil {
|
||||
t.Fatalf("appenddata error: %v", err)
|
||||
}
|
||||
dumpFile(f1Name)
|
||||
}
|
471
waveshell/pkg/cmdtail/cmdtail.go
Normal file
471
waveshell/pkg/cmdtail/cmdtail.go
Normal file
@ -0,0 +1,471 @@
|
||||
package cmdtail
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"regexp"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/commandlinedev/apishell/pkg/base"
|
||||
"github.com/commandlinedev/apishell/pkg/packet"
|
||||
)
|
||||
|
||||
const MaxDataBytes = 4096
|
||||
const FileTypePty = "ptyout"
|
||||
const FileTypeRun = "runout"
|
||||
|
||||
type Tailer struct {
|
||||
Lock *sync.Mutex
|
||||
WatchList map[base.CommandKey]CmdWatchEntry
|
||||
Watcher *fsnotify.Watcher
|
||||
Sender *packet.PacketSender
|
||||
Gen FileNameGenerator
|
||||
Sessions map[string]bool
|
||||
}
|
||||
|
||||
type TailPos struct {
|
||||
ReqId string
|
||||
Running bool // an active tailer sending data
|
||||
TailPtyPos int64
|
||||
TailRunPos int64
|
||||
Follow bool
|
||||
}
|
||||
|
||||
type CmdWatchEntry struct {
|
||||
CmdKey base.CommandKey
|
||||
FilePtyLen int64
|
||||
FileRunLen int64
|
||||
Tails []TailPos
|
||||
Done bool
|
||||
}
|
||||
|
||||
type FileNameGenerator interface {
|
||||
PtyOutFile(ck base.CommandKey) string
|
||||
RunOutFile(ck base.CommandKey) string
|
||||
SessionDir(sessionId string) string
|
||||
}
|
||||
|
||||
func (w CmdWatchEntry) getTailPos(reqId string) (TailPos, bool) {
|
||||
for _, pos := range w.Tails {
|
||||
if pos.ReqId == reqId {
|
||||
return pos, true
|
||||
}
|
||||
}
|
||||
return TailPos{}, false
|
||||
}
|
||||
|
||||
func (w *CmdWatchEntry) updateTailPos(reqId string, newPos TailPos) {
|
||||
for idx, pos := range w.Tails {
|
||||
if pos.ReqId == reqId {
|
||||
w.Tails[idx] = newPos
|
||||
return
|
||||
}
|
||||
}
|
||||
w.Tails = append(w.Tails, newPos)
|
||||
}
|
||||
|
||||
func (w *CmdWatchEntry) removeTailPos(reqId string) {
|
||||
var newTails []TailPos
|
||||
for _, pos := range w.Tails {
|
||||
if pos.ReqId == reqId {
|
||||
continue
|
||||
}
|
||||
newTails = append(newTails, pos)
|
||||
}
|
||||
w.Tails = newTails
|
||||
}
|
||||
|
||||
func (pos TailPos) IsCurrent(entry CmdWatchEntry) bool {
|
||||
return pos.TailPtyPos >= entry.FilePtyLen && pos.TailRunPos >= entry.FileRunLen
|
||||
}
|
||||
|
||||
func (t *Tailer) updateTailPos_nolock(cmdKey base.CommandKey, reqId string, pos TailPos) {
|
||||
entry, found := t.WatchList[cmdKey]
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
entry.updateTailPos(reqId, pos)
|
||||
t.WatchList[cmdKey] = entry
|
||||
}
|
||||
|
||||
func (t *Tailer) removeTailPos(cmdKey base.CommandKey, reqId string) {
|
||||
t.Lock.Lock()
|
||||
defer t.Lock.Unlock()
|
||||
t.removeTailPos_nolock(cmdKey, reqId)
|
||||
}
|
||||
|
||||
func (t *Tailer) removeTailPos_nolock(cmdKey base.CommandKey, reqId string) {
|
||||
entry, found := t.WatchList[cmdKey]
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
entry.removeTailPos(reqId)
|
||||
t.WatchList[cmdKey] = entry
|
||||
if len(entry.Tails) == 0 {
|
||||
t.removeWatch_nolock(cmdKey)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tailer) removeWatch_nolock(cmdKey base.CommandKey) {
|
||||
// delete from watchlist, remove watches
|
||||
delete(t.WatchList, cmdKey)
|
||||
t.Watcher.Remove(t.Gen.PtyOutFile(cmdKey))
|
||||
t.Watcher.Remove(t.Gen.RunOutFile(cmdKey))
|
||||
}
|
||||
|
||||
func (t *Tailer) getEntryAndPos_nolock(cmdKey base.CommandKey, reqId string) (CmdWatchEntry, TailPos, bool) {
|
||||
entry, found := t.WatchList[cmdKey]
|
||||
if !found {
|
||||
return CmdWatchEntry{}, TailPos{}, false
|
||||
}
|
||||
pos, found := entry.getTailPos(reqId)
|
||||
if !found {
|
||||
return CmdWatchEntry{}, TailPos{}, false
|
||||
}
|
||||
return entry, pos, true
|
||||
}
|
||||
|
||||
func (t *Tailer) addSessionWatcher(sessionId string) error {
|
||||
t.Lock.Lock()
|
||||
defer t.Lock.Unlock()
|
||||
|
||||
if t.Sessions[sessionId] {
|
||||
return
|
||||
}
|
||||
sdir := t.Gen.SessionDir(sessionId)
|
||||
err := t.Watcher.Add(sdir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.Sessions[sessionId] = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Tailer) removeSessionWatcher(sessionId string) {
|
||||
t.Lock.Lock()
|
||||
defer t.Lock.Unlock()
|
||||
|
||||
if !t.Sessions[sessionId] {
|
||||
return
|
||||
}
|
||||
sdir := t.Gen.SessionDir(sessionId)
|
||||
t.Watcher.Remove(sdir)
|
||||
}
|
||||
|
||||
func MakeTailer(sender *packet.PacketSender, gen FileNameGenerator) (*Tailer, error) {
|
||||
rtn := &Tailer{
|
||||
Lock: &sync.Mutex{},
|
||||
WatchList: make(map[base.CommandKey]CmdWatchEntry),
|
||||
Sessions: make(map[string]bool),
|
||||
Sender: sender,
|
||||
Gen: gen,
|
||||
}
|
||||
var err error
|
||||
rtn.Watcher, err = fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func (t *Tailer) readDataFromFile(fileName string, pos int64, maxBytes int) ([]byte, error) {
|
||||
fd, err := os.Open(fileName)
|
||||
defer fd.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buf := make([]byte, maxBytes)
|
||||
nr, err := fd.ReadAt(buf, pos)
|
||||
if err != nil && err != io.EOF { // ignore EOF error
|
||||
return nil, err
|
||||
}
|
||||
return buf[0:nr], nil
|
||||
}
|
||||
|
||||
func (t *Tailer) makeCmdDataPacket(entry CmdWatchEntry, pos TailPos) (*packet.CmdDataPacketType, error) {
|
||||
dataPacket := packet.MakeCmdDataPacket(pos.ReqId)
|
||||
dataPacket.CK = entry.CmdKey
|
||||
dataPacket.PtyPos = pos.TailPtyPos
|
||||
dataPacket.RunPos = pos.TailRunPos
|
||||
if entry.FilePtyLen > pos.TailPtyPos {
|
||||
ptyData, err := t.readDataFromFile(t.Gen.PtyOutFile(entry.CmdKey), pos.TailPtyPos, MaxDataBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dataPacket.PtyData64 = base64.StdEncoding.EncodeToString(ptyData)
|
||||
dataPacket.PtyDataLen = len(ptyData)
|
||||
}
|
||||
if entry.FileRunLen > pos.TailRunPos {
|
||||
runData, err := t.readDataFromFile(t.Gen.RunOutFile(entry.CmdKey), pos.TailRunPos, MaxDataBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dataPacket.RunData64 = base64.StdEncoding.EncodeToString(runData)
|
||||
dataPacket.RunDataLen = len(runData)
|
||||
}
|
||||
return dataPacket, nil
|
||||
}
|
||||
|
||||
// returns (data-packet, keepRunning)
|
||||
func (t *Tailer) runSingleDataTransfer(key base.CommandKey, reqId string) (*packet.CmdDataPacketType, bool, error) {
|
||||
t.Lock.Lock()
|
||||
entry, pos, foundPos := t.getEntryAndPos_nolock(key, reqId)
|
||||
t.Lock.Unlock()
|
||||
if !foundPos {
|
||||
return nil, false, nil
|
||||
}
|
||||
dataPacket, dataErr := t.makeCmdDataPacket(entry, pos)
|
||||
|
||||
t.Lock.Lock()
|
||||
defer t.Lock.Unlock()
|
||||
entry, pos, foundPos = t.getEntryAndPos_nolock(key, reqId)
|
||||
if !foundPos {
|
||||
return nil, false, nil
|
||||
}
|
||||
// pos was updated between first and second get, throw out data-packet and re-run
|
||||
if pos.TailPtyPos != dataPacket.PtyPos || pos.TailRunPos != dataPacket.RunPos {
|
||||
return nil, true, nil
|
||||
}
|
||||
if dataErr != nil {
|
||||
// error, so return error packet, and stop running
|
||||
pos.Running = false
|
||||
t.updateTailPos_nolock(key, reqId, pos)
|
||||
return nil, false, dataErr
|
||||
}
|
||||
pos.TailPtyPos += int64(dataPacket.PtyDataLen)
|
||||
pos.TailRunPos += int64(dataPacket.RunDataLen)
|
||||
if pos.IsCurrent(entry) {
|
||||
// we caught up, tail position equals file length
|
||||
pos.Running = false
|
||||
}
|
||||
t.updateTailPos_nolock(key, reqId, pos)
|
||||
return dataPacket, pos.Running, nil
|
||||
}
|
||||
|
||||
// returns (removed)
|
||||
func (t *Tailer) checkRemove(cmdKey base.CommandKey, reqId string) bool {
|
||||
t.Lock.Lock()
|
||||
defer t.Lock.Unlock()
|
||||
entry, pos, foundPos := t.getEntryAndPos_nolock(cmdKey, reqId)
|
||||
if !foundPos {
|
||||
return false
|
||||
}
|
||||
if !pos.IsCurrent(entry) {
|
||||
return false
|
||||
}
|
||||
if !pos.Follow || entry.Done {
|
||||
t.removeTailPos_nolock(cmdKey, reqId)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *Tailer) RunDataTransfer(key base.CommandKey, reqId string) {
|
||||
for {
|
||||
dataPacket, keepRunning, err := t.runSingleDataTransfer(key, reqId)
|
||||
if dataPacket != nil {
|
||||
t.Sender.SendPacket(dataPacket)
|
||||
}
|
||||
if err != nil {
|
||||
t.removeTailPos(key, reqId)
|
||||
t.Sender.SendErrorResponse(reqId, err)
|
||||
break
|
||||
}
|
||||
if !keepRunning {
|
||||
removed := t.checkRemove(key, reqId)
|
||||
if removed {
|
||||
t.Sender.SendResponse(reqId, true)
|
||||
}
|
||||
break
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tailer) tryStartRun_nolock(entry CmdWatchEntry, pos TailPos) {
|
||||
if pos.Running {
|
||||
return
|
||||
}
|
||||
if pos.IsCurrent(entry) {
|
||||
return
|
||||
}
|
||||
pos.Running = true
|
||||
t.updateTailPos_nolock(entry.CmdKey, pos.ReqId, pos)
|
||||
go t.RunDataTransfer(entry.CmdKey, pos.ReqId)
|
||||
}
|
||||
|
||||
var updateFileRe = regexp.MustCompile("/([a-z0-9-]+)/([a-z0-9-]+)\\.(ptyout|runout)$")
|
||||
|
||||
func (t *Tailer) updateFile(relFileName string) {
|
||||
m := updateFileRe.FindStringSubmatch(relFileName)
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
finfo, err := os.Stat(relFileName)
|
||||
if err != nil {
|
||||
t.Sender.SendPacket(packet.FmtMessagePacket("error trying to stat file '%s': %v", relFileName, err))
|
||||
return
|
||||
}
|
||||
cmdKey := base.MakeCommandKey(m[1], m[2])
|
||||
t.Lock.Lock()
|
||||
defer t.Lock.Unlock()
|
||||
entry, foundEntry := t.WatchList[cmdKey]
|
||||
if !foundEntry {
|
||||
return
|
||||
}
|
||||
fileType := m[3]
|
||||
if fileType == FileTypePty {
|
||||
entry.FilePtyLen = finfo.Size()
|
||||
} else if fileType == FileTypeRun {
|
||||
entry.FileRunLen = finfo.Size()
|
||||
}
|
||||
t.WatchList[cmdKey] = entry
|
||||
for _, pos := range entry.Tails {
|
||||
t.tryStartRun_nolock(entry, pos)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tailer) Run() {
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-t.Watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if event.Op&fsnotify.Write == fsnotify.Write {
|
||||
t.updateFile(event.Name)
|
||||
}
|
||||
|
||||
case err, ok := <-t.Watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
// what to do with this error? just send a message
|
||||
t.Sender.SendPacket(packet.FmtMessagePacket("error in tailer: %v", err))
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (t *Tailer) Close() error {
|
||||
return t.Watcher.Close()
|
||||
}
|
||||
|
||||
func max(v1 int64, v2 int64) int64 {
|
||||
if v1 > v2 {
|
||||
return v1
|
||||
}
|
||||
return v2
|
||||
}
|
||||
|
||||
func (entry *CmdWatchEntry) fillFilePos(gen FileNameGenerator) {
|
||||
ptyInfo, _ := os.Stat(gen.PtyOutFile(entry.CmdKey))
|
||||
if ptyInfo != nil {
|
||||
entry.FilePtyLen = ptyInfo.Size()
|
||||
}
|
||||
runoutInfo, _ := os.Stat(gen.RunOutFile(entry.CmdKey))
|
||||
if runoutInfo != nil {
|
||||
entry.FileRunLen = runoutInfo.Size()
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tailer) KeyDone(key base.CommandKey) {
|
||||
t.Lock.Lock()
|
||||
defer t.Lock.Unlock()
|
||||
entry, foundEntry := t.WatchList[key]
|
||||
if !foundEntry {
|
||||
return
|
||||
}
|
||||
entry.Done = true
|
||||
var newTails []TailPos
|
||||
for _, pos := range entry.Tails {
|
||||
if pos.IsCurrent(entry) {
|
||||
continue
|
||||
}
|
||||
newTails = append(newTails, pos)
|
||||
}
|
||||
entry.Tails = newTails
|
||||
t.WatchList[key] = entry
|
||||
if len(entry.Tails) == 0 {
|
||||
t.removeWatch_nolock(key)
|
||||
}
|
||||
t.WatchList[key] = entry
|
||||
}
|
||||
|
||||
func (t *Tailer) RemoveWatch(pk *packet.UntailCmdPacketType) {
|
||||
t.Lock.Lock()
|
||||
defer t.Lock.Unlock()
|
||||
t.removeTailPos_nolock(pk.CK, pk.ReqId)
|
||||
}
|
||||
|
||||
func (t *Tailer) AddFileWatches_nolock(key base.CommandKey, ptyOnly bool) error {
|
||||
ptyName := t.Gen.PtyOutFile(key)
|
||||
runName := t.Gen.RunOutFile(key)
|
||||
fmt.Printf("WATCH> add %s\n", ptyName)
|
||||
err := t.Watcher.Add(ptyName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ptyOnly {
|
||||
return nil
|
||||
}
|
||||
err = t.Watcher.Add(runName)
|
||||
if err != nil {
|
||||
t.Watcher.Remove(ptyName) // best effort clean up
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// returns (up-to-date/done, error)
|
||||
func (t *Tailer) AddWatch(getPacket *packet.GetCmdPacketType) (bool, error) {
|
||||
if err := getPacket.CK.Validate("getcmd"); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if getPacket.ReqId == "" {
|
||||
return false, fmt.Errorf("getcmd, no reqid specified")
|
||||
}
|
||||
t.Lock.Lock()
|
||||
defer t.Lock.Unlock()
|
||||
key := getPacket.CK
|
||||
entry, foundEntry := t.WatchList[key]
|
||||
if !foundEntry {
|
||||
// initialize entry, add watches
|
||||
entry = CmdWatchEntry{CmdKey: key}
|
||||
entry.fillFilePos(t.Gen)
|
||||
}
|
||||
pos, foundPos := entry.getTailPos(getPacket.ReqId)
|
||||
if !foundPos {
|
||||
// initialize a new tailpos
|
||||
pos = TailPos{ReqId: getPacket.ReqId}
|
||||
}
|
||||
// update tailpos with new values from getpacket
|
||||
pos.TailPtyPos = getPacket.PtyPos
|
||||
pos.TailRunPos = getPacket.RunPos
|
||||
pos.Follow = getPacket.Tail
|
||||
// convert negative pos to positive
|
||||
if pos.TailPtyPos < 0 {
|
||||
pos.TailPtyPos = max(0, entry.FilePtyLen+pos.TailPtyPos) // + because negative
|
||||
}
|
||||
if pos.TailRunPos < 0 {
|
||||
pos.TailRunPos = max(0, entry.FileRunLen+pos.TailRunPos) // + because negative
|
||||
}
|
||||
entry.updateTailPos(pos.ReqId, pos)
|
||||
if !pos.Follow && pos.IsCurrent(entry) {
|
||||
// don't add to t.WatchList, don't t.AddFileWatches_nolock, send rpc response
|
||||
return true, nil
|
||||
}
|
||||
if !foundEntry {
|
||||
err := t.AddFileWatches_nolock(key, getPacket.PtyOnly)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
t.WatchList[key] = entry
|
||||
t.tryStartRun_nolock(entry, pos)
|
||||
return false, nil
|
||||
}
|
144
waveshell/pkg/mpio/bufreader.go
Normal file
144
waveshell/pkg/mpio/bufreader.go
Normal file
@ -0,0 +1,144 @@
|
||||
package mpio
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/commandlinedev/apishell/pkg/packet"
|
||||
)
|
||||
|
||||
type FdReader struct {
|
||||
CVar *sync.Cond
|
||||
M *Multiplexer
|
||||
FdNum int
|
||||
Fd io.ReadCloser
|
||||
BufSize int
|
||||
Closed bool
|
||||
ShouldCloseFd bool
|
||||
IsPty bool
|
||||
}
|
||||
|
||||
func MakeFdReader(m *Multiplexer, fd io.ReadCloser, fdNum int, shouldCloseFd bool, isPty bool) *FdReader {
|
||||
fr := &FdReader{
|
||||
CVar: sync.NewCond(&sync.Mutex{}),
|
||||
M: m,
|
||||
FdNum: fdNum,
|
||||
Fd: fd,
|
||||
BufSize: 0,
|
||||
ShouldCloseFd: shouldCloseFd,
|
||||
IsPty: isPty,
|
||||
}
|
||||
return fr
|
||||
}
|
||||
|
||||
func (r *FdReader) Close() {
|
||||
r.CVar.L.Lock()
|
||||
defer r.CVar.L.Unlock()
|
||||
if r.Closed {
|
||||
return
|
||||
}
|
||||
if r.Fd != nil && r.ShouldCloseFd {
|
||||
r.Fd.Close()
|
||||
}
|
||||
r.CVar.Broadcast()
|
||||
}
|
||||
|
||||
func (r *FdReader) GetBufSize() int {
|
||||
r.CVar.L.Lock()
|
||||
defer r.CVar.L.Unlock()
|
||||
return r.BufSize
|
||||
}
|
||||
|
||||
func (r *FdReader) NotifyAck(ackLen int) {
|
||||
r.CVar.L.Lock()
|
||||
defer r.CVar.L.Unlock()
|
||||
if r.Closed {
|
||||
return
|
||||
}
|
||||
r.BufSize -= ackLen
|
||||
if r.BufSize < 0 {
|
||||
r.BufSize = 0
|
||||
}
|
||||
r.CVar.Broadcast()
|
||||
}
|
||||
|
||||
// !! inverse locking. must already hold the lock when you call this method.
|
||||
// will *unlock*, send the packet, and then *relock* once it is done.
|
||||
// this can prevent an unlikely deadlock where we are holding r.CVar.L and stuck on sender.SendCh
|
||||
func (r *FdReader) sendPacket_unlock(pk packet.PacketType) {
|
||||
r.CVar.L.Unlock()
|
||||
defer r.CVar.L.Lock()
|
||||
r.M.sendPacket(pk)
|
||||
}
|
||||
|
||||
// returns (success)
|
||||
func (r *FdReader) WriteWait(data []byte, isEof bool) bool {
|
||||
r.CVar.L.Lock()
|
||||
defer r.CVar.L.Unlock()
|
||||
for {
|
||||
bufAvail := ReadBufSize - r.BufSize
|
||||
if r.Closed {
|
||||
return false
|
||||
}
|
||||
if bufAvail == 0 {
|
||||
r.CVar.Wait()
|
||||
continue
|
||||
}
|
||||
writeLen := min(bufAvail, len(data))
|
||||
pk := r.M.makeDataPacket(r.FdNum, data[0:writeLen], nil)
|
||||
pk.Eof = isEof && (writeLen == len(data))
|
||||
r.BufSize += writeLen
|
||||
data = data[writeLen:]
|
||||
r.sendPacket_unlock(pk)
|
||||
if len(data) == 0 {
|
||||
return true
|
||||
}
|
||||
// do *not* do a CVar.Wait() here -- because we *unlocked* to send the packet, we should
|
||||
// recheck the condition before waiting to avoid deadlock.
|
||||
}
|
||||
}
|
||||
|
||||
func min(v1 int, v2 int) int {
|
||||
if v1 <= v2 {
|
||||
return v1
|
||||
}
|
||||
return v2
|
||||
}
|
||||
|
||||
func (r *FdReader) isClosed() bool {
|
||||
r.CVar.L.Lock()
|
||||
defer r.CVar.L.Unlock()
|
||||
return r.Closed
|
||||
}
|
||||
|
||||
func (r *FdReader) ReadLoop(wg *sync.WaitGroup) {
|
||||
defer r.Close()
|
||||
if wg != nil {
|
||||
defer wg.Done()
|
||||
}
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
nr, err := r.Fd.Read(buf)
|
||||
if r.isClosed() {
|
||||
return // should not send data or error if we already closed the fd
|
||||
}
|
||||
if nr > 0 || err == io.EOF {
|
||||
isOpen := r.WriteWait(buf[0:nr], (err == io.EOF))
|
||||
if !isOpen {
|
||||
return
|
||||
}
|
||||
if err == io.EOF {
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if r.IsPty {
|
||||
r.WriteWait(nil, true)
|
||||
return
|
||||
}
|
||||
errPk := r.M.makeDataPacket(r.FdNum, nil, err)
|
||||
r.M.sendPacket(errPk)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
112
waveshell/pkg/mpio/bufwriter.go
Normal file
112
waveshell/pkg/mpio/bufwriter.go
Normal file
@ -0,0 +1,112 @@
|
||||
package mpio
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type FdWriter struct {
|
||||
CVar *sync.Cond
|
||||
M *Multiplexer
|
||||
FdNum int
|
||||
Buffer []byte
|
||||
BufferLimit int
|
||||
Fd io.WriteCloser
|
||||
Eof bool
|
||||
Closed bool
|
||||
ShouldCloseFd bool
|
||||
Desc string
|
||||
}
|
||||
|
||||
func MakeFdWriter(m *Multiplexer, fd io.WriteCloser, fdNum int, shouldCloseFd bool, desc string) *FdWriter {
|
||||
fw := &FdWriter{
|
||||
CVar: sync.NewCond(&sync.Mutex{}),
|
||||
Fd: fd,
|
||||
M: m,
|
||||
FdNum: fdNum,
|
||||
ShouldCloseFd: shouldCloseFd,
|
||||
Desc: desc,
|
||||
BufferLimit: WriteBufSize,
|
||||
}
|
||||
return fw
|
||||
}
|
||||
|
||||
func (w *FdWriter) Close() {
|
||||
w.CVar.L.Lock()
|
||||
defer w.CVar.L.Unlock()
|
||||
if w.Closed {
|
||||
return
|
||||
}
|
||||
w.Closed = true
|
||||
if w.Fd != nil && w.ShouldCloseFd {
|
||||
w.Fd.Close()
|
||||
}
|
||||
w.Buffer = nil
|
||||
w.CVar.Broadcast()
|
||||
}
|
||||
|
||||
func (w *FdWriter) WaitForData() ([]byte, bool) {
|
||||
w.CVar.L.Lock()
|
||||
defer w.CVar.L.Unlock()
|
||||
for {
|
||||
if len(w.Buffer) > 0 || w.Eof || w.Closed {
|
||||
toWrite := w.Buffer
|
||||
w.Buffer = nil
|
||||
return toWrite, w.Eof
|
||||
}
|
||||
w.CVar.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *FdWriter) AddData(data []byte, eof bool) error {
|
||||
w.CVar.L.Lock()
|
||||
defer w.CVar.L.Unlock()
|
||||
if w.Closed || w.Eof {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("write to closed file %q (fd:%d) eof[%v]", w.Desc, w.FdNum, w.Eof)
|
||||
}
|
||||
if len(data) > 0 {
|
||||
if len(data)+len(w.Buffer) > w.BufferLimit {
|
||||
return fmt.Errorf("write exceeds buffer size %q (fd:%d) bufsize=%d (max=%d)", w.Desc, w.FdNum, len(data)+len(w.Buffer), w.BufferLimit)
|
||||
}
|
||||
w.Buffer = append(w.Buffer, data...)
|
||||
}
|
||||
if eof {
|
||||
w.Eof = true
|
||||
}
|
||||
w.CVar.Broadcast()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *FdWriter) WriteLoop(wg *sync.WaitGroup) {
|
||||
defer w.Close()
|
||||
if wg != nil {
|
||||
defer wg.Done()
|
||||
}
|
||||
for {
|
||||
data, isEof := w.WaitForData()
|
||||
// chunk the writes to make sure we send ample ack packets
|
||||
for len(data) > 0 {
|
||||
if w.Closed {
|
||||
return
|
||||
}
|
||||
chunkSize := min(len(data), MaxSingleWriteSize)
|
||||
chunk := data[0:chunkSize]
|
||||
nw, err := w.Fd.Write(chunk)
|
||||
if nw > 0 || err != nil {
|
||||
ack := w.M.makeDataAckPacket(w.FdNum, nw, err)
|
||||
w.M.sendPacket(ack)
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
data = data[chunkSize:]
|
||||
}
|
||||
if isEof {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
303
waveshell/pkg/mpio/mpio.go
Normal file
303
waveshell/pkg/mpio/mpio.go
Normal file
@ -0,0 +1,303 @@
|
||||
package mpio
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/commandlinedev/apishell/pkg/base"
|
||||
"github.com/commandlinedev/apishell/pkg/packet"
|
||||
)
|
||||
|
||||
const ReadBufSize = 128 * 1024
|
||||
const WriteBufSize = 128 * 1024
|
||||
const MaxSingleWriteSize = 4 * 1024
|
||||
const MaxTotalRunDataSize = 10 * ReadBufSize
|
||||
|
||||
type Multiplexer struct {
|
||||
Lock *sync.Mutex
|
||||
CK base.CommandKey
|
||||
FdReaders map[int]*FdReader // synchronized
|
||||
FdWriters map[int]*FdWriter // synchronized
|
||||
RunData map[int]*FdReader // synchronized
|
||||
CloseAfterStart []*os.File // synchronized
|
||||
|
||||
Sender *packet.PacketSender
|
||||
Input *packet.PacketParser
|
||||
Started bool
|
||||
UPR packet.UnknownPacketReporter
|
||||
|
||||
Debug bool
|
||||
}
|
||||
|
||||
func MakeMultiplexer(ck base.CommandKey, upr packet.UnknownPacketReporter) *Multiplexer {
|
||||
if upr == nil {
|
||||
upr = packet.DefaultUPR{}
|
||||
}
|
||||
return &Multiplexer{
|
||||
Lock: &sync.Mutex{},
|
||||
CK: ck,
|
||||
FdReaders: make(map[int]*FdReader),
|
||||
FdWriters: make(map[int]*FdWriter),
|
||||
UPR: upr,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Multiplexer) Close() {
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
|
||||
for _, fr := range m.FdReaders {
|
||||
fr.Close()
|
||||
}
|
||||
for _, fw := range m.FdWriters {
|
||||
fw.Close()
|
||||
}
|
||||
for _, fd := range m.CloseAfterStart {
|
||||
fd.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Multiplexer) HandleInputDone() {
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
|
||||
// close readers (obviously the done command needs no more input)
|
||||
for _, fr := range m.FdReaders {
|
||||
fr.Close()
|
||||
}
|
||||
|
||||
// ensure EOF on all writers (ignore error)
|
||||
for _, fw := range m.FdWriters {
|
||||
fw.AddData(nil, true)
|
||||
}
|
||||
}
|
||||
|
||||
// returns the *writer* to connect to process, reader is put in FdReaders
|
||||
func (m *Multiplexer) MakeReaderPipe(fdNum int) (*os.File, error) {
|
||||
pr, pw, err := os.Pipe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
m.FdReaders[fdNum] = MakeFdReader(m, pr, fdNum, true, false)
|
||||
m.CloseAfterStart = append(m.CloseAfterStart, pw)
|
||||
return pw, nil
|
||||
}
|
||||
|
||||
// returns the *reader* to connect to process, writer is put in FdWriters
|
||||
func (m *Multiplexer) MakeWriterPipe(fdNum int, desc string) (*os.File, error) {
|
||||
pr, pw, err := os.Pipe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
m.FdWriters[fdNum] = MakeFdWriter(m, pw, fdNum, true, desc)
|
||||
m.CloseAfterStart = append(m.CloseAfterStart, pr)
|
||||
return pr, nil
|
||||
}
|
||||
|
||||
// returns the *reader* to connect to process, writer is put in FdWriters
|
||||
func (m *Multiplexer) MakeStaticWriterPipe(fdNum int, data []byte, bufferLimit int, desc string) (*os.File, error) {
|
||||
pr, pw, err := os.Pipe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
fdWriter := MakeFdWriter(m, pw, fdNum, true, desc)
|
||||
fdWriter.BufferLimit = bufferLimit
|
||||
err = fdWriter.AddData(data, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.FdWriters[fdNum] = fdWriter
|
||||
m.CloseAfterStart = append(m.CloseAfterStart, pr)
|
||||
return pr, nil
|
||||
}
|
||||
|
||||
func (m *Multiplexer) MakeRawFdReader(fdNum int, fd io.ReadCloser, shouldClose bool, isPty bool) {
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
m.FdReaders[fdNum] = MakeFdReader(m, fd, fdNum, shouldClose, isPty)
|
||||
}
|
||||
|
||||
func (m *Multiplexer) MakeRawFdWriter(fdNum int, fd io.WriteCloser, shouldClose bool, desc string) {
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
m.FdWriters[fdNum] = MakeFdWriter(m, fd, fdNum, shouldClose, desc)
|
||||
}
|
||||
|
||||
func (m *Multiplexer) makeDataAckPacket(fdNum int, ackLen int, err error) *packet.DataAckPacketType {
|
||||
ack := packet.MakeDataAckPacket()
|
||||
ack.CK = m.CK
|
||||
ack.FdNum = fdNum
|
||||
ack.AckLen = ackLen
|
||||
if err != nil {
|
||||
ack.Error = err.Error()
|
||||
}
|
||||
return ack
|
||||
}
|
||||
|
||||
func (m *Multiplexer) makeDataPacket(fdNum int, data []byte, err error) *packet.DataPacketType {
|
||||
pk := packet.MakeDataPacket()
|
||||
pk.CK = m.CK
|
||||
pk.FdNum = fdNum
|
||||
pk.Data64 = base64.StdEncoding.EncodeToString(data)
|
||||
if err != nil {
|
||||
pk.Error = err.Error()
|
||||
}
|
||||
return pk
|
||||
}
|
||||
|
||||
func (m *Multiplexer) sendPacket(p packet.PacketType) {
|
||||
m.Sender.SendPacket(p)
|
||||
}
|
||||
|
||||
func (m *Multiplexer) launchWriters(wg *sync.WaitGroup) {
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
if wg != nil {
|
||||
wg.Add(len(m.FdWriters))
|
||||
}
|
||||
for _, fw := range m.FdWriters {
|
||||
go fw.WriteLoop(wg)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Multiplexer) launchReaders(wg *sync.WaitGroup) {
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
if wg != nil {
|
||||
wg.Add(len(m.FdReaders))
|
||||
}
|
||||
for _, fr := range m.FdReaders {
|
||||
go fr.ReadLoop(wg)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Multiplexer) startIO(packetParser *packet.PacketParser, sender *packet.PacketSender) {
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
if m.Started {
|
||||
panic("Multiplexer is already running, cannot start again")
|
||||
}
|
||||
m.Input = packetParser
|
||||
m.Sender = sender
|
||||
m.Started = true
|
||||
}
|
||||
|
||||
func (m *Multiplexer) runPacketInputLoop() *packet.CmdDonePacketType {
|
||||
defer m.HandleInputDone()
|
||||
for pk := range m.Input.MainCh {
|
||||
if m.Debug {
|
||||
fmt.Printf("PK-M> %s\n", packet.AsString(pk))
|
||||
}
|
||||
if pk.GetType() == packet.DataPacketStr {
|
||||
dataPacket := pk.(*packet.DataPacketType)
|
||||
err := m.processDataPacket(dataPacket)
|
||||
if err != nil {
|
||||
errPacket := m.makeDataAckPacket(dataPacket.FdNum, 0, err)
|
||||
m.sendPacket(errPacket)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if pk.GetType() == packet.DataAckPacketStr {
|
||||
ackPacket := pk.(*packet.DataAckPacketType)
|
||||
m.processAckPacket(ackPacket)
|
||||
continue
|
||||
}
|
||||
if pk.GetType() == packet.CmdDonePacketStr {
|
||||
donePacket := pk.(*packet.CmdDonePacketType)
|
||||
return donePacket
|
||||
}
|
||||
m.UPR.UnknownPacket(pk)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Multiplexer) WriteDataToFd(fdNum int, data []byte, isEof bool) error {
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
fw := m.FdWriters[fdNum]
|
||||
if fw == nil {
|
||||
// add a closed FdWriter as a placeholder so we only send one error
|
||||
fw := MakeFdWriter(m, nil, fdNum, false, "invalid-fd")
|
||||
fw.Close()
|
||||
m.FdWriters[fdNum] = fw
|
||||
return fmt.Errorf("write to closed file (no fd)")
|
||||
}
|
||||
err := fw.AddData(data, isEof)
|
||||
if err != nil {
|
||||
fw.Close()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Multiplexer) processDataPacket(dataPacket *packet.DataPacketType) error {
|
||||
realData, err := base64.StdEncoding.DecodeString(dataPacket.Data64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decoding base64 data: %w", err)
|
||||
}
|
||||
return m.WriteDataToFd(dataPacket.FdNum, realData, dataPacket.Eof)
|
||||
}
|
||||
|
||||
func (m *Multiplexer) processAckPacket(ackPacket *packet.DataAckPacketType) {
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
fr := m.FdReaders[ackPacket.FdNum]
|
||||
if fr == nil {
|
||||
return
|
||||
}
|
||||
fr.NotifyAck(ackPacket.AckLen)
|
||||
}
|
||||
|
||||
func (m *Multiplexer) closeTempStartFds() {
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
for _, fd := range m.CloseAfterStart {
|
||||
fd.Close()
|
||||
}
|
||||
m.CloseAfterStart = nil
|
||||
}
|
||||
|
||||
func (m *Multiplexer) RunIOAndWait(packetParser *packet.PacketParser, sender *packet.PacketSender, waitOnReaders bool, waitOnWriters bool, waitForInputLoop bool) *packet.CmdDonePacketType {
|
||||
m.startIO(packetParser, sender)
|
||||
m.closeTempStartFds()
|
||||
var wg sync.WaitGroup
|
||||
if waitOnReaders {
|
||||
m.launchReaders(&wg)
|
||||
} else {
|
||||
m.launchReaders(nil)
|
||||
}
|
||||
if waitOnWriters {
|
||||
m.launchWriters(&wg)
|
||||
} else {
|
||||
m.launchWriters(nil)
|
||||
}
|
||||
var donePacket *packet.CmdDonePacketType
|
||||
if waitForInputLoop {
|
||||
wg.Add(1)
|
||||
}
|
||||
go func() {
|
||||
if waitForInputLoop {
|
||||
defer wg.Done()
|
||||
}
|
||||
pkRtn := m.runPacketInputLoop()
|
||||
if pkRtn != nil {
|
||||
m.Lock.Lock()
|
||||
donePacket = pkRtn
|
||||
m.Lock.Unlock()
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
return donePacket
|
||||
}
|
34
waveshell/pkg/packet/combined.go
Normal file
34
waveshell/pkg/packet/combined.go
Normal file
@ -0,0 +1,34 @@
|
||||
package packet
|
||||
|
||||
type CombinedPacket struct {
|
||||
Type string `json:"type"`
|
||||
Success bool `json:"success"`
|
||||
Ts int64 `json:"ts"`
|
||||
Id string `json:"id,omitempty"`
|
||||
|
||||
SessionId string `json:"sessionid"`
|
||||
CmdId string `json:"cmdid"`
|
||||
|
||||
PtyPos int64 `json:"ptypos"`
|
||||
PtyLen int64 `json:"ptylen"`
|
||||
RunPos int64 `json:"runpos"`
|
||||
RunLen int64 `json:"runlen"`
|
||||
|
||||
Error string `json:"error"`
|
||||
NotFound bool `json:"notfound,omitempty"`
|
||||
Tail bool `json:"tail,omitempty"`
|
||||
Dir string `json:"dir"`
|
||||
ChDir string `json:"chdir,omitempty"`
|
||||
|
||||
Data string `json:"data"`
|
||||
PtyData string `json:"ptydata"`
|
||||
RunData string `json:"rundata"`
|
||||
Message string `json:"message"`
|
||||
Command string `json:"command"`
|
||||
|
||||
ScHomeDir string `json:"schomedir"`
|
||||
HomeDir string `json:"homedir"`
|
||||
Env []string `json:"env"`
|
||||
ExitCode int `json:"exitcode"`
|
||||
RunnerPid int `json:"runnerpid"`
|
||||
}
|
1178
waveshell/pkg/packet/packet.go
Normal file
1178
waveshell/pkg/packet/packet.go
Normal file
File diff suppressed because it is too large
Load Diff
238
waveshell/pkg/packet/parser.go
Normal file
238
waveshell/pkg/packet/parser.go
Normal file
@ -0,0 +1,238 @@
|
||||
package packet
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type PacketParser struct {
|
||||
Lock *sync.Mutex
|
||||
MainCh chan PacketType
|
||||
RpcMap map[string]*RpcEntry
|
||||
RpcHandler bool
|
||||
Err error
|
||||
}
|
||||
|
||||
type RpcEntry struct {
|
||||
ReqId string
|
||||
RespCh chan RpcResponsePacketType
|
||||
}
|
||||
|
||||
type RpcResponseIter struct {
|
||||
ReqId string
|
||||
Parser *PacketParser
|
||||
}
|
||||
|
||||
func (iter *RpcResponseIter) Next(ctx context.Context) (RpcResponsePacketType, error) {
|
||||
// will unregister the rpc on ResponseDone
|
||||
return iter.Parser.GetNextResponse(ctx, iter.ReqId)
|
||||
}
|
||||
|
||||
func (iter *RpcResponseIter) Close() {
|
||||
iter.Parser.UnRegisterRpc(iter.ReqId)
|
||||
}
|
||||
|
||||
func CombinePacketParsers(p1 *PacketParser, p2 *PacketParser, rpcHandler bool) *PacketParser {
|
||||
rtnParser := &PacketParser{
|
||||
Lock: &sync.Mutex{},
|
||||
MainCh: make(chan PacketType),
|
||||
RpcMap: make(map[string]*RpcEntry),
|
||||
RpcHandler: rpcHandler,
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for pk := range p1.MainCh {
|
||||
if rtnParser.RpcHandler {
|
||||
sent := rtnParser.trySendRpcResponse(pk)
|
||||
if sent {
|
||||
continue
|
||||
}
|
||||
}
|
||||
rtnParser.MainCh <- pk
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for pk := range p2.MainCh {
|
||||
if rtnParser.RpcHandler {
|
||||
sent := rtnParser.trySendRpcResponse(pk)
|
||||
if sent {
|
||||
continue
|
||||
}
|
||||
}
|
||||
rtnParser.MainCh <- pk
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(rtnParser.MainCh)
|
||||
}()
|
||||
return rtnParser
|
||||
}
|
||||
|
||||
// should have already registered rpc
|
||||
func (p *PacketParser) WaitForResponse(ctx context.Context, reqId string) RpcResponsePacketType {
|
||||
entry := p.getRpcEntry(reqId)
|
||||
if entry == nil {
|
||||
return nil
|
||||
}
|
||||
defer p.UnRegisterRpc(reqId)
|
||||
select {
|
||||
case resp := <-entry.RespCh:
|
||||
return resp
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PacketParser) GetResponseIter(reqId string) *RpcResponseIter {
|
||||
return &RpcResponseIter{Parser: p, ReqId: reqId}
|
||||
}
|
||||
|
||||
func (p *PacketParser) GetNextResponse(ctx context.Context, reqId string) (RpcResponsePacketType, error) {
|
||||
entry := p.getRpcEntry(reqId)
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
select {
|
||||
case resp := <-entry.RespCh:
|
||||
if resp.GetResponseDone() {
|
||||
p.UnRegisterRpc(reqId)
|
||||
}
|
||||
return resp, nil
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PacketParser) UnRegisterRpc(reqId string) {
|
||||
p.Lock.Lock()
|
||||
defer p.Lock.Unlock()
|
||||
entry := p.RpcMap[reqId]
|
||||
if entry != nil {
|
||||
close(entry.RespCh)
|
||||
delete(p.RpcMap, reqId)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PacketParser) RegisterRpc(reqId string) chan RpcResponsePacketType {
|
||||
return p.RegisterRpcSz(reqId, 2)
|
||||
}
|
||||
|
||||
func (p *PacketParser) RegisterRpcSz(reqId string, queueSize int) chan RpcResponsePacketType {
|
||||
p.Lock.Lock()
|
||||
defer p.Lock.Unlock()
|
||||
ch := make(chan RpcResponsePacketType, queueSize)
|
||||
entry := &RpcEntry{ReqId: reqId, RespCh: ch}
|
||||
p.RpcMap[reqId] = entry
|
||||
return ch
|
||||
}
|
||||
|
||||
func (p *PacketParser) getRpcEntry(reqId string) *RpcEntry {
|
||||
p.Lock.Lock()
|
||||
defer p.Lock.Unlock()
|
||||
entry := p.RpcMap[reqId]
|
||||
return entry
|
||||
}
|
||||
|
||||
func (p *PacketParser) trySendRpcResponse(pk PacketType) bool {
|
||||
respPk, ok := pk.(RpcResponsePacketType)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
p.Lock.Lock()
|
||||
defer p.Lock.Unlock()
|
||||
entry := p.RpcMap[respPk.GetResponseId()]
|
||||
if entry == nil {
|
||||
return false
|
||||
}
|
||||
// nonblocking send
|
||||
select {
|
||||
case entry.RespCh <- respPk:
|
||||
default:
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *PacketParser) GetErr() error {
|
||||
p.Lock.Lock()
|
||||
defer p.Lock.Unlock()
|
||||
return p.Err
|
||||
}
|
||||
|
||||
func (p *PacketParser) SetErr(err error) {
|
||||
p.Lock.Lock()
|
||||
defer p.Lock.Unlock()
|
||||
if p.Err == nil {
|
||||
p.Err = err
|
||||
}
|
||||
}
|
||||
|
||||
func MakePacketParser(input io.Reader, rpcHandler bool) *PacketParser {
|
||||
parser := &PacketParser{
|
||||
Lock: &sync.Mutex{},
|
||||
MainCh: make(chan PacketType),
|
||||
RpcMap: make(map[string]*RpcEntry),
|
||||
RpcHandler: rpcHandler,
|
||||
}
|
||||
bufReader := bufio.NewReader(input)
|
||||
go func() {
|
||||
defer func() {
|
||||
close(parser.MainCh)
|
||||
}()
|
||||
for {
|
||||
line, err := bufReader.ReadString('\n')
|
||||
if err == io.EOF {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
parser.SetErr(err)
|
||||
return
|
||||
}
|
||||
if line == "\n" {
|
||||
continue
|
||||
}
|
||||
// ##[len][json]\n
|
||||
// ##14{"hello":true}\n
|
||||
// ##N{...}
|
||||
bracePos := strings.Index(line, "{")
|
||||
if !strings.HasPrefix(line, "##") || bracePos == -1 {
|
||||
parser.MainCh <- MakeRawPacket(line[:len(line)-1])
|
||||
continue
|
||||
}
|
||||
packetLen := -1
|
||||
if line[2:bracePos] != "N" {
|
||||
packetLen, err = strconv.Atoi(line[2:bracePos])
|
||||
if err != nil || packetLen != len(line)-bracePos-1 {
|
||||
parser.MainCh <- MakeRawPacket(line[:len(line)-1])
|
||||
continue
|
||||
}
|
||||
}
|
||||
pk, err := ParseJsonPacket([]byte(line[bracePos:]))
|
||||
if err != nil {
|
||||
parser.MainCh <- MakeRawPacket(line[:len(line)-1])
|
||||
continue
|
||||
}
|
||||
if pk.GetType() == DonePacketStr {
|
||||
return
|
||||
}
|
||||
if pk.GetType() == PingPacketStr {
|
||||
continue
|
||||
}
|
||||
if parser.RpcHandler {
|
||||
sent := parser.trySendRpcResponse(pk)
|
||||
if sent {
|
||||
continue
|
||||
}
|
||||
}
|
||||
parser.MainCh <- pk
|
||||
}
|
||||
}()
|
||||
return parser
|
||||
}
|
192
waveshell/pkg/packet/shellstate.go
Normal file
192
waveshell/pkg/packet/shellstate.go
Normal file
@ -0,0 +1,192 @@
|
||||
package packet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/commandlinedev/apishell/pkg/binpack"
|
||||
"github.com/commandlinedev/apishell/pkg/statediff"
|
||||
)
|
||||
|
||||
const ShellStatePackVersion = 0
|
||||
const ShellStateDiffPackVersion = 0
|
||||
|
||||
type ShellState struct {
|
||||
Version string `json:"version"` // [type] [semver]
|
||||
Cwd string `json:"cwd,omitempty"`
|
||||
ShellVars []byte `json:"shellvars,omitempty"`
|
||||
Aliases string `json:"aliases,omitempty"`
|
||||
Funcs string `json:"funcs,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
HashVal string `json:"-"`
|
||||
}
|
||||
|
||||
type ShellStateDiff struct {
|
||||
Version string `json:"version"` // [type] [semver]
|
||||
BaseHash string `json:"basehash"`
|
||||
DiffHashArr []string `json:"diffhasharr,omitempty"`
|
||||
Cwd string `json:"cwd,omitempty"`
|
||||
VarsDiff []byte `json:"shellvarsdiff,omitempty"` // vardiff
|
||||
AliasesDiff []byte `json:"aliasesdiff,omitempty"` // linediff
|
||||
FuncsDiff []byte `json:"funcsdiff,omitempty"` // linediff
|
||||
Error string `json:"error,omitempty"`
|
||||
HashVal string `json:"-"`
|
||||
}
|
||||
|
||||
func (state ShellState) IsEmpty() bool {
|
||||
return state.Version == "" && state.Cwd == "" && len(state.ShellVars) == 0 && state.Aliases == "" && state.Funcs == "" && state.Error == ""
|
||||
}
|
||||
|
||||
// returns base64 hash of data
|
||||
func sha1Hash(data []byte) string {
|
||||
hvalRaw := sha1.Sum(data)
|
||||
hval := base64.StdEncoding.EncodeToString(hvalRaw[:])
|
||||
return hval
|
||||
}
|
||||
|
||||
// returns (SHA1, encoded-state)
|
||||
func (state ShellState) EncodeAndHash() (string, []byte) {
|
||||
var buf bytes.Buffer
|
||||
binpack.PackInt(&buf, ShellStatePackVersion)
|
||||
binpack.PackValue(&buf, []byte(state.Version))
|
||||
binpack.PackValue(&buf, []byte(state.Cwd))
|
||||
binpack.PackValue(&buf, state.ShellVars)
|
||||
binpack.PackValue(&buf, []byte(state.Aliases))
|
||||
binpack.PackValue(&buf, []byte(state.Funcs))
|
||||
binpack.PackValue(&buf, []byte(state.Error))
|
||||
return sha1Hash(buf.Bytes()), buf.Bytes()
|
||||
}
|
||||
|
||||
func (state ShellState) MarshalJSON() ([]byte, error) {
|
||||
_, encodedBytes := state.EncodeAndHash()
|
||||
return json.Marshal(encodedBytes)
|
||||
}
|
||||
|
||||
// caches HashVal in struct
|
||||
func (state *ShellState) GetHashVal(force bool) string {
|
||||
if state.HashVal == "" || force {
|
||||
state.HashVal, _ = state.EncodeAndHash()
|
||||
}
|
||||
return state.HashVal
|
||||
}
|
||||
|
||||
func (state *ShellState) DecodeShellState(barr []byte) error {
|
||||
state.HashVal = sha1Hash(barr)
|
||||
buf := bytes.NewBuffer(barr)
|
||||
u := binpack.MakeUnpacker(buf)
|
||||
version := u.UnpackInt("ShellState pack version")
|
||||
if version != ShellStatePackVersion {
|
||||
return fmt.Errorf("invalid ShellState pack version: %d", version)
|
||||
}
|
||||
state.Version = string(u.UnpackValue("ShellState.Version"))
|
||||
state.Cwd = string(u.UnpackValue("ShellState.Cwd"))
|
||||
state.ShellVars = u.UnpackValue("ShellState.ShellVars")
|
||||
state.Aliases = string(u.UnpackValue("ShellState.Aliases"))
|
||||
state.Funcs = string(u.UnpackValue("ShellState.Funcs"))
|
||||
state.Error = string(u.UnpackValue("ShellState.Error"))
|
||||
return u.Error()
|
||||
}
|
||||
|
||||
func (state *ShellState) UnmarshalJSON(jsonBytes []byte) error {
|
||||
var barr []byte
|
||||
err := json.Unmarshal(jsonBytes, &barr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return state.DecodeShellState(barr)
|
||||
}
|
||||
|
||||
func (sdiff ShellStateDiff) EncodeAndHash() (string, []byte) {
|
||||
var buf bytes.Buffer
|
||||
binpack.PackInt(&buf, ShellStateDiffPackVersion)
|
||||
binpack.PackValue(&buf, []byte(sdiff.Version))
|
||||
binpack.PackValue(&buf, []byte(sdiff.BaseHash))
|
||||
binpack.PackStrArr(&buf, sdiff.DiffHashArr)
|
||||
binpack.PackValue(&buf, []byte(sdiff.Cwd))
|
||||
binpack.PackValue(&buf, sdiff.VarsDiff)
|
||||
binpack.PackValue(&buf, sdiff.AliasesDiff)
|
||||
binpack.PackValue(&buf, sdiff.FuncsDiff)
|
||||
binpack.PackValue(&buf, []byte(sdiff.Error))
|
||||
return sha1Hash(buf.Bytes()), buf.Bytes()
|
||||
}
|
||||
|
||||
func (sdiff ShellStateDiff) MarshalJSON() ([]byte, error) {
|
||||
_, encodedBytes := sdiff.EncodeAndHash()
|
||||
return json.Marshal(encodedBytes)
|
||||
}
|
||||
|
||||
func (sdiff *ShellStateDiff) DecodeShellStateDiff(barr []byte) error {
|
||||
sdiff.HashVal = sha1Hash(barr)
|
||||
buf := bytes.NewBuffer(barr)
|
||||
u := binpack.MakeUnpacker(buf)
|
||||
version := u.UnpackInt("ShellState pack version")
|
||||
if version != ShellStateDiffPackVersion {
|
||||
return fmt.Errorf("invalid ShellStateDiff pack version: %d", version)
|
||||
}
|
||||
sdiff.Version = string(u.UnpackValue("ShellStateDiff.Version"))
|
||||
sdiff.BaseHash = string(u.UnpackValue("ShellStateDiff.BaseHash"))
|
||||
sdiff.DiffHashArr = u.UnpackStrArr("ShellStateDiff.DiffHashArr")
|
||||
sdiff.Cwd = string(u.UnpackValue("ShellStateDiff.Cwd"))
|
||||
sdiff.VarsDiff = u.UnpackValue("ShellStateDiff.VarsDiff")
|
||||
sdiff.AliasesDiff = u.UnpackValue("ShellStateDiff.AliasesDiff")
|
||||
sdiff.FuncsDiff = u.UnpackValue("ShellStateDiff.FuncsDiff")
|
||||
sdiff.Error = string(u.UnpackValue("ShellStateDiff.Error"))
|
||||
return u.Error()
|
||||
}
|
||||
|
||||
func (sdiff *ShellStateDiff) UnmarshalJSON(jsonBytes []byte) error {
|
||||
var barr []byte
|
||||
err := json.Unmarshal(jsonBytes, &barr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sdiff.DecodeShellStateDiff(barr)
|
||||
}
|
||||
|
||||
// caches HashVal in struct
|
||||
func (sdiff *ShellStateDiff) GetHashVal(force bool) string {
|
||||
if sdiff.HashVal == "" || force {
|
||||
sdiff.HashVal, _ = sdiff.EncodeAndHash()
|
||||
}
|
||||
return sdiff.HashVal
|
||||
}
|
||||
|
||||
func (sdiff ShellStateDiff) Dump(vars bool, aliases bool, funcs bool) {
|
||||
fmt.Printf("ShellStateDiff:\n")
|
||||
fmt.Printf(" version: %s\n", sdiff.Version)
|
||||
fmt.Printf(" base: %s\n", sdiff.BaseHash)
|
||||
fmt.Printf(" vars: %d, aliases: %d, funcs: %d\n", len(sdiff.VarsDiff), len(sdiff.AliasesDiff), len(sdiff.FuncsDiff))
|
||||
if sdiff.Error != "" {
|
||||
fmt.Printf(" error: %s\n", sdiff.Error)
|
||||
}
|
||||
if vars {
|
||||
var mdiff statediff.MapDiffType
|
||||
err := mdiff.Decode(sdiff.VarsDiff)
|
||||
if err != nil {
|
||||
fmt.Printf(" vars: error[%s]\n", err.Error())
|
||||
} else {
|
||||
mdiff.Dump()
|
||||
}
|
||||
}
|
||||
if aliases && len(sdiff.AliasesDiff) > 0 {
|
||||
var ldiff statediff.LineDiffType
|
||||
err := ldiff.Decode(sdiff.AliasesDiff)
|
||||
if err != nil {
|
||||
fmt.Printf(" aliases: error[%s]\n", err.Error())
|
||||
} else {
|
||||
ldiff.Dump()
|
||||
}
|
||||
}
|
||||
if funcs && len(sdiff.FuncsDiff) > 0 {
|
||||
var ldiff statediff.LineDiffType
|
||||
err := ldiff.Decode(sdiff.FuncsDiff)
|
||||
if err != nil {
|
||||
fmt.Printf(" funcs: error[%s]\n", err.Error())
|
||||
} else {
|
||||
ldiff.Dump()
|
||||
}
|
||||
}
|
||||
}
|
747
waveshell/pkg/server/server.go
Normal file
747
waveshell/pkg/server/server.go
Normal file
@ -0,0 +1,747 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/alessio/shellescape"
|
||||
"github.com/commandlinedev/apishell/pkg/base"
|
||||
"github.com/commandlinedev/apishell/pkg/packet"
|
||||
"github.com/commandlinedev/apishell/pkg/shexec"
|
||||
)
|
||||
|
||||
const MaxFileDataPacketSize = 16 * 1024
|
||||
const WriteFileContextTimeout = 30 * time.Second
|
||||
const cleanLoopTime = 5 * time.Second
|
||||
const MaxWriteFileContextData = 100
|
||||
|
||||
// TODO create unblockable packet-sender (backed by an array) for clientproc
|
||||
type MServer struct {
|
||||
Lock *sync.Mutex
|
||||
MainInput *packet.PacketParser
|
||||
Sender *packet.PacketSender
|
||||
ClientMap map[base.CommandKey]*shexec.ClientProc
|
||||
Debug bool
|
||||
StateMap map[string]*packet.ShellState // sha1->state
|
||||
CurrentState string // sha1
|
||||
WriteErrorCh chan bool // closed if there is a I/O write error
|
||||
WriteErrorChOnce *sync.Once
|
||||
WriteFileContextMap map[string]*WriteFileContext
|
||||
Done bool
|
||||
}
|
||||
|
||||
type WriteFileContext struct {
|
||||
CVar *sync.Cond
|
||||
Data []*packet.FileDataPacketType
|
||||
LastActive time.Time
|
||||
Err error
|
||||
Done bool
|
||||
}
|
||||
|
||||
func (m *MServer) Close() {
|
||||
m.Sender.Close()
|
||||
m.Sender.WaitForDone()
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
m.Done = true
|
||||
}
|
||||
|
||||
func (m *MServer) checkDone() bool {
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
return m.Done
|
||||
}
|
||||
|
||||
func (m *MServer) getWriteFileContext(reqId string) *WriteFileContext {
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
wfc := m.WriteFileContextMap[reqId]
|
||||
if wfc == nil {
|
||||
wfc = &WriteFileContext{
|
||||
CVar: sync.NewCond(&sync.Mutex{}),
|
||||
LastActive: time.Now(),
|
||||
}
|
||||
m.WriteFileContextMap[reqId] = wfc
|
||||
}
|
||||
return wfc
|
||||
}
|
||||
|
||||
func (m *MServer) addFileDataPacket(pk *packet.FileDataPacketType) {
|
||||
m.Lock.Lock()
|
||||
wfc := m.WriteFileContextMap[pk.RespId]
|
||||
m.Lock.Unlock()
|
||||
if wfc == nil {
|
||||
return
|
||||
}
|
||||
wfc.CVar.L.Lock()
|
||||
defer wfc.CVar.L.Unlock()
|
||||
if wfc.Done || wfc.Err != nil {
|
||||
return
|
||||
}
|
||||
if len(wfc.Data) > MaxWriteFileContextData {
|
||||
wfc.Err = errors.New("write-file buffer length exceeded")
|
||||
wfc.Data = nil
|
||||
wfc.CVar.Broadcast()
|
||||
return
|
||||
}
|
||||
wfc.LastActive = time.Now()
|
||||
wfc.Data = append(wfc.Data, pk)
|
||||
wfc.CVar.Signal()
|
||||
}
|
||||
|
||||
func (wfc *WriteFileContext) setDone() {
|
||||
wfc.CVar.L.Lock()
|
||||
defer wfc.CVar.L.Unlock()
|
||||
wfc.Done = true
|
||||
wfc.Data = nil
|
||||
wfc.CVar.Broadcast()
|
||||
}
|
||||
|
||||
func (m *MServer) cleanWriteFileContexts() {
|
||||
now := time.Now()
|
||||
var staleWfcs []*WriteFileContext
|
||||
m.Lock.Lock()
|
||||
for reqId, wfc := range m.WriteFileContextMap {
|
||||
if now.Sub(wfc.LastActive) > WriteFileContextTimeout {
|
||||
staleWfcs = append(staleWfcs, wfc)
|
||||
delete(m.WriteFileContextMap, reqId)
|
||||
}
|
||||
}
|
||||
m.Lock.Unlock()
|
||||
|
||||
// we do this outside of m.Lock just in case there is some lock contention (end of WriteFile could theoretically be slow)
|
||||
for _, wfc := range staleWfcs {
|
||||
wfc.setDone()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MServer) ProcessCommandPacket(pk packet.CommandPacketType) {
|
||||
ck := pk.GetCK()
|
||||
if ck == "" {
|
||||
m.Sender.SendMessageFmt("received '%s' packet without ck", pk.GetType())
|
||||
return
|
||||
}
|
||||
m.Lock.Lock()
|
||||
cproc := m.ClientMap[ck]
|
||||
m.Lock.Unlock()
|
||||
if cproc == nil {
|
||||
m.Sender.SendCmdError(ck, fmt.Errorf("no client proc for ck '%s', pk=%s", ck, packet.AsString(pk)))
|
||||
return
|
||||
}
|
||||
cproc.Input.SendPacket(pk)
|
||||
return
|
||||
}
|
||||
|
||||
func runSingleCompGen(cwd string, compType string, prefix string) ([]string, bool, error) {
|
||||
if !packet.IsValidCompGenType(compType) {
|
||||
return nil, false, fmt.Errorf("invalid compgen type '%s'", compType)
|
||||
}
|
||||
compGenCmdStr := fmt.Sprintf("cd %s; compgen -A %s -- %s | sort | uniq | head -n %d", shellescape.Quote(cwd), shellescape.Quote(compType), shellescape.Quote(prefix), packet.MaxCompGenValues+1)
|
||||
ecmd := exec.Command("bash", "-c", compGenCmdStr)
|
||||
outputBytes, err := ecmd.Output()
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("compgen error: %w", err)
|
||||
}
|
||||
outputStr := string(outputBytes)
|
||||
parts := strings.Split(outputStr, "\n")
|
||||
if len(parts) > 0 && parts[len(parts)-1] == "" {
|
||||
parts = parts[0 : len(parts)-1]
|
||||
}
|
||||
hasMore := false
|
||||
if len(parts) > packet.MaxCompGenValues {
|
||||
hasMore = true
|
||||
parts = parts[0:packet.MaxCompGenValues]
|
||||
}
|
||||
return parts, hasMore, nil
|
||||
}
|
||||
|
||||
func appendSlashes(comps []string) {
|
||||
for idx, comp := range comps {
|
||||
comps[idx] = comp + "/"
|
||||
}
|
||||
}
|
||||
|
||||
func strArrToMap(strs []string) map[string]bool {
|
||||
rtn := make(map[string]bool)
|
||||
for _, s := range strs {
|
||||
rtn[s] = true
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (m *MServer) runMixedCompGen(compPk *packet.CompGenPacketType) {
|
||||
// get directories and files, unique them and put slashes on directories for completion
|
||||
reqId := compPk.GetReqId()
|
||||
compDirs, hasMoreDirs, err := runSingleCompGen(compPk.Cwd, "directory", compPk.Prefix)
|
||||
if err != nil {
|
||||
m.Sender.SendErrorResponse(reqId, err)
|
||||
return
|
||||
}
|
||||
compFiles, hasMoreFiles, err := runSingleCompGen(compPk.Cwd, compPk.CompType, compPk.Prefix)
|
||||
if err != nil {
|
||||
m.Sender.SendErrorResponse(reqId, err)
|
||||
return
|
||||
}
|
||||
|
||||
dirMap := strArrToMap(compDirs)
|
||||
// seed comps with dirs (but append slashes)
|
||||
comps := compDirs
|
||||
appendSlashes(comps)
|
||||
// add files that are not directories (look up in dirMap)
|
||||
for _, file := range compFiles {
|
||||
if dirMap[file] {
|
||||
continue
|
||||
}
|
||||
comps = append(comps, file)
|
||||
}
|
||||
sort.Strings(comps) // resort
|
||||
m.Sender.SendResponse(reqId, map[string]interface{}{"comps": comps, "hasmore": (hasMoreFiles || hasMoreDirs)})
|
||||
return
|
||||
}
|
||||
|
||||
func (m *MServer) runCompGen(compPk *packet.CompGenPacketType) {
|
||||
reqId := compPk.GetReqId()
|
||||
if compPk.CompType == "file" || compPk.CompType == "command" {
|
||||
m.runMixedCompGen(compPk)
|
||||
return
|
||||
}
|
||||
comps, hasMore, err := runSingleCompGen(compPk.Cwd, compPk.CompType, compPk.Prefix)
|
||||
if err != nil {
|
||||
m.Sender.SendErrorResponse(reqId, err)
|
||||
return
|
||||
}
|
||||
if compPk.CompType == "directory" {
|
||||
appendSlashes(comps)
|
||||
}
|
||||
m.Sender.SendResponse(reqId, map[string]interface{}{"comps": comps, "hasmore": hasMore})
|
||||
return
|
||||
}
|
||||
|
||||
func (m *MServer) setCurrentState(state *packet.ShellState) {
|
||||
if state == nil {
|
||||
return
|
||||
}
|
||||
hval, _ := state.EncodeAndHash()
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
m.StateMap[hval] = state
|
||||
m.CurrentState = hval
|
||||
}
|
||||
|
||||
func (m *MServer) reinit(reqId string) {
|
||||
initPk, err := shexec.MakeServerInitPacket()
|
||||
if err != nil {
|
||||
m.Sender.SendErrorResponse(reqId, fmt.Errorf("error creating init packet: %w", err))
|
||||
return
|
||||
}
|
||||
m.setCurrentState(initPk.State)
|
||||
initPk.RespId = reqId
|
||||
m.Sender.SendPacket(initPk)
|
||||
}
|
||||
|
||||
func makeTemp(path string, mode fs.FileMode) (*os.File, error) {
|
||||
dirName := filepath.Dir(path)
|
||||
baseName := filepath.Base(path)
|
||||
baseTempName := baseName + ".tmp."
|
||||
writeFd, err := os.CreateTemp(dirName, baseTempName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = writeFd.Chmod(mode)
|
||||
if err != nil {
|
||||
writeFd.Close()
|
||||
os.Remove(writeFd.Name())
|
||||
return nil, fmt.Errorf("error setting tempfile permissions: %w", err)
|
||||
}
|
||||
return writeFd, nil
|
||||
}
|
||||
|
||||
func checkFileWritable(path string) error {
|
||||
finfo, err := os.Stat(path) // ok to follow symlinks
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
dirName := filepath.Dir(path)
|
||||
dirInfo, err := os.Stat(dirName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("file does not exist, error trying to stat parent directory: %w", err)
|
||||
}
|
||||
if !dirInfo.IsDir() {
|
||||
return fmt.Errorf("file does not exist, parent path [%s] is not a directory", dirName)
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot stat: %w", err)
|
||||
}
|
||||
if finfo.IsDir() {
|
||||
return fmt.Errorf("invalid path, cannot write a directory")
|
||||
}
|
||||
if (finfo.Mode() & fs.ModeSymlink) != 0 {
|
||||
return fmt.Errorf("writefile does not support symlinks") // note this shouldn't happen because we're using Stat (not Lstat)
|
||||
}
|
||||
if (finfo.Mode() & (fs.ModeNamedPipe | fs.ModeSocket | fs.ModeDevice)) != 0 {
|
||||
return fmt.Errorf("writefile does not support special files (named pipes, sockets, devices): mode=%v", finfo.Mode())
|
||||
}
|
||||
writePerm := (finfo.Mode().Perm() & 0o222)
|
||||
if writePerm == 0 {
|
||||
return fmt.Errorf("file is not writable, perms: %v", finfo.Mode().Perm())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func copyFile(dstName string, srcName string) error {
|
||||
srcFd, err := os.Open(srcName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer srcFd.Close()
|
||||
dstFd, err := os.OpenFile(dstName, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o666) // use 666 because OpenFile respects umask
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// we don't defer dstFd.Close() so we can return an error if dstFd.Close() returns an error
|
||||
_, err = io.Copy(dstFd, srcFd)
|
||||
if err != nil {
|
||||
dstFd.Close()
|
||||
return err
|
||||
}
|
||||
return dstFd.Close()
|
||||
}
|
||||
|
||||
func (m *MServer) writeFile(pk *packet.WriteFilePacketType, wfc *WriteFileContext) {
|
||||
defer wfc.setDone()
|
||||
if pk.Path == "" {
|
||||
resp := packet.MakeWriteFileReadyPacket(pk.ReqId)
|
||||
resp.Error = "invalid write-file request, no path specified"
|
||||
m.Sender.SendPacket(resp)
|
||||
return
|
||||
}
|
||||
err := checkFileWritable(pk.Path)
|
||||
if err != nil {
|
||||
resp := packet.MakeWriteFileReadyPacket(pk.ReqId)
|
||||
resp.Error = err.Error()
|
||||
m.Sender.SendPacket(resp)
|
||||
return
|
||||
}
|
||||
var writeFd *os.File
|
||||
if pk.UseTemp {
|
||||
writeFd, err = os.CreateTemp("", "mshell.writefile.*") // "" means make this file in standard TempDir
|
||||
if err != nil {
|
||||
resp := packet.MakeWriteFileReadyPacket(pk.ReqId)
|
||||
resp.Error = fmt.Sprintf("cannot create temp file: %v", err)
|
||||
m.Sender.SendPacket(resp)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
writeFd, err = os.OpenFile(pk.Path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o666) // use 666 because OpenFile respects umask
|
||||
if err != nil {
|
||||
resp := packet.MakeWriteFileReadyPacket(pk.ReqId)
|
||||
resp.Error = fmt.Sprintf("write-file could not open file: %v", err)
|
||||
m.Sender.SendPacket(resp)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// ok, so now writeFd is valid, send the "ready" response
|
||||
resp := packet.MakeWriteFileReadyPacket(pk.ReqId)
|
||||
m.Sender.SendPacket(resp)
|
||||
|
||||
// now we wait for data (cond var)
|
||||
// this Unlock() runs first (because it is a later defer) so we can still run wfc.setDone() safely
|
||||
wfc.CVar.L.Lock()
|
||||
defer wfc.CVar.L.Unlock()
|
||||
var doneErr error
|
||||
for {
|
||||
if wfc.Done {
|
||||
break
|
||||
}
|
||||
if wfc.Err != nil {
|
||||
doneErr = wfc.Err
|
||||
break
|
||||
}
|
||||
if len(wfc.Data) == 0 {
|
||||
wfc.CVar.Wait()
|
||||
continue
|
||||
}
|
||||
dataPk := wfc.Data[0]
|
||||
wfc.Data = wfc.Data[1:]
|
||||
if dataPk.Error != "" {
|
||||
doneErr = fmt.Errorf("error received from client: %v", errors.New(dataPk.Error))
|
||||
break
|
||||
}
|
||||
if len(dataPk.Data) > 0 {
|
||||
_, err := writeFd.Write(dataPk.Data)
|
||||
if err != nil {
|
||||
doneErr = fmt.Errorf("error writing data to file: %v", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
if dataPk.Eof {
|
||||
break
|
||||
}
|
||||
}
|
||||
closeErr := writeFd.Close()
|
||||
if doneErr == nil && closeErr != nil {
|
||||
doneErr = fmt.Errorf("error closing file: %v", closeErr)
|
||||
}
|
||||
if pk.UseTemp {
|
||||
if doneErr != nil {
|
||||
os.Remove(writeFd.Name())
|
||||
} else {
|
||||
// copy file between writeFd.Name() and pk.Path
|
||||
copyErr := copyFile(pk.Path, writeFd.Name())
|
||||
if err != nil {
|
||||
doneErr = fmt.Errorf("error writing file: %v", copyErr)
|
||||
}
|
||||
os.Remove(writeFd.Name())
|
||||
}
|
||||
}
|
||||
donePk := packet.MakeWriteFileDonePacket(pk.ReqId)
|
||||
if doneErr != nil {
|
||||
donePk.Error = doneErr.Error()
|
||||
}
|
||||
m.Sender.SendPacket(donePk)
|
||||
}
|
||||
|
||||
func (m *MServer) returnStreamFileNewFileResponse(pk *packet.StreamFilePacketType) {
|
||||
// ok, file doesn't exist, so try to check the directory at least to see if we can write a file here
|
||||
resp := packet.MakeStreamFileResponse(pk.ReqId)
|
||||
defer func() {
|
||||
if resp.Error == "" {
|
||||
resp.Done = true
|
||||
}
|
||||
m.Sender.SendPacket(resp)
|
||||
}()
|
||||
dirName := filepath.Dir(pk.Path)
|
||||
dirInfo, err := os.Stat(dirName)
|
||||
if err != nil {
|
||||
resp.Error = fmt.Sprintf("file does not exist, error trying to stat parent directory: %v", err)
|
||||
return
|
||||
}
|
||||
if !dirInfo.IsDir() {
|
||||
resp.Error = fmt.Sprintf("file does not exist, parent path [%s] is not a directory", dirName)
|
||||
return
|
||||
}
|
||||
resp.Info = &packet.FileInfo{
|
||||
Name: pk.Path,
|
||||
Size: 0,
|
||||
ModTs: 0,
|
||||
IsDir: false,
|
||||
Perm: int(dirInfo.Mode().Perm()),
|
||||
NotFound: true,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m *MServer) streamFile(pk *packet.StreamFilePacketType) {
|
||||
resp := packet.MakeStreamFileResponse(pk.ReqId)
|
||||
finfo, err := os.Stat(pk.Path)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
// special return
|
||||
m.returnStreamFileNewFileResponse(pk)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
resp.Error = fmt.Sprintf("cannot stat file %q: %v", pk.Path, err)
|
||||
m.Sender.SendPacket(resp)
|
||||
return
|
||||
}
|
||||
resp.Info = &packet.FileInfo{
|
||||
Name: pk.Path,
|
||||
Size: finfo.Size(),
|
||||
ModTs: finfo.ModTime().UnixMilli(),
|
||||
IsDir: finfo.IsDir(),
|
||||
Perm: int(finfo.Mode().Perm()),
|
||||
}
|
||||
if pk.StatOnly {
|
||||
resp.Done = true
|
||||
m.Sender.SendPacket(resp)
|
||||
return
|
||||
}
|
||||
// like the http Range header. range header is end inclusive. for us, endByte is non-inclusive (so we add 1)
|
||||
var startByte, endByte int64
|
||||
if len(pk.ByteRange) == 0 {
|
||||
endByte = finfo.Size()
|
||||
} else if len(pk.ByteRange) == 1 && pk.ByteRange[0] >= 0 {
|
||||
startByte = pk.ByteRange[0]
|
||||
endByte = finfo.Size()
|
||||
} else if len(pk.ByteRange) == 1 && pk.ByteRange[0] < 0 {
|
||||
startByte = finfo.Size() + pk.ByteRange[0] // "+" since ByteRange[0] is less than 0
|
||||
endByte = finfo.Size()
|
||||
} else if len(pk.ByteRange) == 2 {
|
||||
startByte = pk.ByteRange[0]
|
||||
endByte = pk.ByteRange[1] + 1
|
||||
} else {
|
||||
resp.Error = fmt.Sprintf("invalid byte range (%d entries)", len(pk.ByteRange))
|
||||
m.Sender.SendPacket(resp)
|
||||
return
|
||||
}
|
||||
if startByte < 0 {
|
||||
startByte = 0
|
||||
}
|
||||
if endByte > finfo.Size() {
|
||||
endByte = finfo.Size()
|
||||
}
|
||||
if startByte >= endByte {
|
||||
resp.Done = true
|
||||
m.Sender.SendPacket(resp)
|
||||
return
|
||||
}
|
||||
fd, err := os.Open(pk.Path)
|
||||
if err != nil {
|
||||
resp.Error = fmt.Sprintf("opening file: %v", err)
|
||||
m.Sender.SendPacket(resp)
|
||||
return
|
||||
}
|
||||
defer fd.Close()
|
||||
m.Sender.SendPacket(resp)
|
||||
var buffer [MaxFileDataPacketSize]byte
|
||||
var sentDone bool
|
||||
first := true
|
||||
for ; startByte < endByte; startByte += MaxFileDataPacketSize {
|
||||
if !first {
|
||||
// throttle packet sending @ 1000 packets/s, or 16M/s
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
}
|
||||
first = false
|
||||
readLen := int64Min(MaxFileDataPacketSize, endByte-startByte)
|
||||
bufSlice := buffer[0:readLen]
|
||||
nr, err := fd.ReadAt(bufSlice, startByte)
|
||||
dataPk := packet.MakeFileDataPacket(pk.ReqId)
|
||||
dataPk.Data = make([]byte, nr)
|
||||
copy(dataPk.Data, bufSlice)
|
||||
if err == io.EOF {
|
||||
dataPk.Eof = true
|
||||
} else if err != nil {
|
||||
dataPk.Error = err.Error()
|
||||
}
|
||||
m.Sender.SendPacket(dataPk)
|
||||
if dataPk.GetResponseDone() {
|
||||
sentDone = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !sentDone {
|
||||
dataPk := packet.MakeFileDataPacket(pk.ReqId)
|
||||
dataPk.Eof = true
|
||||
m.Sender.SendPacket(dataPk)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func int64Min(v1 int64, v2 int64) int64 {
|
||||
if v1 < v2 {
|
||||
return v1
|
||||
}
|
||||
return v2
|
||||
}
|
||||
|
||||
func (m *MServer) ProcessRpcPacket(pk packet.RpcPacketType) {
|
||||
reqId := pk.GetReqId()
|
||||
if cdPk, ok := pk.(*packet.CdPacketType); ok {
|
||||
err := os.Chdir(cdPk.Dir)
|
||||
if err != nil {
|
||||
m.Sender.SendErrorResponse(reqId, fmt.Errorf("cannot change directory: %w", err))
|
||||
return
|
||||
}
|
||||
m.Sender.SendResponse(reqId, true)
|
||||
return
|
||||
}
|
||||
if compPk, ok := pk.(*packet.CompGenPacketType); ok {
|
||||
go m.runCompGen(compPk)
|
||||
return
|
||||
}
|
||||
if _, ok := pk.(*packet.ReInitPacketType); ok {
|
||||
go m.reinit(reqId)
|
||||
return
|
||||
}
|
||||
if streamPk, ok := pk.(*packet.StreamFilePacketType); ok {
|
||||
go m.streamFile(streamPk)
|
||||
return
|
||||
}
|
||||
if writePk, ok := pk.(*packet.WriteFilePacketType); ok {
|
||||
wfc := m.getWriteFileContext(writePk.ReqId)
|
||||
go m.writeFile(writePk, wfc)
|
||||
return
|
||||
}
|
||||
m.Sender.SendErrorResponse(reqId, fmt.Errorf("invalid rpc type '%s'", pk.GetType()))
|
||||
return
|
||||
}
|
||||
|
||||
func (m *MServer) getCurrentState() (string, *packet.ShellState) {
|
||||
m.Lock.Lock()
|
||||
defer m.Lock.Unlock()
|
||||
return m.CurrentState, m.StateMap[m.CurrentState]
|
||||
}
|
||||
|
||||
func (m *MServer) clientPacketCallback(pk packet.PacketType) {
|
||||
if pk.GetType() != packet.CmdDonePacketStr {
|
||||
return
|
||||
}
|
||||
donePk := pk.(*packet.CmdDonePacketType)
|
||||
if donePk.FinalState == nil {
|
||||
return
|
||||
}
|
||||
stateHash, curState := m.getCurrentState()
|
||||
if curState == nil {
|
||||
return
|
||||
}
|
||||
diff, err := shexec.MakeShellStateDiff(*curState, stateHash, *donePk.FinalState)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
donePk.FinalState = nil
|
||||
donePk.FinalStateDiff = &diff
|
||||
}
|
||||
|
||||
func (m *MServer) runCommand(runPacket *packet.RunPacketType) {
|
||||
if err := runPacket.CK.Validate("packet"); err != nil {
|
||||
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("server run packets require valid ck: %s", err))
|
||||
return
|
||||
}
|
||||
ecmd, err := shexec.SSHOpts{}.MakeMShellSingleCmd(true)
|
||||
if err != nil {
|
||||
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("server run packets require valid ck: %s", err))
|
||||
return
|
||||
}
|
||||
cproc, _, err := shexec.MakeClientProc(context.Background(), ecmd)
|
||||
if err != nil {
|
||||
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("starting mshell client: %s", err))
|
||||
return
|
||||
}
|
||||
m.Lock.Lock()
|
||||
m.ClientMap[runPacket.CK] = cproc
|
||||
m.Lock.Unlock()
|
||||
go func() {
|
||||
defer func() {
|
||||
r := recover()
|
||||
finalPk := packet.MakeCmdFinalPacket(runPacket.CK)
|
||||
finalPk.Ts = time.Now().UnixMilli()
|
||||
if r != nil {
|
||||
finalPk.Error = fmt.Sprintf("%s", r)
|
||||
}
|
||||
m.Sender.SendPacket(finalPk)
|
||||
m.Lock.Lock()
|
||||
delete(m.ClientMap, runPacket.CK)
|
||||
m.Lock.Unlock()
|
||||
cproc.Close()
|
||||
}()
|
||||
shexec.SendRunPacketAndRunData(context.Background(), cproc.Input, runPacket)
|
||||
cproc.ProxySingleOutput(runPacket.CK, m.Sender, m.clientPacketCallback)
|
||||
}()
|
||||
}
|
||||
|
||||
func (m *MServer) packetSenderErrorHandler(sender *packet.PacketSender, pk packet.PacketType, err error) {
|
||||
if serr, ok := err.(*packet.SendError); ok && serr.IsMarshalError {
|
||||
msg := packet.MakeMessagePacket(err.Error())
|
||||
if cpk, ok := pk.(packet.CommandPacketType); ok {
|
||||
msg.CK = cpk.GetCK()
|
||||
}
|
||||
sender.SendPacket(msg)
|
||||
return
|
||||
} else {
|
||||
// I/O error: close the WriteErrorCh to signal that we are dead (cannot continue if we can't write output)
|
||||
m.WriteErrorChOnce.Do(func() {
|
||||
close(m.WriteErrorCh)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (server *MServer) runReadLoop() {
|
||||
builder := packet.MakeRunPacketBuilder()
|
||||
for pk := range server.MainInput.MainCh {
|
||||
if server.Debug {
|
||||
fmt.Printf("PK> %s\n", packet.AsString(pk))
|
||||
}
|
||||
ok, runPacket := builder.ProcessPacket(pk)
|
||||
if ok {
|
||||
if runPacket != nil {
|
||||
server.runCommand(runPacket)
|
||||
continue
|
||||
}
|
||||
continue
|
||||
}
|
||||
if cmdPk, ok := pk.(packet.CommandPacketType); ok {
|
||||
server.ProcessCommandPacket(cmdPk)
|
||||
continue
|
||||
}
|
||||
if rpcPk, ok := pk.(packet.RpcPacketType); ok {
|
||||
server.ProcessRpcPacket(rpcPk)
|
||||
continue
|
||||
}
|
||||
if fileDataPk, ok := pk.(*packet.FileDataPacketType); ok {
|
||||
server.addFileDataPacket(fileDataPk)
|
||||
continue
|
||||
}
|
||||
server.Sender.SendMessageFmt("invalid packet '%s' sent to mshell server", packet.AsString(pk))
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
func RunServer() (int, error) {
|
||||
debug := false
|
||||
if len(os.Args) >= 3 && os.Args[2] == "--debug" {
|
||||
debug = true
|
||||
}
|
||||
server := &MServer{
|
||||
Lock: &sync.Mutex{},
|
||||
ClientMap: make(map[base.CommandKey]*shexec.ClientProc),
|
||||
StateMap: make(map[string]*packet.ShellState),
|
||||
Debug: debug,
|
||||
WriteErrorCh: make(chan bool),
|
||||
WriteErrorChOnce: &sync.Once{},
|
||||
WriteFileContextMap: make(map[string]*WriteFileContext),
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
if server.checkDone() {
|
||||
return
|
||||
}
|
||||
time.Sleep(cleanLoopTime)
|
||||
server.cleanWriteFileContexts()
|
||||
}
|
||||
}()
|
||||
if debug {
|
||||
packet.GlobalDebug = true
|
||||
}
|
||||
server.MainInput = packet.MakePacketParser(os.Stdin, false)
|
||||
server.Sender = packet.MakePacketSender(os.Stdout, server.packetSenderErrorHandler)
|
||||
defer server.Close()
|
||||
var err error
|
||||
initPacket, err := shexec.MakeServerInitPacket()
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
server.setCurrentState(initPacket.State)
|
||||
server.Sender.SendPacket(initPacket)
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
server.Sender.SendPacket(packet.MakePingPacket())
|
||||
}
|
||||
}()
|
||||
defer ticker.Stop()
|
||||
readLoopDoneCh := make(chan bool)
|
||||
go func() {
|
||||
defer close(readLoopDoneCh)
|
||||
server.runReadLoop()
|
||||
}()
|
||||
select {
|
||||
case <-readLoopDoneCh:
|
||||
break
|
||||
|
||||
case <-server.WriteErrorCh:
|
||||
break
|
||||
}
|
||||
return 0, nil
|
||||
}
|
132
waveshell/pkg/shexec/client.go
Normal file
132
waveshell/pkg/shexec/client.go
Normal file
@ -0,0 +1,132 @@
|
||||
package shexec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
"time"
|
||||
|
||||
"github.com/commandlinedev/apishell/pkg/base"
|
||||
"github.com/commandlinedev/apishell/pkg/packet"
|
||||
"golang.org/x/mod/semver"
|
||||
)
|
||||
|
||||
// TODO - track buffer sizes for sending input
|
||||
|
||||
const NotFoundVersion = "v0.0"
|
||||
|
||||
type ClientProc struct {
|
||||
Cmd *exec.Cmd
|
||||
InitPk *packet.InitPacketType
|
||||
StartTs time.Time
|
||||
StdinWriter io.WriteCloser
|
||||
StdoutReader io.ReadCloser
|
||||
StderrReader io.ReadCloser
|
||||
Input *packet.PacketSender
|
||||
Output *packet.PacketParser
|
||||
}
|
||||
|
||||
// returns (clientproc, initpk, error)
|
||||
func MakeClientProc(ctx context.Context, ecmd *exec.Cmd) (*ClientProc, *packet.InitPacketType, error) {
|
||||
inputWriter, err := ecmd.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("creating stdin pipe: %v", err)
|
||||
}
|
||||
stdoutReader, err := ecmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("creating stdout pipe: %v", err)
|
||||
}
|
||||
stderrReader, err := ecmd.StderrPipe()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("creating stderr pipe: %v", err)
|
||||
}
|
||||
startTs := time.Now()
|
||||
err = ecmd.Start()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("running local client: %w", err)
|
||||
}
|
||||
sender := packet.MakePacketSender(inputWriter, nil)
|
||||
stdoutPacketParser := packet.MakePacketParser(stdoutReader, false)
|
||||
stderrPacketParser := packet.MakePacketParser(stderrReader, false)
|
||||
packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser, true)
|
||||
cproc := &ClientProc{
|
||||
Cmd: ecmd,
|
||||
StartTs: startTs,
|
||||
StdinWriter: inputWriter,
|
||||
StdoutReader: stdoutReader,
|
||||
StderrReader: stderrReader,
|
||||
Input: sender,
|
||||
Output: packetParser,
|
||||
}
|
||||
|
||||
var pk packet.PacketType
|
||||
select {
|
||||
case pk = <-packetParser.MainCh:
|
||||
case <-ctx.Done():
|
||||
cproc.Close()
|
||||
return nil, nil, ctx.Err()
|
||||
}
|
||||
if pk != nil {
|
||||
if pk.GetType() != packet.InitPacketStr {
|
||||
cproc.Close()
|
||||
return nil, nil, fmt.Errorf("invalid packet received from mshell client: %s", packet.AsString(pk))
|
||||
}
|
||||
initPk := pk.(*packet.InitPacketType)
|
||||
if initPk.NotFound {
|
||||
cproc.Close()
|
||||
return nil, initPk, fmt.Errorf("mshell client not found")
|
||||
}
|
||||
if semver.MajorMinor(initPk.Version) != semver.MajorMinor(base.MShellVersion) {
|
||||
cproc.Close()
|
||||
return nil, initPk, fmt.Errorf("invalid remote mshell version '%s', must be '=%s'", initPk.Version, semver.MajorMinor(base.MShellVersion))
|
||||
}
|
||||
cproc.InitPk = initPk
|
||||
}
|
||||
if cproc.InitPk == nil {
|
||||
cproc.Close()
|
||||
return nil, nil, fmt.Errorf("no init packet received from mshell client")
|
||||
}
|
||||
return cproc, cproc.InitPk, nil
|
||||
}
|
||||
|
||||
func (cproc *ClientProc) Close() {
|
||||
if cproc.Input != nil {
|
||||
cproc.Input.Close()
|
||||
}
|
||||
if cproc.StdinWriter != nil {
|
||||
cproc.StdinWriter.Close()
|
||||
}
|
||||
if cproc.StdoutReader != nil {
|
||||
cproc.StdoutReader.Close()
|
||||
}
|
||||
if cproc.StderrReader != nil {
|
||||
cproc.StderrReader.Close()
|
||||
}
|
||||
if cproc.Cmd != nil {
|
||||
cproc.Cmd.Process.Kill()
|
||||
}
|
||||
}
|
||||
|
||||
func (cproc *ClientProc) ProxySingleOutput(ck base.CommandKey, sender *packet.PacketSender, packetCallback func(packet.PacketType)) {
|
||||
sentDonePk := false
|
||||
for pk := range cproc.Output.MainCh {
|
||||
if packetCallback != nil {
|
||||
packetCallback(pk)
|
||||
}
|
||||
if pk.GetType() == packet.CmdDonePacketStr {
|
||||
sentDonePk = true
|
||||
}
|
||||
sender.SendPacket(pk)
|
||||
}
|
||||
exitErr := cproc.Cmd.Wait()
|
||||
if !sentDonePk {
|
||||
endTs := time.Now()
|
||||
cmdDuration := endTs.Sub(cproc.StartTs)
|
||||
donePacket := packet.MakeCmdDonePacket(ck)
|
||||
donePacket.Ts = endTs.UnixMilli()
|
||||
donePacket.ExitCode = GetExitCode(exitErr)
|
||||
donePacket.DurationMs = int64(cmdDuration / time.Millisecond)
|
||||
sender.SendPacket(donePacket)
|
||||
}
|
||||
}
|
638
waveshell/pkg/shexec/parser.go
Normal file
638
waveshell/pkg/shexec/parser.go
Normal file
@ -0,0 +1,638 @@
|
||||
package shexec
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/alessio/shellescape"
|
||||
"github.com/commandlinedev/apishell/pkg/packet"
|
||||
"github.com/commandlinedev/apishell/pkg/simpleexpand"
|
||||
"github.com/commandlinedev/apishell/pkg/statediff"
|
||||
"mvdan.cc/sh/v3/expand"
|
||||
"mvdan.cc/sh/v3/syntax"
|
||||
)
|
||||
|
||||
const (
|
||||
DeclTypeArray = "array"
|
||||
DeclTypeAssocArray = "assoc"
|
||||
DeclTypeInt = "int"
|
||||
DeclTypeNormal = "normal"
|
||||
)
|
||||
|
||||
type ParseEnviron struct {
|
||||
Env map[string]string
|
||||
}
|
||||
|
||||
func (e *ParseEnviron) Get(name string) expand.Variable {
|
||||
val, ok := e.Env[name]
|
||||
if !ok {
|
||||
return expand.Variable{}
|
||||
}
|
||||
return expand.Variable{
|
||||
Exported: true,
|
||||
Kind: expand.String,
|
||||
Str: val,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ParseEnviron) Each(fn func(name string, vr expand.Variable) bool) {
|
||||
for key, _ := range e.Env {
|
||||
rtn := fn(key, e.Get(key))
|
||||
if !rtn {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func doCmdSubst(commandStr string, w io.Writer, word *syntax.CmdSubst) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func doProcSubst(w *syntax.ProcSubst) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func GetParserConfig(envMap map[string]string) *expand.Config {
|
||||
cfg := &expand.Config{
|
||||
Env: &ParseEnviron{Env: envMap},
|
||||
GlobStar: false,
|
||||
NullGlob: false,
|
||||
NoUnset: false,
|
||||
CmdSubst: func(w io.Writer, word *syntax.CmdSubst) error { return doCmdSubst("", w, word) },
|
||||
ProcSubst: doProcSubst,
|
||||
ReadDir: nil,
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func writeIndent(buf *bytes.Buffer, num int) {
|
||||
for i := 0; i < num; i++ {
|
||||
buf.WriteByte(' ')
|
||||
}
|
||||
}
|
||||
|
||||
func makeSpaceStr(num int) string {
|
||||
barr := make([]byte, num)
|
||||
for i := 0; i < num; i++ {
|
||||
barr[i] = ' '
|
||||
}
|
||||
return string(barr)
|
||||
}
|
||||
|
||||
// https://wiki.bash-hackers.org/syntax/shellvars
|
||||
var NoStoreVarNames = map[string]bool{
|
||||
"BASH": true,
|
||||
"BASHOPTS": true,
|
||||
"BASHPID": true,
|
||||
"BASH_ALIASES": true,
|
||||
"BASH_ARGC": true,
|
||||
"BASH_ARGV": true,
|
||||
"BASH_ARGV0": true,
|
||||
"BASH_CMDS": true,
|
||||
"BASH_COMMAND": true,
|
||||
"BASH_EXECUTION_STRING": true,
|
||||
"LINENO": true,
|
||||
"BASH_LINENO": true,
|
||||
"BASH_REMATCH": true,
|
||||
"BASH_SOURCE": true,
|
||||
"BASH_SUBSHELL": true,
|
||||
"COPROC": true,
|
||||
"DIRSTACK": true,
|
||||
"EPOCHREALTIME": true,
|
||||
"EPOCHSECONDS": true,
|
||||
"FUNCNAME": true,
|
||||
"HISTCMD": true,
|
||||
"OLDPWD": true,
|
||||
"PIPESTATUS": true,
|
||||
"PPID": true,
|
||||
"PWD": true,
|
||||
"RANDOM": true,
|
||||
"SECONDS": true,
|
||||
"SHLVL": true,
|
||||
"HISTFILE": true,
|
||||
"HISTFILESIZE": true,
|
||||
"HISTCONTROL": true,
|
||||
"HISTIGNORE": true,
|
||||
"HISTSIZE": true,
|
||||
"HISTTIMEFORMAT": true,
|
||||
"SRANDOM": true,
|
||||
"COLUMNS": true,
|
||||
|
||||
// we want these in our remote state object
|
||||
// "EUID": true,
|
||||
// "SHELLOPTS": true,
|
||||
// "UID": true,
|
||||
// "BASH_VERSINFO": true,
|
||||
// "BASH_VERSION": true,
|
||||
}
|
||||
|
||||
type DeclareDeclType struct {
|
||||
Args string
|
||||
Name string
|
||||
|
||||
// this holds the raw quoted value suitable for bash. this is *not* the real expanded variable value
|
||||
Value string
|
||||
}
|
||||
|
||||
var declareDeclArgsRe = regexp.MustCompile("^[aAxrifx]*$")
|
||||
var bashValidIdentifierRe = regexp.MustCompile("^[a-zA-Z_][a-zA-Z0-9_]*$")
|
||||
|
||||
func (d *DeclareDeclType) Validate() error {
|
||||
if len(d.Name) == 0 || !IsValidBashIdentifier(d.Name) {
|
||||
return fmt.Errorf("invalid shell variable name (invalid bash identifier)")
|
||||
}
|
||||
if strings.Index(d.Value, "\x00") >= 0 {
|
||||
return fmt.Errorf("invalid shell variable value (cannot contain 0 byte)")
|
||||
}
|
||||
if !declareDeclArgsRe.MatchString(d.Args) {
|
||||
return fmt.Errorf("invalid shell variable type %s", shellescape.Quote(d.Args))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DeclareDeclType) Serialize() string {
|
||||
return fmt.Sprintf("%s|%s=%s\x00", d.Args, d.Name, d.Value)
|
||||
}
|
||||
|
||||
func (d *DeclareDeclType) DeclareStmt() string {
|
||||
var argsStr string
|
||||
if d.Args == "" {
|
||||
argsStr = "--"
|
||||
} else {
|
||||
argsStr = "-" + d.Args
|
||||
}
|
||||
return fmt.Sprintf("declare %s %s=%s", argsStr, d.Name, d.Value)
|
||||
}
|
||||
|
||||
// envline should be valid
|
||||
func ParseDeclLine(envLine string) *DeclareDeclType {
|
||||
eqIdx := strings.Index(envLine, "=")
|
||||
if eqIdx == -1 {
|
||||
return nil
|
||||
}
|
||||
namePart := envLine[0:eqIdx]
|
||||
valPart := envLine[eqIdx+1:]
|
||||
pipeIdx := strings.Index(namePart, "|")
|
||||
if pipeIdx == -1 {
|
||||
return nil
|
||||
}
|
||||
return &DeclareDeclType{
|
||||
Args: namePart[0:pipeIdx],
|
||||
Name: namePart[pipeIdx+1:],
|
||||
Value: valPart,
|
||||
}
|
||||
}
|
||||
|
||||
// returns name => full-line
|
||||
func parseDeclLineToKV(envLine string) (string, string) {
|
||||
decl := ParseDeclLine(envLine)
|
||||
if decl == nil {
|
||||
return "", ""
|
||||
}
|
||||
return decl.Name, envLine
|
||||
}
|
||||
|
||||
func shellStateVarsToMap(shellVars []byte) map[string]string {
|
||||
if len(shellVars) == 0 {
|
||||
return nil
|
||||
}
|
||||
rtn := make(map[string]string)
|
||||
vars := bytes.Split(shellVars, []byte{0})
|
||||
for _, varLine := range vars {
|
||||
name, val := parseDeclLineToKV(string(varLine))
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
rtn[name] = val
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func strMapToShellStateVars(varMap map[string]string) []byte {
|
||||
var buf bytes.Buffer
|
||||
orderedKeys := getOrderedKeysStrMap(varMap)
|
||||
for _, key := range orderedKeys {
|
||||
val := varMap[key]
|
||||
buf.WriteString(val)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func getOrderedKeysStrMap(m map[string]string) []string {
|
||||
keys := make([]string, 0, len(m))
|
||||
for key, _ := range m {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
func getOrderedKeysDeclMap(m map[string]*DeclareDeclType) []string {
|
||||
keys := make([]string, 0, len(m))
|
||||
for key, _ := range m {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
func DeclMapFromState(state *packet.ShellState) map[string]*DeclareDeclType {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
rtn := make(map[string]*DeclareDeclType)
|
||||
vars := bytes.Split(state.ShellVars, []byte{0})
|
||||
for _, varLine := range vars {
|
||||
decl := ParseDeclLine(string(varLine))
|
||||
if decl != nil {
|
||||
rtn[decl.Name] = decl
|
||||
}
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func SerializeDeclMap(declMap map[string]*DeclareDeclType) []byte {
|
||||
var rtn bytes.Buffer
|
||||
orderedKeys := getOrderedKeysDeclMap(declMap)
|
||||
for _, key := range orderedKeys {
|
||||
decl := declMap[key]
|
||||
rtn.WriteString(decl.Serialize())
|
||||
}
|
||||
return rtn.Bytes()
|
||||
}
|
||||
|
||||
func EnvMapFromState(state *packet.ShellState) map[string]string {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
rtn := make(map[string]string)
|
||||
ectx := simpleexpand.SimpleExpandContext{}
|
||||
vars := bytes.Split(state.ShellVars, []byte{0})
|
||||
for _, varLine := range vars {
|
||||
decl := ParseDeclLine(string(varLine))
|
||||
if decl != nil && decl.IsExport() {
|
||||
rtn[decl.Name], _ = simpleexpand.SimpleExpandPartialWord(ectx, decl.Value, false)
|
||||
}
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func ShellVarMapFromState(state *packet.ShellState) map[string]string {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
rtn := make(map[string]string)
|
||||
ectx := simpleexpand.SimpleExpandContext{}
|
||||
vars := bytes.Split(state.ShellVars, []byte{0})
|
||||
for _, varLine := range vars {
|
||||
decl := ParseDeclLine(string(varLine))
|
||||
if decl != nil {
|
||||
rtn[decl.Name], _ = simpleexpand.SimpleExpandPartialWord(ectx, decl.Value, false)
|
||||
}
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func DumpVarMapFromState(state *packet.ShellState) {
|
||||
fmt.Printf("DUMP-STATE-VARS:\n")
|
||||
if state == nil {
|
||||
fmt.Printf(" nil\n")
|
||||
return
|
||||
}
|
||||
vars := bytes.Split(state.ShellVars, []byte{0})
|
||||
for _, varLine := range vars {
|
||||
fmt.Printf(" %s\n", varLine)
|
||||
}
|
||||
}
|
||||
|
||||
func VarDeclsFromState(state *packet.ShellState) []*DeclareDeclType {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
var rtn []*DeclareDeclType
|
||||
vars := bytes.Split(state.ShellVars, []byte{0})
|
||||
for _, varLine := range vars {
|
||||
decl := ParseDeclLine(string(varLine))
|
||||
if decl != nil {
|
||||
rtn = append(rtn, decl)
|
||||
}
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func IsValidBashIdentifier(s string) bool {
|
||||
return bashValidIdentifierRe.MatchString(s)
|
||||
}
|
||||
|
||||
func (d *DeclareDeclType) IsExport() bool {
|
||||
return strings.Index(d.Args, "x") >= 0
|
||||
}
|
||||
|
||||
func (d *DeclareDeclType) IsReadOnly() bool {
|
||||
return strings.Index(d.Args, "r") >= 0
|
||||
}
|
||||
|
||||
func (d *DeclareDeclType) DataType() string {
|
||||
if strings.Index(d.Args, "a") >= 0 {
|
||||
return DeclTypeArray
|
||||
}
|
||||
if strings.Index(d.Args, "A") >= 0 {
|
||||
return DeclTypeAssocArray
|
||||
}
|
||||
if strings.Index(d.Args, "i") >= 0 {
|
||||
return DeclTypeInt
|
||||
}
|
||||
return DeclTypeNormal
|
||||
}
|
||||
|
||||
func parseDeclareStmt(stmt *syntax.Stmt, src string) (*DeclareDeclType, error) {
|
||||
cmd := stmt.Cmd
|
||||
decl, ok := cmd.(*syntax.DeclClause)
|
||||
if !ok || decl.Variant.Value != "declare" || len(decl.Args) != 2 {
|
||||
return nil, fmt.Errorf("invalid declare variant")
|
||||
}
|
||||
rtn := &DeclareDeclType{}
|
||||
declArgs := decl.Args[0]
|
||||
if !declArgs.Naked || len(declArgs.Value.Parts) != 1 {
|
||||
return nil, fmt.Errorf("wrong number of declare args parts")
|
||||
}
|
||||
declArgsLit, ok := declArgs.Value.Parts[0].(*syntax.Lit)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("declare args is not a literal")
|
||||
}
|
||||
if !strings.HasPrefix(declArgsLit.Value, "-") {
|
||||
return nil, fmt.Errorf("declare args not an argument (does not start with '-')")
|
||||
}
|
||||
if declArgsLit.Value == "--" {
|
||||
rtn.Args = ""
|
||||
} else {
|
||||
rtn.Args = declArgsLit.Value[1:]
|
||||
}
|
||||
declAssign := decl.Args[1]
|
||||
if declAssign.Name == nil {
|
||||
return nil, fmt.Errorf("declare does not have a valid name")
|
||||
}
|
||||
rtn.Name = declAssign.Name.Value
|
||||
if declAssign.Naked || declAssign.Index != nil || declAssign.Append {
|
||||
return nil, fmt.Errorf("invalid decl format")
|
||||
}
|
||||
if declAssign.Value != nil {
|
||||
rtn.Value = string(src[declAssign.Value.Pos().Offset():declAssign.Value.End().Offset()])
|
||||
} else if declAssign.Array != nil {
|
||||
rtn.Value = string(src[declAssign.Array.Pos().Offset():declAssign.Array.End().Offset()])
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid decl, not plain value or array")
|
||||
}
|
||||
err := rtn.normalize()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = rtn.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func parseDeclareOutput(state *packet.ShellState, declareBytes []byte, pvarBytes []byte) error {
|
||||
declareStr := string(declareBytes)
|
||||
r := bytes.NewReader(declareBytes)
|
||||
parser := syntax.NewParser(syntax.Variant(syntax.LangBash))
|
||||
file, err := parser.Parse(r, "aliases")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var firstParseErr error
|
||||
declMap := make(map[string]*DeclareDeclType)
|
||||
for _, stmt := range file.Stmts {
|
||||
decl, err := parseDeclareStmt(stmt, declareStr)
|
||||
if err != nil {
|
||||
if firstParseErr == nil {
|
||||
firstParseErr = err
|
||||
}
|
||||
}
|
||||
if decl != nil && !NoStoreVarNames[decl.Name] {
|
||||
declMap[decl.Name] = decl
|
||||
}
|
||||
}
|
||||
pvars := bytes.Split(pvarBytes, []byte{0})
|
||||
for _, pvarBA := range pvars {
|
||||
pvarStr := string(pvarBA)
|
||||
pvarFields := strings.SplitN(pvarStr, " ", 2)
|
||||
if len(pvarFields) != 2 {
|
||||
continue
|
||||
}
|
||||
if pvarFields[0] == "" {
|
||||
continue
|
||||
}
|
||||
decl := &DeclareDeclType{Args: "x"}
|
||||
decl.Name = "PROMPTVAR_" + pvarFields[0]
|
||||
decl.Value = shellescape.Quote(pvarFields[1])
|
||||
declMap[decl.Name] = decl
|
||||
}
|
||||
state.ShellVars = SerializeDeclMap(declMap) // this writes out the decls in a canonical order
|
||||
if firstParseErr != nil {
|
||||
state.Error = firstParseErr.Error()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, error) {
|
||||
// 5 fields: version, cwd, env/vars, aliases, funcs
|
||||
fields := bytes.Split(outputBytes, []byte{0, 0})
|
||||
if len(fields) != 6 {
|
||||
return nil, fmt.Errorf("invalid shell state output, wrong number of fields, fields=%d", len(fields))
|
||||
}
|
||||
rtn := &packet.ShellState{}
|
||||
rtn.Version = strings.TrimSpace(string(fields[0]))
|
||||
if strings.Index(rtn.Version, "bash") == -1 {
|
||||
return nil, fmt.Errorf("invalid shell state output, only bash is supported")
|
||||
}
|
||||
rtn.Version = rtn.Version
|
||||
cwdStr := string(fields[1])
|
||||
if strings.HasSuffix(cwdStr, "\r\n") {
|
||||
cwdStr = cwdStr[0 : len(cwdStr)-2]
|
||||
} else if strings.HasSuffix(cwdStr, "\n") {
|
||||
cwdStr = cwdStr[0 : len(cwdStr)-1]
|
||||
}
|
||||
rtn.Cwd = string(cwdStr)
|
||||
err := parseDeclareOutput(rtn, fields[2], fields[5])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rtn.Aliases = strings.ReplaceAll(string(fields[3]), "\r\n", "\n")
|
||||
rtn.Funcs = strings.ReplaceAll(string(fields[4]), "\r\n", "\n")
|
||||
rtn.Funcs = removeFunc(rtn.Funcs, "_mshell_exittrap")
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func removeFunc(funcs string, toRemove string) string {
|
||||
lines := strings.Split(funcs, "\n")
|
||||
var newLines []string
|
||||
removeLine := fmt.Sprintf("%s ()", toRemove)
|
||||
doingRemove := false
|
||||
for _, line := range lines {
|
||||
if line == removeLine {
|
||||
doingRemove = true
|
||||
continue
|
||||
}
|
||||
if doingRemove {
|
||||
if line == "}" {
|
||||
doingRemove = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
newLines = append(newLines, line)
|
||||
}
|
||||
return strings.Join(newLines, "\n")
|
||||
}
|
||||
|
||||
func (d *DeclareDeclType) normalize() error {
|
||||
if d.DataType() == DeclTypeAssocArray {
|
||||
return d.normalizeAssocArrayDecl()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// normalizes order of assoc array keys so value is stable
|
||||
func (d *DeclareDeclType) normalizeAssocArrayDecl() error {
|
||||
if d.DataType() != DeclTypeAssocArray {
|
||||
return fmt.Errorf("invalid decltype passed to assocArrayDeclToStr: %s", d.DataType())
|
||||
}
|
||||
varMap, err := assocArrayVarToMap(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
keys := make([]string, 0, len(varMap))
|
||||
for key, _ := range varMap {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
var buf bytes.Buffer
|
||||
buf.WriteByte('(')
|
||||
for _, key := range keys {
|
||||
buf.WriteByte('[')
|
||||
buf.WriteString(key)
|
||||
buf.WriteByte(']')
|
||||
buf.WriteByte('=')
|
||||
buf.WriteString(varMap[key])
|
||||
buf.WriteByte(' ')
|
||||
}
|
||||
buf.WriteByte(')')
|
||||
d.Value = buf.String()
|
||||
return nil
|
||||
}
|
||||
|
||||
func assocArrayVarToMap(d *DeclareDeclType) (map[string]string, error) {
|
||||
if d.DataType() != DeclTypeAssocArray {
|
||||
return nil, fmt.Errorf("decl is not an assoc-array")
|
||||
}
|
||||
refStr := "X=" + d.Value
|
||||
r := strings.NewReader(refStr)
|
||||
parser := syntax.NewParser(syntax.Variant(syntax.LangBash))
|
||||
file, err := parser.Parse(r, "assocdecl")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(file.Stmts) != 1 {
|
||||
return nil, fmt.Errorf("invalid assoc-array parse (multiple stmts)")
|
||||
}
|
||||
stmt := file.Stmts[0]
|
||||
callExpr, ok := stmt.Cmd.(*syntax.CallExpr)
|
||||
if !ok || len(callExpr.Args) != 0 || len(callExpr.Assigns) != 1 {
|
||||
return nil, fmt.Errorf("invalid assoc-array parse (bad expr)")
|
||||
}
|
||||
assign := callExpr.Assigns[0]
|
||||
arrayExpr := assign.Array
|
||||
if arrayExpr == nil {
|
||||
return nil, fmt.Errorf("invalid assoc-array parse (no array expr)")
|
||||
}
|
||||
rtn := make(map[string]string)
|
||||
for _, elem := range arrayExpr.Elems {
|
||||
indexStr := refStr[elem.Index.Pos().Offset():elem.Index.End().Offset()]
|
||||
valStr := refStr[elem.Value.Pos().Offset():elem.Value.End().Offset()]
|
||||
rtn[indexStr] = valStr
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func strMapsEqual(m1 map[string]string, m2 map[string]string) bool {
|
||||
if len(m1) != len(m2) {
|
||||
return false
|
||||
}
|
||||
for key, val1 := range m1 {
|
||||
val2, found := m2[key]
|
||||
if !found || val1 != val2 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for key, _ := range m2 {
|
||||
_, found := m1[key]
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func DeclsEqual(compareName bool, d1 *DeclareDeclType, d2 *DeclareDeclType) bool {
|
||||
if d1.IsExport() != d2.IsExport() {
|
||||
return false
|
||||
}
|
||||
if d1.DataType() != d2.DataType() {
|
||||
return false
|
||||
}
|
||||
if compareName && d1.Name != d2.Name {
|
||||
return false
|
||||
}
|
||||
return d1.Value == d2.Value // this works even for assoc arrays because we normalize them when parsing
|
||||
}
|
||||
|
||||
func MakeShellStateDiff(oldState packet.ShellState, oldStateHash string, newState packet.ShellState) (packet.ShellStateDiff, error) {
|
||||
var rtn packet.ShellStateDiff
|
||||
rtn.BaseHash = oldStateHash
|
||||
if oldState.Version != newState.Version {
|
||||
return rtn, fmt.Errorf("cannot diff, states have different versions")
|
||||
}
|
||||
rtn.Version = newState.Version
|
||||
if oldState.Cwd != newState.Cwd {
|
||||
rtn.Cwd = newState.Cwd
|
||||
}
|
||||
rtn.Error = newState.Error
|
||||
oldVars := shellStateVarsToMap(oldState.ShellVars)
|
||||
newVars := shellStateVarsToMap(newState.ShellVars)
|
||||
rtn.VarsDiff = statediff.MakeMapDiff(oldVars, newVars)
|
||||
rtn.AliasesDiff = statediff.MakeLineDiff(oldState.Aliases, newState.Aliases)
|
||||
rtn.FuncsDiff = statediff.MakeLineDiff(oldState.Funcs, newState.Funcs)
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func ApplyShellStateDiff(oldState packet.ShellState, diff packet.ShellStateDiff) (packet.ShellState, error) {
|
||||
var rtnState packet.ShellState
|
||||
var err error
|
||||
rtnState.Version = oldState.Version
|
||||
rtnState.Cwd = oldState.Cwd
|
||||
if diff.Cwd != "" {
|
||||
rtnState.Cwd = diff.Cwd
|
||||
}
|
||||
rtnState.Error = diff.Error
|
||||
oldVars := shellStateVarsToMap(oldState.ShellVars)
|
||||
newVars, err := statediff.ApplyMapDiff(oldVars, diff.VarsDiff)
|
||||
if err != nil {
|
||||
return rtnState, fmt.Errorf("applying mapdiff 'vars': %v", err)
|
||||
}
|
||||
rtnState.ShellVars = strMapToShellStateVars(newVars)
|
||||
rtnState.Aliases, err = statediff.ApplyLineDiff(oldState.Aliases, diff.AliasesDiff)
|
||||
if err != nil {
|
||||
return rtnState, fmt.Errorf("applying diff 'aliases': %v", err)
|
||||
}
|
||||
rtnState.Funcs, err = statediff.ApplyLineDiff(oldState.Funcs, diff.FuncsDiff)
|
||||
if err != nil {
|
||||
return rtnState, fmt.Errorf("applying diff 'funcs': %v", err)
|
||||
}
|
||||
return rtnState, nil
|
||||
}
|
1542
waveshell/pkg/shexec/shexec.go
Normal file
1542
waveshell/pkg/shexec/shexec.go
Normal file
File diff suppressed because it is too large
Load Diff
222
waveshell/pkg/simpleexpand/simpleexpand.go
Normal file
222
waveshell/pkg/simpleexpand/simpleexpand.go
Normal file
@ -0,0 +1,222 @@
|
||||
package simpleexpand
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"mvdan.cc/sh/v3/expand"
|
||||
"mvdan.cc/sh/v3/syntax"
|
||||
)
|
||||
|
||||
type SimpleExpandContext struct {
|
||||
HomeDir string
|
||||
}
|
||||
|
||||
type SimpleExpandInfo struct {
|
||||
HasTilde bool // only ~ as the first character when SimpleExpandContext.HomeDir is set
|
||||
HasVar bool // $x, $$, ${...}
|
||||
HasGlob bool // *, ?, [, {
|
||||
HasExtGlob bool // ?(...) ... ?*+@!
|
||||
HasHistory bool // ! (anywhere)
|
||||
HasSpecial bool // subshell, arith
|
||||
}
|
||||
|
||||
func expandHomeDir(info *SimpleExpandInfo, litVal string, multiPart bool, homeDir string) string {
|
||||
if homeDir == "" {
|
||||
return litVal
|
||||
}
|
||||
if litVal == "~" && !multiPart {
|
||||
return homeDir
|
||||
}
|
||||
if strings.HasPrefix(litVal, "~/") {
|
||||
info.HasTilde = true
|
||||
return homeDir + litVal[1:]
|
||||
}
|
||||
return litVal
|
||||
}
|
||||
|
||||
func expandLiteral(buf *bytes.Buffer, info *SimpleExpandInfo, litVal string) {
|
||||
var lastBackSlash bool
|
||||
var lastExtGlob bool
|
||||
var lastDollar bool
|
||||
for _, ch := range litVal {
|
||||
if ch == 0 {
|
||||
break
|
||||
}
|
||||
if lastBackSlash {
|
||||
lastBackSlash = false
|
||||
if ch == '\n' {
|
||||
// special case, backslash *and* newline are ignored
|
||||
continue
|
||||
}
|
||||
buf.WriteRune(ch)
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
lastBackSlash = true
|
||||
lastExtGlob = false
|
||||
lastDollar = false
|
||||
continue
|
||||
}
|
||||
if ch == '*' || ch == '?' || ch == '[' || ch == '{' {
|
||||
info.HasGlob = true
|
||||
}
|
||||
if ch == '`' {
|
||||
info.HasSpecial = true
|
||||
}
|
||||
if ch == '!' {
|
||||
info.HasHistory = true
|
||||
}
|
||||
if lastExtGlob && ch == '(' {
|
||||
info.HasExtGlob = true
|
||||
}
|
||||
if lastDollar && (ch != ' ' && ch != '"' && ch != '\'' && ch != '(' || ch != '[') {
|
||||
info.HasVar = true
|
||||
}
|
||||
if lastDollar && (ch == '(' || ch == '[') {
|
||||
info.HasSpecial = true
|
||||
}
|
||||
lastExtGlob = (ch == '?' || ch == '*' || ch == '+' || ch == '@' || ch == '!')
|
||||
lastDollar = (ch == '$')
|
||||
buf.WriteRune(ch)
|
||||
}
|
||||
if lastBackSlash {
|
||||
buf.WriteByte('\\')
|
||||
}
|
||||
}
|
||||
|
||||
// also expands ~
|
||||
func expandLiteralPlus(buf *bytes.Buffer, info *SimpleExpandInfo, litVal string, multiPart bool, ectx SimpleExpandContext) {
|
||||
litVal = expandHomeDir(info, litVal, multiPart, ectx.HomeDir)
|
||||
expandLiteral(buf, info, litVal)
|
||||
}
|
||||
|
||||
func expandSQANSILiteral(buf *bytes.Buffer, litVal string) {
|
||||
// no info specials
|
||||
if strings.HasSuffix(litVal, "'") {
|
||||
litVal = litVal[0 : len(litVal)-1]
|
||||
}
|
||||
str, _, _ := expand.Format(nil, litVal, nil)
|
||||
buf.WriteString(str)
|
||||
}
|
||||
|
||||
func expandSQLiteral(buf *bytes.Buffer, litVal string) {
|
||||
// no info specials
|
||||
if strings.HasSuffix(litVal, "'") {
|
||||
litVal = litVal[0 : len(litVal)-1]
|
||||
}
|
||||
buf.WriteString(litVal)
|
||||
}
|
||||
|
||||
// will also work for partial double quoted strings
|
||||
func expandDQLiteral(buf *bytes.Buffer, info *SimpleExpandInfo, litVal string) {
|
||||
var lastBackSlash bool
|
||||
var lastDollar bool
|
||||
for _, ch := range litVal {
|
||||
if ch == 0 {
|
||||
break
|
||||
}
|
||||
if lastBackSlash {
|
||||
lastBackSlash = false
|
||||
if ch == '"' || ch == '\\' || ch == '$' || ch == '`' {
|
||||
buf.WriteRune(ch)
|
||||
continue
|
||||
}
|
||||
buf.WriteRune('\\')
|
||||
buf.WriteRune(ch)
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
lastBackSlash = true
|
||||
lastDollar = false
|
||||
continue
|
||||
}
|
||||
if ch == '"' {
|
||||
break
|
||||
}
|
||||
|
||||
// similar to expandLiteral, but no globbing
|
||||
if ch == '`' {
|
||||
info.HasSpecial = true
|
||||
}
|
||||
if ch == '!' {
|
||||
info.HasHistory = true
|
||||
}
|
||||
if lastDollar && (ch != ' ' && ch != '"' && ch != '\'' && ch != '(' || ch != '[') {
|
||||
info.HasVar = true
|
||||
}
|
||||
if lastDollar && (ch == '(' || ch == '[') {
|
||||
info.HasSpecial = true
|
||||
}
|
||||
lastDollar = (ch == '$')
|
||||
buf.WriteRune(ch)
|
||||
}
|
||||
// in a valid parsed DQ string, you cannot have a trailing backslash (because \" would not end the string)
|
||||
// still putting the case here though in case we ever deal with incomplete strings (e.g. completion)
|
||||
if lastBackSlash {
|
||||
buf.WriteByte('\\')
|
||||
}
|
||||
}
|
||||
|
||||
func simpleExpandWordInternal(buf *bytes.Buffer, info *SimpleExpandInfo, ectx SimpleExpandContext, parts []syntax.WordPart, sourceStr string, inDoubleQuote bool, level int) {
|
||||
for partIdx, untypedPart := range parts {
|
||||
switch part := untypedPart.(type) {
|
||||
case *syntax.Lit:
|
||||
if !inDoubleQuote && partIdx == 0 && level == 1 && ectx.HomeDir != "" {
|
||||
expandLiteralPlus(buf, info, part.Value, len(parts) > 1, ectx)
|
||||
} else if inDoubleQuote {
|
||||
expandDQLiteral(buf, info, part.Value)
|
||||
} else {
|
||||
expandLiteral(buf, info, part.Value)
|
||||
}
|
||||
|
||||
case *syntax.SglQuoted:
|
||||
if part.Dollar {
|
||||
expandSQANSILiteral(buf, part.Value)
|
||||
} else {
|
||||
expandSQLiteral(buf, part.Value)
|
||||
}
|
||||
|
||||
case *syntax.DblQuoted:
|
||||
simpleExpandWordInternal(buf, info, ectx, part.Parts, sourceStr, true, level+1)
|
||||
|
||||
default:
|
||||
rawStr := sourceStr[part.Pos().Offset():part.End().Offset()]
|
||||
buf.WriteString(rawStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// simple word expansion
|
||||
// expands: literals, single-quoted strings, double-quoted strings (recursively)
|
||||
// does *not* expand: params (variables), command substitution, arithmetic expressions, process substituions, globs
|
||||
// for the not expands, they will show up as the literal string
|
||||
// this is different than expand.Literal which will replace variables as empty string if they aren't defined.
|
||||
// so "a"'foo'${bar}$x => "afoo${bar}$x", but expand.Literal would produce => "afoo"
|
||||
// note will do ~ expansion (will not do ~user expansion)
|
||||
func SimpleExpandWord(ectx SimpleExpandContext, word *syntax.Word, sourceStr string) (string, SimpleExpandInfo) {
|
||||
var buf bytes.Buffer
|
||||
var info SimpleExpandInfo
|
||||
simpleExpandWordInternal(&buf, &info, ectx, word.Parts, sourceStr, false, 1)
|
||||
return buf.String(), info
|
||||
}
|
||||
|
||||
func SimpleExpandPartialWord(ectx SimpleExpandContext, partialWord string, multiPart bool) (string, SimpleExpandInfo) {
|
||||
var buf bytes.Buffer
|
||||
var info SimpleExpandInfo
|
||||
if partialWord == "" {
|
||||
return "", info
|
||||
}
|
||||
if strings.HasPrefix(partialWord, "\"") {
|
||||
expandDQLiteral(&buf, &info, partialWord[1:])
|
||||
} else if strings.HasPrefix(partialWord, "$\"") {
|
||||
expandDQLiteral(&buf, &info, partialWord[2:])
|
||||
} else if strings.HasPrefix(partialWord, "'") {
|
||||
expandSQLiteral(&buf, partialWord[1:])
|
||||
} else if strings.HasPrefix(partialWord, "$'") {
|
||||
expandSQANSILiteral(&buf, partialWord[2:])
|
||||
} else {
|
||||
expandLiteralPlus(&buf, &info, partialWord, multiPart, ectx)
|
||||
}
|
||||
return buf.String(), info
|
||||
}
|
188
waveshell/pkg/statediff/linediff.go
Normal file
188
waveshell/pkg/statediff/linediff.go
Normal file
@ -0,0 +1,188 @@
|
||||
package statediff
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const LineDiffVersion = 0
|
||||
|
||||
type SingleLineEntry struct {
|
||||
LineVal int
|
||||
Run int
|
||||
}
|
||||
|
||||
type LineDiffType struct {
|
||||
Lines []SingleLineEntry
|
||||
NewData []string
|
||||
}
|
||||
|
||||
func (diff LineDiffType) Dump() {
|
||||
fmt.Printf("DIFF:\n")
|
||||
pos := 1
|
||||
for _, entry := range diff.Lines {
|
||||
fmt.Printf(" %d-%d: %d\n", pos, pos+entry.Run, entry.LineVal)
|
||||
pos += entry.Run
|
||||
}
|
||||
for idx, str := range diff.NewData {
|
||||
fmt.Printf(" n%d: %s\n", idx+1, str)
|
||||
}
|
||||
}
|
||||
|
||||
// simple encoding
|
||||
// a 0 means read a line from NewData
|
||||
// a non-zero number means read the 1-indexed line from OldData
|
||||
func (diff LineDiffType) applyDiff(oldData []string) ([]string, error) {
|
||||
rtn := make([]string, 0, len(diff.Lines))
|
||||
newDataPos := 0
|
||||
for _, entry := range diff.Lines {
|
||||
if entry.LineVal == 0 {
|
||||
for i := 0; i < entry.Run; i++ {
|
||||
if newDataPos >= len(diff.NewData) {
|
||||
return nil, fmt.Errorf("not enough newdata for diff")
|
||||
}
|
||||
rtn = append(rtn, diff.NewData[newDataPos])
|
||||
newDataPos++
|
||||
}
|
||||
} else {
|
||||
oldDataPos := entry.LineVal - 1 // 1-indexed
|
||||
for i := 0; i < entry.Run; i++ {
|
||||
realPos := oldDataPos + i
|
||||
if realPos < 0 || realPos >= len(oldData) {
|
||||
return nil, fmt.Errorf("diff index out of bounds %d old-data-len:%d", realPos, len(oldData))
|
||||
}
|
||||
rtn = append(rtn, oldData[realPos])
|
||||
}
|
||||
}
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func putUVarint(buf *bytes.Buffer, viBuf []byte, ival int) {
|
||||
l := binary.PutUvarint(viBuf, uint64(ival))
|
||||
buf.Write(viBuf[0:l])
|
||||
}
|
||||
|
||||
// simple encoding
|
||||
// write varints. first version, then len, then len-number-of-varints, then fill the rest with newdata
|
||||
// [version] [len-varint] [varint]xlen... newdata (bytes)
|
||||
func (diff LineDiffType) Encode() []byte {
|
||||
var buf bytes.Buffer
|
||||
viBuf := make([]byte, binary.MaxVarintLen64)
|
||||
putUVarint(&buf, viBuf, LineDiffVersion)
|
||||
putUVarint(&buf, viBuf, len(diff.Lines))
|
||||
for _, entry := range diff.Lines {
|
||||
putUVarint(&buf, viBuf, entry.LineVal)
|
||||
putUVarint(&buf, viBuf, entry.Run)
|
||||
}
|
||||
for idx, str := range diff.NewData {
|
||||
buf.WriteString(str)
|
||||
if idx != len(diff.NewData)-1 {
|
||||
buf.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func (rtn *LineDiffType) Decode(diffBytes []byte) error {
|
||||
r := bytes.NewBuffer(diffBytes)
|
||||
version, err := binary.ReadUvarint(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid diff, cannot read version: %v", err)
|
||||
}
|
||||
if version != LineDiffVersion {
|
||||
return fmt.Errorf("invalid diff, bad version: %d", version)
|
||||
}
|
||||
linesLen64, err := binary.ReadUvarint(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid diff, cannot read lines length: %v", err)
|
||||
}
|
||||
linesLen := int(linesLen64)
|
||||
rtn.Lines = make([]SingleLineEntry, linesLen)
|
||||
for idx := 0; idx < linesLen; idx++ {
|
||||
lineVal, err := binary.ReadUvarint(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid diff, cannot read line %d: %v", idx, err)
|
||||
}
|
||||
lineRun, err := binary.ReadUvarint(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid diff, cannot read line-run %d: %v", idx, err)
|
||||
}
|
||||
rtn.Lines[idx] = SingleLineEntry{LineVal: int(lineVal), Run: int(lineRun)}
|
||||
}
|
||||
restOfInput := string(r.Bytes())
|
||||
if len(restOfInput) > 0 {
|
||||
rtn.NewData = strings.Split(restOfInput, "\n")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func makeLineDiff(oldData []string, newData []string) LineDiffType {
|
||||
var rtn LineDiffType
|
||||
oldDataMap := make(map[string]int) // 1-indexed
|
||||
for idx, str := range oldData {
|
||||
if _, found := oldDataMap[str]; found {
|
||||
continue
|
||||
}
|
||||
oldDataMap[str] = idx + 1
|
||||
}
|
||||
var cur *SingleLineEntry
|
||||
rtn.Lines = make([]SingleLineEntry, 0)
|
||||
for _, str := range newData {
|
||||
oldIdx, found := oldDataMap[str]
|
||||
if cur != nil && cur.LineVal != 0 {
|
||||
checkLine := cur.LineVal + cur.Run - 1
|
||||
if checkLine < len(oldData) && oldData[checkLine] == str {
|
||||
cur.Run++
|
||||
continue
|
||||
}
|
||||
} else if cur != nil && cur.LineVal == 0 && !found {
|
||||
cur.Run++
|
||||
rtn.NewData = append(rtn.NewData, str)
|
||||
continue
|
||||
}
|
||||
if cur != nil {
|
||||
rtn.Lines = append(rtn.Lines, *cur)
|
||||
}
|
||||
cur = &SingleLineEntry{Run: 1}
|
||||
if found {
|
||||
cur.LineVal = oldIdx
|
||||
} else {
|
||||
cur.LineVal = 0
|
||||
rtn.NewData = append(rtn.NewData, str)
|
||||
}
|
||||
}
|
||||
if cur != nil {
|
||||
rtn.Lines = append(rtn.Lines, *cur)
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func MakeLineDiff(str1 string, str2 string) []byte {
|
||||
if str1 == str2 {
|
||||
return nil
|
||||
}
|
||||
str1Arr := strings.Split(str1, "\n")
|
||||
str2Arr := strings.Split(str2, "\n")
|
||||
diff := makeLineDiff(str1Arr, str2Arr)
|
||||
return diff.Encode()
|
||||
}
|
||||
|
||||
func ApplyLineDiff(str1 string, diffBytes []byte) (string, error) {
|
||||
if len(diffBytes) == 0 {
|
||||
return str1, nil
|
||||
}
|
||||
var diff LineDiffType
|
||||
err := diff.Decode(diffBytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
str1Arr := strings.Split(str1, "\n")
|
||||
str2Arr, err := diff.applyDiff(str1Arr)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.Join(str2Arr, "\n"), nil
|
||||
}
|
130
waveshell/pkg/statediff/mapdiff.go
Normal file
130
waveshell/pkg/statediff/mapdiff.go
Normal file
@ -0,0 +1,130 @@
|
||||
package statediff
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const MapDiffVersion = 0
|
||||
|
||||
// 0-bytes are not allowed in entries or keys (same as bash)
|
||||
|
||||
type MapDiffType struct {
|
||||
ToAdd map[string]string
|
||||
ToRemove []string
|
||||
}
|
||||
|
||||
func (diff MapDiffType) Dump() {
|
||||
fmt.Printf("VAR-DIFF\n")
|
||||
for name, val := range diff.ToAdd {
|
||||
fmt.Printf(" add[%s] %s\n", name, val)
|
||||
}
|
||||
for _, name := range diff.ToRemove {
|
||||
fmt.Printf(" rem[%s]\n", name)
|
||||
}
|
||||
}
|
||||
|
||||
func makeMapDiff(oldMap map[string]string, newMap map[string]string) MapDiffType {
|
||||
var rtn MapDiffType
|
||||
rtn.ToAdd = make(map[string]string)
|
||||
for name, newVal := range newMap {
|
||||
oldVal, found := oldMap[name]
|
||||
if !found || oldVal != newVal {
|
||||
rtn.ToAdd[name] = newVal
|
||||
continue
|
||||
}
|
||||
}
|
||||
for name, _ := range oldMap {
|
||||
_, found := newMap[name]
|
||||
if !found {
|
||||
rtn.ToRemove = append(rtn.ToRemove, name)
|
||||
}
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (diff MapDiffType) apply(oldMap map[string]string) map[string]string {
|
||||
rtn := make(map[string]string)
|
||||
for name, val := range oldMap {
|
||||
rtn[name] = val
|
||||
}
|
||||
for name, val := range diff.ToAdd {
|
||||
rtn[name] = val
|
||||
}
|
||||
for _, name := range diff.ToRemove {
|
||||
delete(rtn, name)
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (diff MapDiffType) Encode() []byte {
|
||||
var buf bytes.Buffer
|
||||
viBuf := make([]byte, binary.MaxVarintLen64)
|
||||
putUVarint(&buf, viBuf, MapDiffVersion)
|
||||
putUVarint(&buf, viBuf, len(diff.ToAdd))
|
||||
for key, val := range diff.ToAdd {
|
||||
buf.WriteString(key)
|
||||
buf.WriteByte(0)
|
||||
buf.WriteString(val)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
for _, val := range diff.ToRemove {
|
||||
buf.WriteString(val)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func (diff *MapDiffType) Decode(diffBytes []byte) error {
|
||||
r := bytes.NewBuffer(diffBytes)
|
||||
version, err := binary.ReadUvarint(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid diff, cannot read version: %v", err)
|
||||
}
|
||||
if version != MapDiffVersion {
|
||||
return fmt.Errorf("invalid diff, bad version: %d", version)
|
||||
}
|
||||
mapLen64, err := binary.ReadUvarint(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid diff, cannot map length: %v", err)
|
||||
}
|
||||
mapLen := int(mapLen64)
|
||||
fields := bytes.Split(r.Bytes(), []byte{0})
|
||||
if len(fields) < 2*mapLen {
|
||||
return fmt.Errorf("invalid diff, not enough fields, maplen:%d fields:%d", mapLen, len(fields))
|
||||
}
|
||||
mapFields := fields[0 : 2*mapLen]
|
||||
removeFields := fields[2*mapLen:]
|
||||
diff.ToAdd = make(map[string]string)
|
||||
for i := 0; i < len(mapFields); i += 2 {
|
||||
diff.ToAdd[string(mapFields[i])] = string(mapFields[i+1])
|
||||
}
|
||||
for _, removeVal := range removeFields {
|
||||
if len(removeVal) == 0 {
|
||||
continue
|
||||
}
|
||||
diff.ToRemove = append(diff.ToRemove, string(removeVal))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func MakeMapDiff(m1 map[string]string, m2 map[string]string) []byte {
|
||||
diff := makeMapDiff(m1, m2)
|
||||
if len(diff.ToAdd) == 0 && len(diff.ToRemove) == 0 {
|
||||
return nil
|
||||
}
|
||||
return diff.Encode()
|
||||
}
|
||||
|
||||
func ApplyMapDiff(oldMap map[string]string, diffBytes []byte) (map[string]string, error) {
|
||||
if len(diffBytes) == 0 {
|
||||
return oldMap, nil
|
||||
}
|
||||
var diff MapDiffType
|
||||
err := diff.Decode(diffBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return diff.apply(oldMap), nil
|
||||
}
|
99
waveshell/pkg/statediff/statediff_test.go
Normal file
99
waveshell/pkg/statediff/statediff_test.go
Normal file
@ -0,0 +1,99 @@
|
||||
package statediff
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const Str1 = `
|
||||
hello
|
||||
line #2
|
||||
apple
|
||||
grapes
|
||||
banana
|
||||
apple
|
||||
`
|
||||
|
||||
const Str2 = `
|
||||
line #2
|
||||
apple
|
||||
grapes
|
||||
banana
|
||||
`
|
||||
|
||||
const Str3 = `
|
||||
more
|
||||
stuff
|
||||
banana
|
||||
coconut
|
||||
`
|
||||
|
||||
const Str4 = `
|
||||
more
|
||||
stuff
|
||||
banana2
|
||||
coconut
|
||||
`
|
||||
|
||||
func testLineDiff(t *testing.T, str1 string, str2 string) {
|
||||
diffBytes := MakeLineDiff(str1, str2)
|
||||
fmt.Printf("diff-len: %d\n", len(diffBytes))
|
||||
out, err := ApplyLineDiff(str1, diffBytes)
|
||||
if err != nil {
|
||||
t.Errorf("error in diff: %v", err)
|
||||
return
|
||||
}
|
||||
if out != str2 {
|
||||
t.Errorf("bad diff output")
|
||||
}
|
||||
var dt LineDiffType
|
||||
err = dt.Decode(diffBytes)
|
||||
if err != nil {
|
||||
t.Errorf("error decoding diff: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLineDiff(t *testing.T) {
|
||||
testLineDiff(t, Str1, Str2)
|
||||
testLineDiff(t, Str2, Str3)
|
||||
testLineDiff(t, Str1, Str3)
|
||||
testLineDiff(t, Str3, Str1)
|
||||
testLineDiff(t, Str3, Str4)
|
||||
}
|
||||
|
||||
func strMapsEqual(m1 map[string]string, m2 map[string]string) bool {
|
||||
if len(m1) != len(m2) {
|
||||
return false
|
||||
}
|
||||
for key, val := range m1 {
|
||||
val2, ok := m2[key]
|
||||
if !ok || val != val2 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for key, val := range m2 {
|
||||
val2, ok := m1[key]
|
||||
if !ok || val != val2 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestMapDiff(t *testing.T) {
|
||||
m1 := map[string]string{"a": "5", "b": "hello", "c": "mike"}
|
||||
m2 := map[string]string{"a": "5", "b": "goodbye", "d": "more"}
|
||||
diffBytes := MakeMapDiff(m1, m2)
|
||||
fmt.Printf("mapdifflen: %d\n", len(diffBytes))
|
||||
var diff MapDiffType
|
||||
diff.Decode(diffBytes)
|
||||
diff.Dump()
|
||||
mcheck, err := ApplyMapDiff(m1, diffBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("error applying map diff: %v", err)
|
||||
}
|
||||
if !strMapsEqual(m2, mcheck) {
|
||||
t.Errorf("maps not equal")
|
||||
}
|
||||
fmt.Printf("%v\n", mcheck)
|
||||
}
|
18
waveshell/scripthaus.md
Normal file
18
waveshell/scripthaus.md
Normal file
@ -0,0 +1,18 @@
|
||||
|
||||
```bash
|
||||
# @scripthaus command build
|
||||
GO_LDFLAGS="-s -w -X main.BuildTime=$(date +'%Y%m%d%H%M')"
|
||||
go build -ldflags="$GO_LDFLAGS" -o bin/mshell-v0.3-darwin.amd64 main-waveshell.go
|
||||
```
|
||||
|
||||
```bash
|
||||
# @scripthaus command fullbuild
|
||||
GO_LDFLAGS="-s -w -X main.BuildTime=$(date +'%Y%m%d%H%M')"
|
||||
go build -ldflags="$GO_LDFLAGS" -o ~/.mshell/mshell-v0.2 main-waveshell.go
|
||||
GOOS=linux GOARCH=amd64 go build -ldflags="$GO_LDFLAGS" -o bin/mshell-v0.3-linux.amd64 main-waveshell.go
|
||||
GOOS=linux GOARCH=arm64 go build -ldflags="$GO_LDFLAGS" -o bin/mshell-v0.3-linux.arm64 main-waveshell.go
|
||||
GOOS=darwin GOARCH=amd64 go build -ldflags="$GO_LDFLAGS" -o bin/mshell-v0.3-darwin.amd64 main-waveshell.go
|
||||
GOOS=darwin GOARCH=arm64 go build -ldflags="$GO_LDFLAGS" -o bin/mshell-v0.3-darwin.arm64 main-waveshell.go
|
||||
```
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user