merge waveshell into waveterm repo

This commit is contained in:
sawka 2023-10-16 13:11:26 -07:00
commit a8055489f8
26 changed files with 8419 additions and 1 deletions

6
.gitignore vendored
View File

@ -3,13 +3,17 @@ dist-dev/
node_modules/
*~
*.log
*.out
out/
.DS_Store
bin/
waveshell/bin/
wavesrv/bin/
dev-bin
local-server-bin
*.pw
build/
*.dmg
webshare/dist/
webshare/dist-dev/
webshare/dist-dev/

13
waveshell/go.mod Normal file
View File

@ -0,0 +1,13 @@
module github.com/commandlinedev/apishell
go 1.18
require (
github.com/alessio/shellescape v1.4.1
github.com/creack/pty v1.1.18
github.com/fsnotify/fsnotify v1.6.0
github.com/google/uuid v1.3.0
golang.org/x/mod v0.5.1
golang.org/x/sys v0.10.0
mvdan.cc/sh/v3 v3.7.0
)

20
waveshell/go.sum Normal file
View File

@ -0,0 +1,20 @@
github.com/alessio/shellescape v1.4.1 h1:V7yhSDDn8LP4lc4jS8pFkt0zCnzVJlG5JXy9BVKJUX0=
github.com/alessio/shellescape v1.4.1/go.mod h1:PZAiSCk0LJaZkiCSkPv8qIobYglO3FPpyFjDCtHLS30=
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
github.com/frankban/quicktest v1.14.5 h1:dfYrrRyLtiqT9GyKXgdh+k4inNeTvmGbuSgZ3lx3GhA=
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/rogpeppe/go-internal v1.10.1-0.20230524175051-ec119421bb97 h1:3RPlVWzZ/PDqmVuf/FKHARG5EMid/tl7cv54Sw/QRVY=
golang.org/x/mod v0.5.1 h1:OJxoQ/rynoF0dcCdI7cLPktw/hR2cueqYfjm43oqK38=
golang.org/x/mod v0.5.1/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro=
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
mvdan.cc/sh/v3 v3.7.0 h1:lSTjdP/1xsddtaKfGg7Myu7DnlHItd3/M2tomOcNNBg=
mvdan.cc/sh/v3 v3.7.0/go.mod h1:K2gwkaesF/D7av7Kxl0HbF5kGOd2ArupNTX3X44+8l8=

586
waveshell/main-waveshell.go Normal file
View File

@ -0,0 +1,586 @@
package main
import (
"bytes"
"fmt"
"os"
"strconv"
"strings"
"syscall"
"time"
"github.com/commandlinedev/apishell/pkg/base"
"github.com/commandlinedev/apishell/pkg/packet"
"github.com/commandlinedev/apishell/pkg/server"
"github.com/commandlinedev/apishell/pkg/shexec"
"golang.org/x/sys/unix"
)
var BuildTime = "0"
// func doMainRun(pk *packet.RunPacketType, sender *packet.PacketSender) {
// err := shexec.ValidateRunPacket(pk)
// if err != nil {
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("invalid run packet: %v", err))
// return
// }
// fileNames, err := base.GetCommandFileNames(pk.CK)
// if err != nil {
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot get command file names: %v", err))
// return
// }
// cmd, err := shexec.MakeRunnerExec(pk.CK)
// if err != nil {
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot make mshell command: %v", err))
// return
// }
// cmdStdin, err := cmd.StdinPipe()
// if err != nil {
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot pipe stdin to command: %v", err))
// return
// }
// // touch ptyout file (should exist for tailer to work correctly)
// ptyOutFd, err := os.OpenFile(fileNames.PtyOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600)
// if err != nil {
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot open pty out file '%s': %v", fileNames.PtyOutFile, err))
// return
// }
// ptyOutFd.Close() // just opened to create the file, can close right after
// runnerOutFd, err := os.OpenFile(fileNames.RunnerOutFile, os.O_CREATE|os.O_TRUNC|os.O_APPEND|os.O_WRONLY, 0600)
// if err != nil {
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("cannot open runner out file '%s': %v", fileNames.RunnerOutFile, err))
// return
// }
// defer runnerOutFd.Close()
// cmd.Stdout = runnerOutFd
// cmd.Stderr = runnerOutFd
// err = cmd.Start()
// if err != nil {
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("error starting command: %v", err))
// return
// }
// go func() {
// err = packet.SendPacket(cmdStdin, pk)
// if err != nil {
// sender.SendCKErrorPacket(pk.CK, fmt.Sprintf("error sending forked runner command: %v", err))
// return
// }
// cmdStdin.Close()
// // clean up zombies
// cmd.Wait()
// }()
// }
// func doGetCmd(tailer *cmdtail.Tailer, pk *packet.GetCmdPacketType, sender *packet.PacketSender) error {
// err := tailer.AddWatch(pk)
// if err != nil {
// return err
// }
// return nil
// }
// func doMain() {
// homeDir := base.GetHomeDir()
// err := os.Chdir(homeDir)
// if err != nil {
// packet.SendErrorPacket(os.Stdout, fmt.Sprintf("cannot change directory to $HOME '%s': %v", homeDir, err))
// return
// }
// _, err = base.GetMShellPath()
// if err != nil {
// packet.SendErrorPacket(os.Stdout, err.Error())
// return
// }
// packetParser := packet.MakePacketParser(os.Stdin)
// sender := packet.MakePacketSender(os.Stdout)
// tailer, err := cmdtail.MakeTailer(sender)
// if err != nil {
// packet.SendErrorPacket(os.Stdout, err.Error())
// return
// }
// go tailer.Run()
// initPacket := shexec.MakeInitPacket()
// sender.SendPacket(initPacket)
// for pk := range packetParser.MainCh {
// if pk.GetType() == packet.RunPacketStr {
// doMainRun(pk.(*packet.RunPacketType), sender)
// continue
// }
// if pk.GetType() == packet.GetCmdPacketStr {
// err = doGetCmd(tailer, pk.(*packet.GetCmdPacketType), sender)
// if err != nil {
// errPk := packet.MakeErrorPacket(err.Error())
// sender.SendPacket(errPk)
// continue
// }
// continue
// }
// if pk.GetType() == packet.CdPacketStr {
// cdPacket := pk.(*packet.CdPacketType)
// err := os.Chdir(cdPacket.Dir)
// resp := packet.MakeResponsePacket(cdPacket.ReqId)
// if err != nil {
// resp.Error = err.Error()
// } else {
// resp.Success = true
// }
// sender.SendPacket(resp)
// continue
// }
// if pk.GetType() == packet.ErrorPacketStr {
// errPk := pk.(*packet.ErrorPacketType)
// errPk.Error = "invalid packet sent to mshell: " + errPk.Error
// sender.SendPacket(errPk)
// continue
// }
// sender.SendErrorPacket(fmt.Sprintf("invalid packet '%s' sent to mshell", pk.GetType()))
// }
// }
func readFullRunPacket(packetParser *packet.PacketParser) (*packet.RunPacketType, error) {
rpb := packet.MakeRunPacketBuilder()
for pk := range packetParser.MainCh {
ok, runPacket := rpb.ProcessPacket(pk)
if runPacket != nil {
return runPacket, nil
}
if !ok {
return nil, fmt.Errorf("invalid packet '%s' sent to mshell", pk.GetType())
}
}
return nil, fmt.Errorf("no run packet received")
}
func handleSingle(fromServer bool) {
packetParser := packet.MakePacketParser(os.Stdin, false)
sender := packet.MakePacketSender(os.Stdout, nil)
defer func() {
sender.Close()
sender.WaitForDone()
}()
initPacket := shexec.MakeInitPacket()
sender.SendPacket(initPacket)
if len(os.Args) >= 3 && os.Args[2] == "--version" {
return
}
runPacket, err := readFullRunPacket(packetParser)
if err != nil {
sender.SendErrorResponse(runPacket.ReqId, err)
return
}
err = shexec.ValidateRunPacket(runPacket)
if err != nil {
sender.SendErrorResponse(runPacket.ReqId, err)
return
}
if fromServer {
err = runPacket.CK.Validate("run packet")
if err != nil {
sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("run packets from server must have a CK: %v", err))
}
}
if runPacket.Detached {
cmd, startPk, err := shexec.RunCommandDetached(runPacket, sender)
if err != nil {
sender.SendErrorResponse(runPacket.ReqId, err)
return
}
sender.SendPacket(startPk)
sender.Close()
sender.WaitForDone()
cmd.DetachedWait(startPk)
return
} else {
shexec.IgnoreSigPipe()
ticker := time.NewTicker(1 * time.Minute)
go func() {
for range ticker.C {
// this will let the command detect when the server has gone away
// that will then trigger cmd.SendHup() to send SIGHUP to the exec'ed process
sender.SendPacket(packet.MakePingPacket())
}
}()
defer ticker.Stop()
cmd, err := shexec.RunCommandSimple(runPacket, sender, true)
if err != nil {
sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("error running command: %w", err))
return
}
defer cmd.Close()
startPacket := cmd.MakeCmdStartPacket(runPacket.ReqId)
sender.SendPacket(startPacket)
go func() {
exitErr := sender.WaitForDone()
if exitErr != nil {
base.Logf("I/O error talking to server, sending SIGHUP to children\n")
cmd.SendSignal(syscall.SIGHUP)
}
}()
cmd.RunRemoteIOAndWait(packetParser, sender)
return
}
}
func detectOpenFds() ([]packet.RemoteFd, error) {
var fds []packet.RemoteFd
for fdNum := 3; fdNum <= 64; fdNum++ {
flags, err := unix.FcntlInt(uintptr(fdNum), unix.F_GETFL, 0)
if err != nil {
continue
}
flags = flags & 3
rfd := packet.RemoteFd{FdNum: fdNum}
if flags&2 == 2 {
return nil, fmt.Errorf("invalid fd=%d, mshell does not support fds open for reading and writing", fdNum)
}
if flags&1 == 1 {
rfd.Write = true
} else {
rfd.Read = true
}
fds = append(fds, rfd)
}
return fds, nil
}
func parseInstallOpts() (*shexec.InstallOpts, error) {
opts := &shexec.InstallOpts{}
iter := base.MakeOptsIter(os.Args[2:]) // first arg is --install
for iter.HasNext() {
argStr := iter.Next()
found, err := tryParseSSHOpt(iter, &opts.SSHOpts)
if err != nil {
return nil, err
}
if found {
continue
}
if argStr == "--detect" {
opts.Detect = true
continue
}
if base.IsOption(argStr) {
return nil, fmt.Errorf("invalid option '%s' passed to mshell --install", argStr)
}
opts.ArchStr = argStr
break
}
return opts, nil
}
func tryParseSSHOpt(iter *base.OptsIter, sshOpts *shexec.SSHOpts) (bool, error) {
argStr := iter.Current()
if argStr == "--ssh" {
if !iter.IsNextPlain() {
return false, fmt.Errorf("'--ssh [user@host]' missing host")
}
sshOpts.SSHHost = iter.Next()
return true, nil
}
if argStr == "--ssh-opts" {
if !iter.HasNext() {
return false, fmt.Errorf("'--ssh-opts [options]' missing options")
}
sshOpts.SSHOptsStr = iter.Next()
return true, nil
}
if argStr == "-i" {
if !iter.IsNextPlain() {
return false, fmt.Errorf("-i [identity-file]' missing file")
}
sshOpts.SSHIdentity = iter.Next()
return true, nil
}
if argStr == "-l" {
if !iter.IsNextPlain() {
return false, fmt.Errorf("-l [user]' missing user")
}
sshOpts.SSHUser = iter.Next()
return true, nil
}
if argStr == "-p" {
if !iter.IsNextPlain() {
return false, fmt.Errorf("-p [port]' missing port")
}
nextArgStr := iter.Next()
portVal, err := strconv.Atoi(nextArgStr)
if err != nil {
return false, fmt.Errorf("-p [port]' invalid port: %v", err)
}
if portVal <= 0 {
return false, fmt.Errorf("-p [port]' invalid port: %d", portVal)
}
sshOpts.SSHPort = portVal
return true, nil
}
return false, nil
}
func parseClientOpts() (*shexec.ClientOpts, error) {
opts := &shexec.ClientOpts{}
iter := base.MakeOptsIter(os.Args[1:])
for iter.HasNext() {
argStr := iter.Next()
found, err := tryParseSSHOpt(iter, &opts.SSHOpts)
if err != nil {
return nil, err
}
if found {
continue
}
if argStr == "--cwd" {
if !iter.IsNextPlain() {
return nil, fmt.Errorf("'--cwd [dir]' missing directory")
}
opts.Cwd = iter.Next()
continue
}
if argStr == "--detach" {
opts.Detach = true
continue
}
if argStr == "--pty" {
opts.UsePty = true
continue
}
if argStr == "--debug" {
opts.Debug = true
continue
}
if argStr == "--sudo" {
opts.Sudo = true
continue
}
if argStr == "--sudo-with-password" {
if !iter.HasNext() {
return nil, fmt.Errorf("'--sudo-with-password [pw]', missing password")
}
opts.Sudo = true
opts.SudoWithPass = true
opts.SudoPw = iter.Next()
continue
}
if argStr == "--sudo-with-passfile" {
if !iter.IsNextPlain() {
return nil, fmt.Errorf("'--sudo-with-passfile [file]', missing file")
}
opts.Sudo = true
opts.SudoWithPass = true
fileName := iter.Next()
contents, err := os.ReadFile(fileName)
if err != nil {
return nil, fmt.Errorf("cannot read --sudo-with-passfile file '%s': %w", fileName, err)
}
if newlineIdx := bytes.Index(contents, []byte{'\n'}); newlineIdx != -1 {
contents = contents[0:newlineIdx]
}
opts.SudoPw = string(contents) + "\n"
continue
}
if argStr == "--" {
if !iter.HasNext() {
return nil, fmt.Errorf("'--' should be followed by command")
}
opts.Command = strings.Join(iter.Rest(), " ")
break
}
return nil, fmt.Errorf("invalid option '%s' passed to mshell", argStr)
}
return opts, nil
}
func handleClient() (int, error) {
opts, err := parseClientOpts()
if err != nil {
return 1, fmt.Errorf("parsing opts: %w", err)
}
if opts.Debug {
packet.GlobalDebug = true
}
if opts.Command == "" {
return 1, fmt.Errorf("no [command] specified. [command] follows '--' option (see usage)")
}
fds, err := detectOpenFds()
if err != nil {
return 1, err
}
opts.Fds = fds
err = shexec.ValidateRemoteFds(opts.Fds)
if err != nil {
return 1, err
}
runPacket, err := opts.MakeRunPacket() // modifies opts
if err != nil {
return 1, err
}
if runPacket.Detached {
return 1, fmt.Errorf("cannot run detached command from command line client")
}
donePacket, err := shexec.RunClientSSHCommandAndWait(runPacket, shexec.StdContext{}, opts.SSHOpts, nil, opts.Debug)
if err != nil {
return 1, err
}
return donePacket.ExitCode, nil
}
func handleInstall() (int, error) {
opts, err := parseInstallOpts()
if err != nil {
return 1, fmt.Errorf("parsing opts: %w", err)
}
if opts.SSHOpts.SSHHost == "" {
return 1, fmt.Errorf("cannot install without '--ssh user@host' option")
}
if opts.Detect && opts.ArchStr != "" {
return 1, fmt.Errorf("cannot supply both --detect and arch '%s'", opts.ArchStr)
}
if opts.ArchStr == "" && !opts.Detect {
return 1, fmt.Errorf("must supply an arch string or '--detect' to auto detect")
}
if opts.ArchStr != "" {
fullArch := opts.ArchStr
fields := strings.SplitN(fullArch, ".", 2)
if len(fields) != 2 {
return 1, fmt.Errorf("invalid arch format '%s' passed to mshell --install", fullArch)
}
goos, goarch := fields[0], fields[1]
if !base.ValidGoArch(goos, goarch) {
return 1, fmt.Errorf("invalid arch '%s' passed to mshell --install", fullArch)
}
optName := base.GoArchOptFile(base.MShellVersion, goos, goarch)
_, err = os.Stat(optName)
if err != nil {
return 1, fmt.Errorf("cannot install mshell to remote host, cannot read '%s': %w", optName, err)
}
opts.OptName = optName
}
err = shexec.RunInstallFromOpts(opts)
if err != nil {
return 1, err
}
return 0, nil
}
func handleEnv() (int, error) {
cwd, err := os.Getwd()
if err != nil {
return 1, err
}
fmt.Printf("%s\x00\x00", cwd)
fullEnv := os.Environ()
var linePrinted bool
for _, envLine := range fullEnv {
if envLine != "" {
fmt.Printf("%s\x00", envLine)
linePrinted = true
}
}
if linePrinted {
fmt.Printf("\x00")
} else {
fmt.Printf("\x00\x00")
}
return 0, nil
}
func handleUsage() {
usage := `
Client Usage: mshell [opts] --ssh user@host -- [command]
mshell multiplexes input and output streams to a remote command over ssh.
Options:
-i [identity-file] - used to set '-i' option for ssh command
-l [user] - used to set '-l' option for ssh command
--cwd [dir] - execute remote command in [dir]
--ssh-opts [opts] - addition options to pass to ssh command
[command] - the remote command to execute
Sudo Options:
--sudo - use only if sudo never requires a password
--sudo-with-password [pw] - not recommended, use --sudo-with-passfile if possible
--sudo-with-passfile [file]
Sudo options allow you to run the given command using "sudo". The first
option only works when you can sudo without a password. Your password will be passed
securely through a high numbered fd to "sudo -S". Note that to use high numbered
file descriptors with sudo, you will need to add this line to your /etc/sudoers file:
Defaults closefrom_override
See full documentation for more details.
Examples:
# execute a python script remotely, with stdin still hooked up correctly
mshell --cwd "~/work" -i key.pem --ssh ubuntu@somehost -- "python3 /dev/fd/4" 4< myscript.py
# capture multiple outputs
mshell --ssh ubuntu@test -- "cat file1.txt > /dev/fd/3; cat file2.txt > /dev/fd/4" 3> file1.txt 4> file2.txt
# execute a script, catpure stdout/stderr in fd-3 and fd-4
# useful if you need to see stdout for interacting with ssh (password or host auth)
mshell --ssh user@host -- "test.sh > /dev/fd/3 2> /dev/fd/4" 3> test.stdout 4> test.stderr
# run a script as root (via sudo), capture output
mshell --sudo-with-passfile pw.txt --ssh ubuntu@somehost -- "python3 /dev/fd/3 > /dev/fd/4" 3< myscript.py 4> script-output.txt < script-input.txt
`
fmt.Printf("%s\n\n", strings.TrimSpace(usage))
}
func main() {
base.SetBuildTime(BuildTime)
if len(os.Args) == 1 {
handleUsage()
return
}
firstArg := os.Args[1]
if firstArg == "--help" {
handleUsage()
return
} else if firstArg == "--version" {
fmt.Printf("mshell %s+%s\n", base.MShellVersion, base.BuildTime)
return
} else if firstArg == "--test-env" {
state, err := shexec.GetShellState()
if state != nil {
}
if err != nil {
fmt.Fprintf(os.Stderr, "[error] %v\n", err)
os.Exit(1)
}
} else if firstArg == "--single" {
base.InitDebugLog("single")
handleSingle(false)
return
} else if firstArg == "--single-from-server" {
base.InitDebugLog("single")
handleSingle(true)
return
} else if firstArg == "--server" {
base.InitDebugLog("server")
rtnCode, err := server.RunServer()
if err != nil {
fmt.Fprintf(os.Stderr, "[error] %v\n", err)
}
if rtnCode != 0 {
os.Exit(rtnCode)
}
return
} else if firstArg == "--install" {
rtnCode, err := handleInstall()
if err != nil {
fmt.Fprintf(os.Stderr, "[error] %v\n", err)
}
os.Exit(rtnCode)
return
} else {
rtnCode, err := handleClient()
if err != nil {
fmt.Fprintf(os.Stderr, "[error] %v\n", err)
}
if rtnCode != 0 {
os.Exit(rtnCode)
}
return
}
}

381
waveshell/pkg/base/base.go Normal file
View File

@ -0,0 +1,381 @@
package base
import (
"errors"
"fmt"
"io"
"io/fs"
"log"
"os"
"os/exec"
"path"
"path/filepath"
"strings"
"sync"
"github.com/google/uuid"
"golang.org/x/mod/semver"
)
const HomeVarName = "HOME"
const DefaultMShellHome = "~/.mshell"
const DefaultMShellName = "mshell"
const MShellPathVarName = "MSHELL_PATH"
const MShellHomeVarName = "MSHELL_HOME"
const MShellInstallBinVarName = "MSHELL_INSTALLBIN_PATH"
const SSHCommandVarName = "SSH_COMMAND"
const MShellDebugVarName = "MSHELL_DEBUG"
const SessionsDirBaseName = "sessions"
const MShellVersion = "v0.3.0"
const RemoteIdFile = "remoteid"
const DefaultMShellInstallBinDir = "/opt/mshell/bin"
const LogFileName = "mshell.log"
const ForceDebugLog = false
const DebugFlag_LogRcFile = "logrc"
const LogRcFileName = "debug.rcfile"
var sessionDirCache = make(map[string]string)
var baseLock = &sync.Mutex{}
var DebugLogEnabled = false
var DebugLogger *log.Logger
var BuildTime string = "0"
type CommandFileNames struct {
PtyOutFile string
StdinFifo string
RunnerOutFile string
}
type CommandKey string
func SetBuildTime(build string) {
BuildTime = build
}
func MakeCommandKey(sessionId string, cmdId string) CommandKey {
if sessionId == "" && cmdId == "" {
return CommandKey("")
}
return CommandKey(fmt.Sprintf("%s/%s", sessionId, cmdId))
}
func (ckey CommandKey) IsEmpty() bool {
return string(ckey) == ""
}
func Logf(fmtStr string, args ...interface{}) {
if (!DebugLogEnabled && !ForceDebugLog) || DebugLogger == nil {
return
}
DebugLogger.Printf(fmtStr, args...)
}
func InitDebugLog(prefix string) {
homeDir := GetMShellHomeDir()
err := os.MkdirAll(homeDir, 0777)
if err != nil {
return
}
logFile := path.Join(homeDir, LogFileName)
fd, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return
}
DebugLogger = log.New(fd, prefix+" ", log.LstdFlags)
Logf("logger initialized\n")
}
func SetEnableDebugLog(enable bool) {
DebugLogEnabled = enable
}
// deprecated (use GetGroupId instead)
func (ckey CommandKey) GetSessionId() string {
return ckey.GetGroupId()
}
func (ckey CommandKey) GetGroupId() string {
slashIdx := strings.Index(string(ckey), "/")
if slashIdx == -1 {
return ""
}
return string(ckey[0:slashIdx])
}
func (ckey CommandKey) GetCmdId() string {
slashIdx := strings.Index(string(ckey), "/")
if slashIdx == -1 {
return ""
}
return string(ckey[slashIdx+1:])
}
func (ckey CommandKey) Split() (string, string) {
fields := strings.SplitN(string(ckey), "/", 2)
if len(fields) < 2 {
return "", ""
}
return fields[0], fields[1]
}
func (ckey CommandKey) Validate(typeStr string) error {
if typeStr == "" {
typeStr = "ck"
}
if ckey == "" {
return fmt.Errorf("%s has empty commandkey", typeStr)
}
sessionId, cmdId := ckey.Split()
if sessionId == "" {
return fmt.Errorf("%s does not have sessionid", typeStr)
}
_, err := uuid.Parse(sessionId)
if err != nil {
return fmt.Errorf("%s has invalid sessionid '%s'", typeStr, sessionId)
}
if cmdId == "" {
return fmt.Errorf("%s does not have cmdid", typeStr)
}
_, err = uuid.Parse(cmdId)
if err != nil {
return fmt.Errorf("%s has invalid cmdid '%s'", typeStr, cmdId)
}
return nil
}
func HasDebugFlag(envMap map[string]string, flagName string) bool {
msDebug := envMap[MShellDebugVarName]
flags := strings.Split(msDebug, ",")
Logf("hasdebugflag[%s]: %s [%#v]\n", flagName, msDebug, flags)
for _, flag := range flags {
if strings.TrimSpace(flag) == flagName {
return true
}
}
return false
}
func GetDebugRcFileName() string {
msHome := GetMShellHomeDir()
return path.Join(msHome, LogRcFileName)
}
func GetHomeDir() string {
homeVar := os.Getenv(HomeVarName)
if homeVar == "" {
return "/"
}
return homeVar
}
func GetMShellHomeDir() string {
homeVar := os.Getenv(MShellHomeVarName)
if homeVar != "" {
return homeVar
}
return ExpandHomeDir(DefaultMShellHome)
}
func GetCommandFileNames(ck CommandKey) (*CommandFileNames, error) {
if err := ck.Validate("ck"); err != nil {
return nil, fmt.Errorf("cannot get command files: %w", err)
}
sessionId, cmdId := ck.Split()
sdir, err := EnsureSessionDir(sessionId)
if err != nil {
return nil, err
}
base := path.Join(sdir, cmdId)
return &CommandFileNames{
PtyOutFile: base + ".ptyout",
StdinFifo: base + ".stdin",
RunnerOutFile: base + ".runout",
}, nil
}
func CleanUpCmdFiles(sessionId string, cmdId string) error {
if cmdId == "" {
return fmt.Errorf("bad cmdid, cannot clean up")
}
sdir, err := EnsureSessionDir(sessionId)
if err != nil {
return err
}
cmdFileGlob := path.Join(sdir, cmdId+".*")
matches, err := filepath.Glob(cmdFileGlob)
if err != nil {
return err
}
for _, file := range matches {
rmErr := os.Remove(file)
if err == nil && rmErr != nil {
err = rmErr
}
}
return err
}
func GetSessionsDir() string {
mhome := GetMShellHomeDir()
sdir := path.Join(mhome, SessionsDirBaseName)
return sdir
}
func EnsureSessionDir(sessionId string) (string, error) {
if sessionId == "" {
return "", fmt.Errorf("Bad sessionid, cannot be empty")
}
baseLock.Lock()
sdir, ok := sessionDirCache[sessionId]
baseLock.Unlock()
if ok {
return sdir, nil
}
mhome := GetMShellHomeDir()
sdir = path.Join(mhome, SessionsDirBaseName, sessionId)
info, err := os.Stat(sdir)
if errors.Is(err, fs.ErrNotExist) {
err = os.MkdirAll(sdir, 0777)
if err != nil {
return "", fmt.Errorf("cannot make mshell session directory[%s]: %w", sdir, err)
}
info, err = os.Stat(sdir)
}
if err != nil {
return "", err
}
if !info.IsDir() {
return "", fmt.Errorf("session dir '%s' must be a directory", sdir)
}
baseLock.Lock()
sessionDirCache[sessionId] = sdir
baseLock.Unlock()
return sdir, nil
}
func GetMShellPath() (string, error) {
msPath := os.Getenv(MShellPathVarName) // use MSHELL_PATH
if msPath != "" {
return exec.LookPath(msPath)
}
mhome := GetMShellHomeDir()
userMShellPath := path.Join(mhome, DefaultMShellName) // look in ~/.mshell
msPath, err := exec.LookPath(userMShellPath)
if err == nil {
return msPath, nil
}
return exec.LookPath(DefaultMShellName) // standard path lookup for 'mshell'
}
func GetMShellSessionsDir() (string, error) {
mhome := GetMShellHomeDir()
return path.Join(mhome, SessionsDirBaseName), nil
}
func ExpandHomeDir(pathStr string) string {
if pathStr != "~" && !strings.HasPrefix(pathStr, "~/") {
return pathStr
}
homeDir := GetHomeDir()
if pathStr == "~" {
return homeDir
}
return path.Join(homeDir, pathStr[2:])
}
func ValidGoArch(goos string, goarch string) bool {
return (goos == "darwin" || goos == "linux") && (goarch == "amd64" || goarch == "arm64")
}
func GoArchOptFile(version string, goos string, goarch string) string {
installBinDir := os.Getenv(MShellInstallBinVarName)
if installBinDir == "" {
installBinDir = DefaultMShellInstallBinDir
}
versionStr := semver.MajorMinor(version)
if versionStr == "" {
versionStr = "unknown"
}
binBaseName := fmt.Sprintf("mshell-%s-%s.%s", versionStr, goos, goarch)
return fmt.Sprintf(path.Join(installBinDir, binBaseName))
}
func MShellBinaryFromOptDir(version string, goos string, goarch string) (io.ReadCloser, error) {
if !ValidGoArch(goos, goarch) {
return nil, fmt.Errorf("invalid goos/goarch combination: %s/%s", goos, goarch)
}
versionStr := semver.MajorMinor(version)
if versionStr == "" {
return nil, fmt.Errorf("invalid mshell version: %q", version)
}
fileName := GoArchOptFile(version, goos, goarch)
fd, err := os.Open(fileName)
if err != nil {
return nil, fmt.Errorf("cannot open mshell binary %q: %v", fileName, err)
}
return fd, nil
}
func GetRemoteId() (string, error) {
mhome := GetMShellHomeDir()
homeInfo, err := os.Stat(mhome)
if errors.Is(err, fs.ErrNotExist) {
err = os.MkdirAll(mhome, 0777)
if err != nil {
return "", fmt.Errorf("cannot make mshell home directory[%s]: %w", mhome, err)
}
homeInfo, err = os.Stat(mhome)
}
if err != nil {
return "", fmt.Errorf("cannot stat mshell home directory[%s]: %w", mhome, err)
}
if !homeInfo.IsDir() {
return "", fmt.Errorf("mshell home directory[%s] is not a directory", mhome)
}
remoteIdFile := path.Join(mhome, RemoteIdFile)
fd, err := os.Open(remoteIdFile)
if errors.Is(err, fs.ErrNotExist) {
// write the file
remoteId := uuid.New().String()
err = os.WriteFile(remoteIdFile, []byte(remoteId), 0644)
if err != nil {
return "", fmt.Errorf("cannot write remoteid to '%s': %w", remoteIdFile, err)
}
return remoteId, nil
} else if err != nil {
return "", fmt.Errorf("cannot read remoteid file '%s': %w", remoteIdFile, err)
} else {
defer fd.Close()
contents, err := io.ReadAll(fd)
if err != nil {
return "", fmt.Errorf("cannot read remoteid file '%s': %w", remoteIdFile, err)
}
uuidStr := string(contents)
_, err = uuid.Parse(uuidStr)
if err != nil {
return "", fmt.Errorf("invalid uuid read from '%s': %w", remoteIdFile, err)
}
return uuidStr, nil
}
}
func BoundInt(ival int, minVal int, maxVal int) int {
if ival < minVal {
return minVal
}
if ival > maxVal {
return maxVal
}
return ival
}
func BoundInt64(ival int64, minVal int64, maxVal int64) int64 {
if ival < minVal {
return minVal
}
if ival > maxVal {
return maxVal
}
return ival
}

View File

@ -0,0 +1,47 @@
package base
import "strings"
type OptsIter struct {
Pos int
Opts []string
}
func MakeOptsIter(opts []string) *OptsIter {
return &OptsIter{Opts: opts}
}
func IsOption(argStr string) bool {
return strings.HasPrefix(argStr, "-") && argStr != "-" && !strings.HasPrefix(argStr, "-/")
}
func (iter *OptsIter) HasNext() bool {
return iter.Pos <= len(iter.Opts)-1
}
func (iter *OptsIter) IsNextPlain() bool {
if !iter.HasNext() {
return false
}
return !IsOption(iter.Opts[iter.Pos])
}
func (iter *OptsIter) Next() string {
if iter.Pos >= len(iter.Opts) {
return ""
}
rtn := iter.Opts[iter.Pos]
iter.Pos++
return rtn
}
func (iter *OptsIter) Current() string {
if iter.Pos == 0 {
return ""
}
return iter.Opts[iter.Pos-1]
}
func (iter *OptsIter) Rest() []string {
return iter.Opts[iter.Pos:]
}

View File

@ -0,0 +1,127 @@
package binpack
import (
"encoding/binary"
"encoding/json"
"fmt"
"io"
)
type Unpacker struct {
R FullByteReader
Err error
}
type FullByteReader interface {
io.ByteReader
io.Reader
}
func PackValue(w io.Writer, barr []byte) error {
viBuf := make([]byte, binary.MaxVarintLen64)
viLen := binary.PutUvarint(viBuf, uint64(len(barr)))
_, err := w.Write(viBuf[0:viLen])
if err != nil {
return err
}
if len(barr) > 0 {
_, err = w.Write(barr)
if err != nil {
return err
}
}
return nil
}
func PackStrArr(w io.Writer, strs []string) error {
barr, err := json.Marshal(strs)
if err != nil {
return err
}
return PackValue(w, barr)
}
func PackInt(w io.Writer, ival int) error {
viBuf := make([]byte, binary.MaxVarintLen64)
l := binary.PutUvarint(viBuf, uint64(ival))
_, err := w.Write(viBuf[0:l])
return err
}
func UnpackValue(r FullByteReader) ([]byte, error) {
lenVal, err := binary.ReadUvarint(r)
if err != nil {
return nil, err
}
if lenVal == 0 {
return nil, nil
}
rtnBuf := make([]byte, int(lenVal))
_, err = io.ReadFull(r, rtnBuf)
if err != nil {
return nil, err
}
return rtnBuf, nil
}
func UnpackStrArr(r FullByteReader) ([]string, error) {
barr, err := UnpackValue(r)
if err != nil {
return nil, err
}
var strs []string
err = json.Unmarshal(barr, &strs)
if err != nil {
return nil, err
}
return strs, nil
}
func UnpackInt(r io.ByteReader) (int, error) {
ival64, err := binary.ReadVarint(r)
if err != nil {
return 0, err
}
return int(ival64), nil
}
func (u *Unpacker) UnpackValue(name string) []byte {
if u.Err != nil {
return nil
}
rtn, err := UnpackValue(u.R)
if err != nil {
u.Err = fmt.Errorf("cannot unpack %s: %v", name, err)
}
return rtn
}
func (u *Unpacker) UnpackInt(name string) int {
if u.Err != nil {
return 0
}
rtn, err := UnpackInt(u.R)
if err != nil {
u.Err = fmt.Errorf("cannot unpack %s: %v", name, err)
}
return rtn
}
func (u *Unpacker) UnpackStrArr(name string) []string {
if u.Err != nil {
return nil
}
rtn, err := UnpackStrArr(u.R)
if err != nil {
u.Err = fmt.Errorf("cannot unpack %s: %v", name, err)
}
return rtn
}
func (u *Unpacker) Error() error {
return u.Err
}
func MakeUnpacker(r FullByteReader) *Unpacker {
return &Unpacker{R: r}
}

View File

@ -0,0 +1,570 @@
package cirfile
import (
"context"
"fmt"
"io"
"os"
"syscall"
"time"
)
// CBUF[version] [maxsize] [fileoffset] [startpos] [endpos]
const HeaderFmt1 = "CBUF%02d %19d %19d %19d %19d\n" // 87 bytes
const HeaderLen = 256 // set to 256 for future expandability
const FullHeaderFmt = "%-255s\n" // 256 bytes (255 + newline)
const CurrentVersion = 1
const FilePosEmpty = -1 // sentinel, if startpos is set to -1, file is empty
const InitialLockDelay = 10 * time.Millisecond
const InitialLockTries = 5
const LockDelay = 100 * time.Millisecond
// File objects are *not* multithread safe, operations must be externally synchronized
type File struct {
OSFile *os.File
Version byte
MaxSize int64
FileOffset int64
StartPos int64
EndPos int64
FileDataSize int64 // size of data (does not include header size)
FlockStatus int
}
type Stat struct {
Location string
Version byte
MaxSize int64
FileOffset int64
DataSize int64
}
func (f *File) flock(ctx context.Context, lockType int) error {
err := syscall.Flock(int(f.OSFile.Fd()), lockType|syscall.LOCK_NB)
if err == nil {
f.FlockStatus = lockType
return nil
}
if err != syscall.EWOULDBLOCK {
return err
}
if ctx == nil {
return syscall.EWOULDBLOCK
}
// busy-wait with context
numWaits := 0
for {
numWaits++
var timeout time.Duration
if numWaits <= InitialLockTries {
timeout = InitialLockDelay
} else {
timeout = LockDelay
}
select {
case <-time.After(timeout):
break
case <-ctx.Done():
return ctx.Err()
}
err = syscall.Flock(int(f.OSFile.Fd()), lockType|syscall.LOCK_NB)
if err == nil {
f.FlockStatus = lockType
return nil
}
if err != syscall.EWOULDBLOCK {
return err
}
}
return fmt.Errorf("could not acquire lock")
}
func (f *File) unflock() {
if f.FlockStatus != 0 {
syscall.Flock(int(f.OSFile.Fd()), syscall.LOCK_UN) // ignore error (nothing to do about it anyway)
f.FlockStatus = 0
}
return
}
// does not read metadata because locking could block/fail. we want to be able
// to return a valid file struct without blocking.
func OpenCirFile(fileName string) (*File, error) {
fd, err := os.OpenFile(fileName, os.O_RDWR, 0777)
if err != nil {
return nil, err
}
finfo, err := fd.Stat()
if err != nil {
return nil, err
}
if finfo.Size() < HeaderLen {
return nil, fmt.Errorf("invalid cirfile, file length[%d] less than HeaderLen[%d]", finfo.Size(), HeaderLen)
}
rtn := &File{OSFile: fd}
return rtn, nil
}
func StatCirFile(ctx context.Context, fileName string) (*Stat, error) {
file, err := OpenCirFile(fileName)
if err != nil {
return nil, err
}
defer file.Close()
fileOffset, dataSize, err := file.GetStartOffsetAndSize(ctx)
if err != nil {
return nil, err
}
return &Stat{
Location: fileName,
Version: file.Version,
MaxSize: file.MaxSize,
FileOffset: fileOffset,
DataSize: dataSize,
}, nil
}
// if the file already exists, it is an error.
// there is a race condition if two goroutines try to create the same file between Stat() and Create(), so
// they both might get no error, but only one file will be valid. if this is a concern, this call
// should be externally synchronized.
func CreateCirFile(fileName string, maxSize int64) (*File, error) {
if maxSize <= 0 {
return nil, fmt.Errorf("invalid maxsize[%d]", maxSize)
}
_, err := os.Stat(fileName)
if err == nil {
return nil, fmt.Errorf("file[%s] already exists", fileName)
}
if !os.IsNotExist(err) {
return nil, fmt.Errorf("cannot stat: %w", err)
}
fd, err := os.Create(fileName)
if err != nil {
return nil, err
}
rtn := &File{OSFile: fd, Version: CurrentVersion, MaxSize: maxSize, StartPos: FilePosEmpty}
err = rtn.flock(nil, syscall.LOCK_EX)
if err != nil {
return nil, err
}
defer rtn.unflock()
err = rtn.writeMeta()
if err != nil {
return nil, err
}
return rtn, nil
}
func (f *File) Close() error {
return f.OSFile.Close()
}
func (f *File) ReadMeta(ctx context.Context) error {
err := f.flock(ctx, syscall.LOCK_SH)
if err != nil {
return err
}
defer f.unflock()
return f.readMeta()
}
func (f *File) hasShLock() bool {
return f.FlockStatus == syscall.LOCK_EX || f.FlockStatus == syscall.LOCK_SH
}
func (f *File) hasExLock() bool {
return f.FlockStatus == syscall.LOCK_EX
}
func (f *File) readMeta() error {
if f.OSFile == nil {
return fmt.Errorf("no *os.File")
}
if !f.hasShLock() {
return fmt.Errorf("writeMeta must hold LOCK_SH")
}
_, err := f.OSFile.Seek(0, 0)
if err != nil {
return fmt.Errorf("cannot seek file: %w", err)
}
finfo, err := f.OSFile.Stat()
if err != nil {
return fmt.Errorf("cannot stat file: %w", err)
}
if finfo.Size() < 256 {
return fmt.Errorf("invalid cbuf file size[%d] < 256", finfo.Size())
}
f.FileDataSize = finfo.Size() - 256
buf := make([]byte, 256)
_, err = io.ReadFull(f.OSFile, buf)
if err != nil {
return fmt.Errorf("error reading header: %w", err)
}
// currently only one version, so we don't need to have special logic here yet
_, err = fmt.Sscanf(string(buf), HeaderFmt1, &f.Version, &f.MaxSize, &f.FileOffset, &f.StartPos, &f.EndPos)
if err != nil {
return fmt.Errorf("sscanf error: %w", err)
}
if f.Version != CurrentVersion {
return fmt.Errorf("invalid cbuf version[%d]", f.Version)
}
// possible incomplete write, fix start/end pos to be within filesize
if f.FileDataSize == 0 {
f.StartPos = FilePosEmpty
f.EndPos = 0
} else if f.StartPos >= f.FileDataSize && f.EndPos >= f.FileDataSize {
f.StartPos = FilePosEmpty
f.EndPos = 0
} else if f.StartPos >= f.FileDataSize {
f.StartPos = 0
} else if f.EndPos >= f.FileDataSize {
f.EndPos = f.FileDataSize - 1
}
if f.MaxSize <= 0 || f.FileOffset < 0 || (f.StartPos < 0 && f.StartPos != FilePosEmpty) || f.StartPos >= f.MaxSize || f.EndPos < 0 || f.EndPos >= f.MaxSize {
return fmt.Errorf("invalid cbuf metadata version[%d] filedatasize[%d] maxsize[%d] fileoffset[%d] startpos[%d] endpos[%d]", f.Version, f.FileDataSize, f.MaxSize, f.FileOffset, f.StartPos, f.EndPos)
}
return nil
}
// no error checking of meta values
func (f *File) writeMeta() error {
if f.OSFile == nil {
return fmt.Errorf("no *os.File")
}
if !f.hasExLock() {
return fmt.Errorf("writeMeta must hold LOCK_EX")
}
_, err := f.OSFile.Seek(0, 0)
if err != nil {
return fmt.Errorf("cannot seek file: %w", err)
}
metaStr := fmt.Sprintf(HeaderFmt1, f.Version, f.MaxSize, f.FileOffset, f.StartPos, f.EndPos)
fullMetaStr := fmt.Sprintf(FullHeaderFmt, metaStr)
_, err = f.OSFile.WriteString(fullMetaStr)
if err != nil {
return fmt.Errorf("write error: %w", err)
}
return nil
}
// returns (fileOffset, datasize, error)
// datasize is the current amount of readable data held in the cirfile
func (f *File) GetStartOffsetAndSize(ctx context.Context) (int64, int64, error) {
err := f.flock(ctx, syscall.LOCK_SH)
if err != nil {
return 0, 0, err
}
defer f.unflock()
err = f.readMeta()
if err != nil {
return 0, 0, err
}
chunks := f.getFileChunks()
return f.FileOffset, totalChunksSize(chunks), nil
}
type fileChunk struct {
StartPos int64
Len int64
}
func totalChunksSize(chunks []fileChunk) int64 {
var rtn int64
for _, chunk := range chunks {
rtn += chunk.Len
}
return rtn
}
func advanceChunks(chunks []fileChunk, offset int64) []fileChunk {
if offset < 0 {
panic(fmt.Sprintf("invalid negative offset: %d", offset))
}
if offset == 0 {
return chunks
}
var rtn []fileChunk
for _, chunk := range chunks {
if offset >= chunk.Len {
offset = offset - chunk.Len
continue
}
if offset == 0 {
rtn = append(rtn, chunk)
} else {
rtn = append(rtn, fileChunk{chunk.StartPos + offset, chunk.Len - offset})
offset = 0
}
}
return rtn
}
func (f *File) getFileChunks() []fileChunk {
if f.StartPos == FilePosEmpty {
return nil
}
if f.EndPos >= f.StartPos {
return []fileChunk{fileChunk{f.StartPos, f.EndPos - f.StartPos + 1}}
}
return []fileChunk{
fileChunk{f.StartPos, f.FileDataSize - f.StartPos},
fileChunk{0, f.EndPos + 1},
}
}
func (f *File) getFreeChunks() []fileChunk {
if f.StartPos == FilePosEmpty {
return []fileChunk{fileChunk{0, f.MaxSize}}
}
if (f.EndPos == f.StartPos-1) || (f.StartPos == 0 && f.EndPos == f.MaxSize-1) {
return nil
}
if f.EndPos < f.StartPos {
return []fileChunk{fileChunk{f.EndPos + 1, f.StartPos - f.EndPos - 1}}
}
var rtn []fileChunk
if f.EndPos < f.MaxSize-1 {
rtn = append(rtn, fileChunk{f.EndPos + 1, f.MaxSize - f.EndPos - 1})
}
if f.StartPos > 0 {
rtn = append(rtn, fileChunk{0, f.StartPos})
}
return rtn
}
// returns (offset, data, err)
func (f *File) ReadAll(ctx context.Context) (int64, []byte, error) {
err := f.flock(ctx, syscall.LOCK_SH)
if err != nil {
return 0, nil, err
}
defer f.unflock()
err = f.readMeta()
if err != nil {
return 0, nil, err
}
chunks := f.getFileChunks()
curSize := totalChunksSize(chunks)
buf := make([]byte, curSize)
realOffset, nr, err := f.internalReadNext(buf, 0)
return realOffset, buf[0:nr], err
}
func (f *File) ReadAtWithMax(ctx context.Context, offset int64, maxSize int64) (int64, []byte, error) {
err := f.flock(ctx, syscall.LOCK_SH)
if err != nil {
return 0, nil, err
}
defer f.unflock()
err = f.readMeta()
if err != nil {
return 0, nil, err
}
chunks := f.getFileChunks()
curSize := totalChunksSize(chunks)
var buf []byte
if maxSize > curSize {
buf = make([]byte, curSize)
} else {
buf = make([]byte, maxSize)
}
realOffset, nr, err := f.internalReadNext(buf, offset)
return realOffset, buf[0:nr], err
}
func (f *File) internalReadNext(buf []byte, offset int64) (int64, int, error) {
if offset < f.FileOffset {
offset = f.FileOffset
}
relativeOffset := offset - f.FileOffset
chunks := f.getFileChunks()
curSize := totalChunksSize(chunks)
if offset >= f.FileOffset+curSize {
return f.FileOffset + curSize, 0, nil
}
chunks = advanceChunks(chunks, relativeOffset)
numRead := 0
for _, chunk := range chunks {
if numRead >= len(buf) {
break
}
toRead := len(buf) - numRead
if toRead > int(chunk.Len) {
toRead = int(chunk.Len)
}
nr, err := f.OSFile.ReadAt(buf[numRead:numRead+toRead], chunk.StartPos+HeaderLen)
if err != nil {
return offset, 0, err
}
numRead += nr
}
return offset, numRead, nil
}
// returns (realOffset, numread, error)
// will only return io.EOF when len(data) == 0, otherwise will just do a short read
func (f *File) ReadNext(ctx context.Context, buf []byte, offset int64) (int64, int, error) {
err := f.flock(ctx, syscall.LOCK_SH)
if err != nil {
return 0, 0, err
}
defer f.unflock()
err = f.readMeta()
if err != nil {
return 0, 0, err
}
return f.internalReadNext(buf, offset)
}
func (f *File) ensureFreeSpace(requiredSpace int64) error {
chunks := f.getFileChunks()
curSpace := f.MaxSize - totalChunksSize(chunks)
if curSpace >= requiredSpace {
return nil
}
neededSpace := requiredSpace - curSpace
if requiredSpace >= f.MaxSize || f.StartPos == FilePosEmpty {
f.StartPos = FilePosEmpty
f.EndPos = 0
f.FileOffset += neededSpace
} else {
f.StartPos = (f.StartPos + neededSpace) % f.MaxSize
f.FileOffset += neededSpace
}
return f.writeMeta()
}
// does not implement io.WriterAt (needs context)
func (f *File) WriteAt(ctx context.Context, buf []byte, writePos int64) error {
if writePos < 0 {
return fmt.Errorf("WriteAt got invalid writePos[%d]", writePos)
}
err := f.flock(ctx, syscall.LOCK_EX)
if err != nil {
return err
}
defer f.unflock()
err = f.readMeta()
if err != nil {
return err
}
chunks := f.getFileChunks()
currentSize := totalChunksSize(chunks)
if writePos < f.FileOffset {
negOffset := f.FileOffset - writePos
if negOffset >= int64(len(buf)) {
return nil
}
buf = buf[negOffset:]
writePos = f.FileOffset
}
if writePos > f.FileOffset+currentSize {
// fill gap with zero bytes
posOffset := writePos - (f.FileOffset + currentSize)
err = f.ensureFreeSpace(int64(posOffset))
if err != nil {
return err
}
var zeroBuf []byte
if posOffset >= f.MaxSize {
zeroBuf = make([]byte, f.MaxSize)
} else {
zeroBuf = make([]byte, posOffset)
}
err = f.internalAppendData(zeroBuf)
if err != nil {
return err
}
// recalc chunks/currentSize
chunks = f.getFileChunks()
currentSize = totalChunksSize(chunks)
// after writing the zero bytes, writePos == f.FileOffset+currentSize (the rest is a straight append)
}
// now writePos >= f.FileOffset && writePos <= f.FileOffset+currentSize (check invariant)
if writePos < f.FileOffset || writePos > f.FileOffset+currentSize {
panic(fmt.Sprintf("invalid writePos, invariant violated writepos[%d] fileoffset[%d] currentsize[%d]", writePos, f.FileOffset, currentSize))
}
// overwrite existing data (in chunks). advance by writePosOffset
writePosOffset := writePos - f.FileOffset
if writePosOffset < currentSize {
advChunks := advanceChunks(chunks, writePosOffset)
nw, err := f.writeToChunks(buf, advChunks, false)
if err != nil {
return err
}
buf = buf[nw:]
if len(buf) == 0 {
return nil
}
}
// buf contains what was unwritten. this unwritten data is now just a straight append
return f.internalAppendData(buf)
}
// try writing to chunks, returns (nw, error)
func (f *File) writeToChunks(buf []byte, chunks []fileChunk, updatePos bool) (int64, error) {
var numWrite int64
for _, chunk := range chunks {
if numWrite >= int64(len(buf)) {
break
}
if chunk.Len == 0 {
continue
}
toWrite := int64(len(buf)) - numWrite
if toWrite > chunk.Len {
toWrite = chunk.Len
}
nw, err := f.OSFile.WriteAt(buf[numWrite:numWrite+toWrite], chunk.StartPos+HeaderLen)
if err != nil {
return 0, err
}
if updatePos {
if chunk.StartPos+int64(nw) > f.FileDataSize {
f.FileDataSize = chunk.StartPos + int64(nw)
}
if f.StartPos == FilePosEmpty {
f.StartPos = chunk.StartPos
}
f.EndPos = chunk.StartPos + int64(nw) - 1
}
numWrite += int64(nw)
}
return numWrite, nil
}
func (f *File) internalAppendData(buf []byte) error {
err := f.ensureFreeSpace(int64(len(buf)))
if err != nil {
return err
}
if len(buf) >= int(f.MaxSize) {
buf = buf[len(buf)-int(f.MaxSize):]
}
chunks := f.getFreeChunks()
// don't track nw because we know we have enough free space to write entire buf
_, err = f.writeToChunks(buf, chunks, true)
if err != nil {
return err
}
err = f.writeMeta()
if err != nil {
return err
}
return nil
}
func (f *File) AppendData(ctx context.Context, buf []byte) error {
err := f.flock(ctx, syscall.LOCK_EX)
if err != nil {
return err
}
defer f.unflock()
err = f.readMeta()
if err != nil {
return err
}
return f.internalAppendData(buf)
}

View File

@ -0,0 +1,282 @@
package cirfile
import (
"context"
"fmt"
"os"
"path"
"strings"
"syscall"
"testing"
"time"
)
func validateFileSize(t *testing.T, name string, size int) {
finfo, err := os.Stat(name)
if err != nil {
t.Fatalf("error stating file[%s]: %v", name, err)
}
if int(finfo.Size()) != size {
t.Fatalf("invalid file[%s] expected[%d] got[%d]", name, size, finfo.Size())
}
}
func validateMeta(t *testing.T, desc string, f *File, startPos int64, endPos int64, dataSize int64, offset int64) {
if f.StartPos != startPos || f.EndPos != endPos || f.FileDataSize != dataSize || f.FileOffset != offset {
t.Fatalf("metadata error (%s): startpos[%d %d] endpos[%d %d] filedatasize[%d %d] fileoffset[%d %d]", desc, f.StartPos, startPos, f.EndPos, endPos, f.FileDataSize, dataSize, f.FileOffset, offset)
}
}
func dumpFile(name string) {
barr, _ := os.ReadFile(name)
str := string(barr)
str = strings.ReplaceAll(str, "\x00", ".")
fmt.Printf("%s<<<\n%s\n>>>\n", name, str)
}
func makeData(size int) string {
var rtn string
for {
if len(rtn) >= size {
break
}
needed := size - len(rtn)
if needed < 10 {
rtn += "123456789\n"[0:needed]
break
}
rtn += "123456789\n"
}
return rtn
}
func TestCreate(t *testing.T) {
tempDir := t.TempDir()
f1Name := path.Join(tempDir, "f1.cf")
f, err := OpenCirFile(f1Name)
if err == nil || f != nil {
t.Fatalf("OpenCirFile f1.cf should fail (no file)")
}
f, err = CreateCirFile(f1Name, 100)
if err != nil {
t.Fatalf("CreateCirFile f1.cf failed: %v", err)
}
if f == nil {
t.Fatalf("CreateCirFile f1.cf returned nil")
}
err = f.ReadMeta(context.Background())
if err != nil {
t.Fatalf("cannot readmeta from f1.cf: %v", err)
}
validateFileSize(t, f1Name, 256)
if f.Version != CurrentVersion || f.MaxSize != 100 || f.FileOffset != 0 || f.StartPos != FilePosEmpty || f.EndPos != 0 || f.FileDataSize != 0 || f.FlockStatus != 0 {
t.Fatalf("error with initial metadata #%v", f)
}
buf := make([]byte, 200)
realOffset, nr, err := f.ReadNext(context.Background(), buf, 0)
if realOffset != 0 || nr != 0 || err != nil {
t.Fatalf("error with empty read: real-offset[%d] nr[%d] err[%v]", realOffset, nr, err)
}
realOffset, nr, err = f.ReadNext(context.Background(), buf, 1000)
if realOffset != 0 || nr != 0 || err != nil {
t.Fatalf("error with empty read: real-offset[%d] nr[%d] err[%v]", realOffset, nr, err)
}
f2, err := CreateCirFile(f1Name, 100)
if err == nil || f2 != nil {
t.Fatalf("should be an error to create duplicate CirFile")
}
}
func TestFile(t *testing.T) {
tempDir := t.TempDir()
f1Name := path.Join(tempDir, "f1.cf")
f, err := CreateCirFile(f1Name, 100)
if err != nil {
t.Fatalf("cannot create cirfile: %v", err)
}
err = f.AppendData(context.Background(), nil)
if err != nil {
t.Fatalf("cannot append data: %v", err)
}
validateFileSize(t, f1Name, HeaderLen)
validateMeta(t, "1", f, FilePosEmpty, 0, 0, 0)
err = f.AppendData(context.Background(), []byte("hello"))
if err != nil {
t.Fatalf("cannot append data: %v", err)
}
validateFileSize(t, f1Name, HeaderLen+5)
validateMeta(t, "2", f, 0, 4, 5, 0)
err = f.AppendData(context.Background(), []byte(" foo"))
if err != nil {
t.Fatalf("cannot append data: %v", err)
}
validateFileSize(t, f1Name, HeaderLen+9)
validateMeta(t, "3", f, 0, 8, 9, 0)
err = f.AppendData(context.Background(), []byte("\n"+makeData(20)))
if err != nil {
t.Fatalf("cannot append data: %v", err)
}
validateFileSize(t, f1Name, HeaderLen+30)
validateMeta(t, "4", f, 0, 29, 30, 0)
data120 := makeData(120)
err = f.AppendData(context.Background(), []byte(data120))
if err != nil {
t.Fatalf("cannot append data: %v", err)
}
validateFileSize(t, f1Name, HeaderLen+100)
validateMeta(t, "5", f, 0, 99, 100, 50)
err = f.AppendData(context.Background(), []byte("foo "))
if err != nil {
t.Fatalf("cannot append data: %v", err)
}
validateFileSize(t, f1Name, HeaderLen+100)
validateMeta(t, "6", f, 4, 3, 100, 54)
buf := make([]byte, 5)
realOffset, nr, err := f.ReadNext(context.Background(), buf, 0)
if err != nil {
t.Fatalf("cannot ReadNext: %v", err)
}
if realOffset != 54 {
t.Fatalf("wrong realoffset got[%d] expected[%d]", realOffset, 54)
}
if nr != 5 {
t.Fatalf("wrong nr got[%d] expected[%d]", nr, 5)
}
if string(buf[0:nr]) != "56789" {
t.Fatalf("wrong buf return got[%s] expected[%s]", string(buf[0:nr]), "56789")
}
realOffset, nr, err = f.ReadNext(context.Background(), buf, 60)
if err != nil {
t.Fatalf("cannot readnext: %v", err)
}
if realOffset != 60 && nr != 5 {
t.Fatalf("invalid rtn realoffset[%d] nr[%d]", realOffset, nr)
}
if string(buf[0:nr]) != "12345" {
t.Fatalf("invalid rtn buf[%s]", string(buf[0:nr]))
}
realOffset, nr, err = f.ReadNext(context.Background(), buf, 800)
if err != nil || realOffset != 154 || nr != 0 {
t.Fatalf("invalid past end read: err[%v] realoffset[%d] nr[%d]", err, realOffset, nr)
}
realOffset, nr, err = f.ReadNext(context.Background(), buf, 150)
if err != nil || realOffset != 150 || nr != 4 || string(buf[0:nr]) != "foo " {
t.Fatalf("invalid end read: err[%v] realoffset[%d] nr[%d] buf[%s]", err, realOffset, nr, string(buf[0:nr]))
}
}
func TestFlock(t *testing.T) {
tempDir := t.TempDir()
f1Name := path.Join(tempDir, "f1.cf")
f, err := CreateCirFile(f1Name, 100)
if err != nil {
t.Fatalf("cannot create cirfile: %v", err)
}
fd2, err := os.OpenFile(f1Name, os.O_RDWR, 0777)
if err != nil {
t.Fatalf("cannot open file: %v", err)
}
err = syscall.Flock(int(fd2.Fd()), syscall.LOCK_EX)
if err != nil {
t.Fatalf("cannot lock fd: %v", err)
}
err = f.AppendData(nil, []byte("hello"))
if err != syscall.EWOULDBLOCK {
t.Fatalf("append should fail with EWOULDBLOCK")
}
timeoutCtx, _ := context.WithTimeout(context.Background(), 20*time.Millisecond)
startTs := time.Now()
err = f.ReadMeta(timeoutCtx)
if err != context.DeadlineExceeded {
t.Fatalf("readmeta should fail with context.DeadlineExceeded")
}
dur := time.Now().Sub(startTs)
if dur < 20*time.Millisecond {
t.Fatalf("readmeta should take at least 20ms")
}
syscall.Flock(int(fd2.Fd()), syscall.LOCK_UN)
err = f.ReadMeta(timeoutCtx)
if err != nil {
t.Fatalf("readmeta err: %v", err)
}
err = syscall.Flock(int(fd2.Fd()), syscall.LOCK_SH)
if err != nil {
t.Fatalf("cannot flock: %v", err)
}
err = f.AppendData(nil, []byte("hello"))
if err != syscall.EWOULDBLOCK {
t.Fatalf("append should fail with EWOULDBLOCK")
}
err = f.ReadMeta(timeoutCtx)
if err != nil {
t.Fatalf("readmeta err (should work because LOCK_SH): %v", err)
}
fd2.Close()
err = f.AppendData(nil, []byte("hello"))
if err != nil {
t.Fatalf("append error (should work fd2 was closed): %v", err)
}
}
func TestWriteAt(t *testing.T) {
tempDir := t.TempDir()
f1Name := path.Join(tempDir, "f1.cf")
f, err := CreateCirFile(f1Name, 100)
if err != nil {
t.Fatalf("cannot create cirfile: %v", err)
}
err = f.WriteAt(nil, []byte("hello\nmike"), 4)
if err != nil {
t.Fatalf("writeat error: %v", err)
}
err = f.WriteAt(nil, []byte("t"), 2)
if err != nil {
t.Fatalf("writeat error: %v", err)
}
err = f.WriteAt(nil, []byte("more"), 30)
if err != nil {
t.Fatalf("writeat error: %v", err)
}
err = f.WriteAt(nil, []byte("\n"), 19)
if err != nil {
t.Fatalf("writeat error: %v", err)
}
dumpFile(f1Name)
err = f.WriteAt(nil, []byte("hello"), 200)
if err != nil {
t.Fatalf("writeat error: %v", err)
}
buf := make([]byte, 10)
realOffset, nr, err := f.ReadNext(context.Background(), buf, 200)
if err != nil || realOffset != 200 || nr != 5 || string(buf[0:nr]) != "hello" {
t.Fatalf("invalid readnext: err[%v] realoffset[%d] nr[%d] buf[%s]", err, realOffset, nr, string(buf[0:nr]))
}
err = f.WriteAt(nil, []byte("0123456789\n"), 100)
if err != nil {
t.Fatalf("writeat error: %v", err)
}
dumpFile(f1Name)
dataStr := makeData(200)
err = f.WriteAt(nil, []byte(dataStr), 50)
if err != nil {
t.Fatalf("writeat error: %v", err)
}
dumpFile(f1Name)
dataStr = makeData(1000)
err = f.WriteAt(nil, []byte(dataStr), 1002)
if err != nil {
t.Fatalf("writeat error: %v", err)
}
err = f.WriteAt(nil, []byte("hello\n"), 2010)
if err != nil {
t.Fatalf("writeat error: %v", err)
}
err = f.AppendData(nil, []byte("foo\n"))
if err != nil {
t.Fatalf("appenddata error: %v", err)
}
dumpFile(f1Name)
}

View File

@ -0,0 +1,471 @@
package cmdtail
import (
"encoding/base64"
"fmt"
"io"
"os"
"regexp"
"sync"
"time"
"github.com/fsnotify/fsnotify"
"github.com/commandlinedev/apishell/pkg/base"
"github.com/commandlinedev/apishell/pkg/packet"
)
const MaxDataBytes = 4096
const FileTypePty = "ptyout"
const FileTypeRun = "runout"
type Tailer struct {
Lock *sync.Mutex
WatchList map[base.CommandKey]CmdWatchEntry
Watcher *fsnotify.Watcher
Sender *packet.PacketSender
Gen FileNameGenerator
Sessions map[string]bool
}
type TailPos struct {
ReqId string
Running bool // an active tailer sending data
TailPtyPos int64
TailRunPos int64
Follow bool
}
type CmdWatchEntry struct {
CmdKey base.CommandKey
FilePtyLen int64
FileRunLen int64
Tails []TailPos
Done bool
}
type FileNameGenerator interface {
PtyOutFile(ck base.CommandKey) string
RunOutFile(ck base.CommandKey) string
SessionDir(sessionId string) string
}
func (w CmdWatchEntry) getTailPos(reqId string) (TailPos, bool) {
for _, pos := range w.Tails {
if pos.ReqId == reqId {
return pos, true
}
}
return TailPos{}, false
}
func (w *CmdWatchEntry) updateTailPos(reqId string, newPos TailPos) {
for idx, pos := range w.Tails {
if pos.ReqId == reqId {
w.Tails[idx] = newPos
return
}
}
w.Tails = append(w.Tails, newPos)
}
func (w *CmdWatchEntry) removeTailPos(reqId string) {
var newTails []TailPos
for _, pos := range w.Tails {
if pos.ReqId == reqId {
continue
}
newTails = append(newTails, pos)
}
w.Tails = newTails
}
func (pos TailPos) IsCurrent(entry CmdWatchEntry) bool {
return pos.TailPtyPos >= entry.FilePtyLen && pos.TailRunPos >= entry.FileRunLen
}
func (t *Tailer) updateTailPos_nolock(cmdKey base.CommandKey, reqId string, pos TailPos) {
entry, found := t.WatchList[cmdKey]
if !found {
return
}
entry.updateTailPos(reqId, pos)
t.WatchList[cmdKey] = entry
}
func (t *Tailer) removeTailPos(cmdKey base.CommandKey, reqId string) {
t.Lock.Lock()
defer t.Lock.Unlock()
t.removeTailPos_nolock(cmdKey, reqId)
}
func (t *Tailer) removeTailPos_nolock(cmdKey base.CommandKey, reqId string) {
entry, found := t.WatchList[cmdKey]
if !found {
return
}
entry.removeTailPos(reqId)
t.WatchList[cmdKey] = entry
if len(entry.Tails) == 0 {
t.removeWatch_nolock(cmdKey)
}
}
func (t *Tailer) removeWatch_nolock(cmdKey base.CommandKey) {
// delete from watchlist, remove watches
delete(t.WatchList, cmdKey)
t.Watcher.Remove(t.Gen.PtyOutFile(cmdKey))
t.Watcher.Remove(t.Gen.RunOutFile(cmdKey))
}
func (t *Tailer) getEntryAndPos_nolock(cmdKey base.CommandKey, reqId string) (CmdWatchEntry, TailPos, bool) {
entry, found := t.WatchList[cmdKey]
if !found {
return CmdWatchEntry{}, TailPos{}, false
}
pos, found := entry.getTailPos(reqId)
if !found {
return CmdWatchEntry{}, TailPos{}, false
}
return entry, pos, true
}
func (t *Tailer) addSessionWatcher(sessionId string) error {
t.Lock.Lock()
defer t.Lock.Unlock()
if t.Sessions[sessionId] {
return
}
sdir := t.Gen.SessionDir(sessionId)
err := t.Watcher.Add(sdir)
if err != nil {
return err
}
t.Sessions[sessionId] = true
return nil
}
func (t *Tailer) removeSessionWatcher(sessionId string) {
t.Lock.Lock()
defer t.Lock.Unlock()
if !t.Sessions[sessionId] {
return
}
sdir := t.Gen.SessionDir(sessionId)
t.Watcher.Remove(sdir)
}
func MakeTailer(sender *packet.PacketSender, gen FileNameGenerator) (*Tailer, error) {
rtn := &Tailer{
Lock: &sync.Mutex{},
WatchList: make(map[base.CommandKey]CmdWatchEntry),
Sessions: make(map[string]bool),
Sender: sender,
Gen: gen,
}
var err error
rtn.Watcher, err = fsnotify.NewWatcher()
if err != nil {
return nil, err
}
return rtn, nil
}
func (t *Tailer) readDataFromFile(fileName string, pos int64, maxBytes int) ([]byte, error) {
fd, err := os.Open(fileName)
defer fd.Close()
if err != nil {
return nil, err
}
buf := make([]byte, maxBytes)
nr, err := fd.ReadAt(buf, pos)
if err != nil && err != io.EOF { // ignore EOF error
return nil, err
}
return buf[0:nr], nil
}
func (t *Tailer) makeCmdDataPacket(entry CmdWatchEntry, pos TailPos) (*packet.CmdDataPacketType, error) {
dataPacket := packet.MakeCmdDataPacket(pos.ReqId)
dataPacket.CK = entry.CmdKey
dataPacket.PtyPos = pos.TailPtyPos
dataPacket.RunPos = pos.TailRunPos
if entry.FilePtyLen > pos.TailPtyPos {
ptyData, err := t.readDataFromFile(t.Gen.PtyOutFile(entry.CmdKey), pos.TailPtyPos, MaxDataBytes)
if err != nil {
return nil, err
}
dataPacket.PtyData64 = base64.StdEncoding.EncodeToString(ptyData)
dataPacket.PtyDataLen = len(ptyData)
}
if entry.FileRunLen > pos.TailRunPos {
runData, err := t.readDataFromFile(t.Gen.RunOutFile(entry.CmdKey), pos.TailRunPos, MaxDataBytes)
if err != nil {
return nil, err
}
dataPacket.RunData64 = base64.StdEncoding.EncodeToString(runData)
dataPacket.RunDataLen = len(runData)
}
return dataPacket, nil
}
// returns (data-packet, keepRunning)
func (t *Tailer) runSingleDataTransfer(key base.CommandKey, reqId string) (*packet.CmdDataPacketType, bool, error) {
t.Lock.Lock()
entry, pos, foundPos := t.getEntryAndPos_nolock(key, reqId)
t.Lock.Unlock()
if !foundPos {
return nil, false, nil
}
dataPacket, dataErr := t.makeCmdDataPacket(entry, pos)
t.Lock.Lock()
defer t.Lock.Unlock()
entry, pos, foundPos = t.getEntryAndPos_nolock(key, reqId)
if !foundPos {
return nil, false, nil
}
// pos was updated between first and second get, throw out data-packet and re-run
if pos.TailPtyPos != dataPacket.PtyPos || pos.TailRunPos != dataPacket.RunPos {
return nil, true, nil
}
if dataErr != nil {
// error, so return error packet, and stop running
pos.Running = false
t.updateTailPos_nolock(key, reqId, pos)
return nil, false, dataErr
}
pos.TailPtyPos += int64(dataPacket.PtyDataLen)
pos.TailRunPos += int64(dataPacket.RunDataLen)
if pos.IsCurrent(entry) {
// we caught up, tail position equals file length
pos.Running = false
}
t.updateTailPos_nolock(key, reqId, pos)
return dataPacket, pos.Running, nil
}
// returns (removed)
func (t *Tailer) checkRemove(cmdKey base.CommandKey, reqId string) bool {
t.Lock.Lock()
defer t.Lock.Unlock()
entry, pos, foundPos := t.getEntryAndPos_nolock(cmdKey, reqId)
if !foundPos {
return false
}
if !pos.IsCurrent(entry) {
return false
}
if !pos.Follow || entry.Done {
t.removeTailPos_nolock(cmdKey, reqId)
return true
}
return false
}
func (t *Tailer) RunDataTransfer(key base.CommandKey, reqId string) {
for {
dataPacket, keepRunning, err := t.runSingleDataTransfer(key, reqId)
if dataPacket != nil {
t.Sender.SendPacket(dataPacket)
}
if err != nil {
t.removeTailPos(key, reqId)
t.Sender.SendErrorResponse(reqId, err)
break
}
if !keepRunning {
removed := t.checkRemove(key, reqId)
if removed {
t.Sender.SendResponse(reqId, true)
}
break
}
time.Sleep(10 * time.Millisecond)
}
}
func (t *Tailer) tryStartRun_nolock(entry CmdWatchEntry, pos TailPos) {
if pos.Running {
return
}
if pos.IsCurrent(entry) {
return
}
pos.Running = true
t.updateTailPos_nolock(entry.CmdKey, pos.ReqId, pos)
go t.RunDataTransfer(entry.CmdKey, pos.ReqId)
}
var updateFileRe = regexp.MustCompile("/([a-z0-9-]+)/([a-z0-9-]+)\\.(ptyout|runout)$")
func (t *Tailer) updateFile(relFileName string) {
m := updateFileRe.FindStringSubmatch(relFileName)
if m == nil {
return
}
finfo, err := os.Stat(relFileName)
if err != nil {
t.Sender.SendPacket(packet.FmtMessagePacket("error trying to stat file '%s': %v", relFileName, err))
return
}
cmdKey := base.MakeCommandKey(m[1], m[2])
t.Lock.Lock()
defer t.Lock.Unlock()
entry, foundEntry := t.WatchList[cmdKey]
if !foundEntry {
return
}
fileType := m[3]
if fileType == FileTypePty {
entry.FilePtyLen = finfo.Size()
} else if fileType == FileTypeRun {
entry.FileRunLen = finfo.Size()
}
t.WatchList[cmdKey] = entry
for _, pos := range entry.Tails {
t.tryStartRun_nolock(entry, pos)
}
}
func (t *Tailer) Run() {
for {
select {
case event, ok := <-t.Watcher.Events:
if !ok {
return
}
if event.Op&fsnotify.Write == fsnotify.Write {
t.updateFile(event.Name)
}
case err, ok := <-t.Watcher.Errors:
if !ok {
return
}
// what to do with this error? just send a message
t.Sender.SendPacket(packet.FmtMessagePacket("error in tailer: %v", err))
}
}
return
}
func (t *Tailer) Close() error {
return t.Watcher.Close()
}
func max(v1 int64, v2 int64) int64 {
if v1 > v2 {
return v1
}
return v2
}
func (entry *CmdWatchEntry) fillFilePos(gen FileNameGenerator) {
ptyInfo, _ := os.Stat(gen.PtyOutFile(entry.CmdKey))
if ptyInfo != nil {
entry.FilePtyLen = ptyInfo.Size()
}
runoutInfo, _ := os.Stat(gen.RunOutFile(entry.CmdKey))
if runoutInfo != nil {
entry.FileRunLen = runoutInfo.Size()
}
}
func (t *Tailer) KeyDone(key base.CommandKey) {
t.Lock.Lock()
defer t.Lock.Unlock()
entry, foundEntry := t.WatchList[key]
if !foundEntry {
return
}
entry.Done = true
var newTails []TailPos
for _, pos := range entry.Tails {
if pos.IsCurrent(entry) {
continue
}
newTails = append(newTails, pos)
}
entry.Tails = newTails
t.WatchList[key] = entry
if len(entry.Tails) == 0 {
t.removeWatch_nolock(key)
}
t.WatchList[key] = entry
}
func (t *Tailer) RemoveWatch(pk *packet.UntailCmdPacketType) {
t.Lock.Lock()
defer t.Lock.Unlock()
t.removeTailPos_nolock(pk.CK, pk.ReqId)
}
func (t *Tailer) AddFileWatches_nolock(key base.CommandKey, ptyOnly bool) error {
ptyName := t.Gen.PtyOutFile(key)
runName := t.Gen.RunOutFile(key)
fmt.Printf("WATCH> add %s\n", ptyName)
err := t.Watcher.Add(ptyName)
if err != nil {
return err
}
if ptyOnly {
return nil
}
err = t.Watcher.Add(runName)
if err != nil {
t.Watcher.Remove(ptyName) // best effort clean up
return err
}
return nil
}
// returns (up-to-date/done, error)
func (t *Tailer) AddWatch(getPacket *packet.GetCmdPacketType) (bool, error) {
if err := getPacket.CK.Validate("getcmd"); err != nil {
return false, err
}
if getPacket.ReqId == "" {
return false, fmt.Errorf("getcmd, no reqid specified")
}
t.Lock.Lock()
defer t.Lock.Unlock()
key := getPacket.CK
entry, foundEntry := t.WatchList[key]
if !foundEntry {
// initialize entry, add watches
entry = CmdWatchEntry{CmdKey: key}
entry.fillFilePos(t.Gen)
}
pos, foundPos := entry.getTailPos(getPacket.ReqId)
if !foundPos {
// initialize a new tailpos
pos = TailPos{ReqId: getPacket.ReqId}
}
// update tailpos with new values from getpacket
pos.TailPtyPos = getPacket.PtyPos
pos.TailRunPos = getPacket.RunPos
pos.Follow = getPacket.Tail
// convert negative pos to positive
if pos.TailPtyPos < 0 {
pos.TailPtyPos = max(0, entry.FilePtyLen+pos.TailPtyPos) // + because negative
}
if pos.TailRunPos < 0 {
pos.TailRunPos = max(0, entry.FileRunLen+pos.TailRunPos) // + because negative
}
entry.updateTailPos(pos.ReqId, pos)
if !pos.Follow && pos.IsCurrent(entry) {
// don't add to t.WatchList, don't t.AddFileWatches_nolock, send rpc response
return true, nil
}
if !foundEntry {
err := t.AddFileWatches_nolock(key, getPacket.PtyOnly)
if err != nil {
return false, err
}
}
t.WatchList[key] = entry
t.tryStartRun_nolock(entry, pos)
return false, nil
}

View File

@ -0,0 +1,144 @@
package mpio
import (
"io"
"sync"
"github.com/commandlinedev/apishell/pkg/packet"
)
type FdReader struct {
CVar *sync.Cond
M *Multiplexer
FdNum int
Fd io.ReadCloser
BufSize int
Closed bool
ShouldCloseFd bool
IsPty bool
}
func MakeFdReader(m *Multiplexer, fd io.ReadCloser, fdNum int, shouldCloseFd bool, isPty bool) *FdReader {
fr := &FdReader{
CVar: sync.NewCond(&sync.Mutex{}),
M: m,
FdNum: fdNum,
Fd: fd,
BufSize: 0,
ShouldCloseFd: shouldCloseFd,
IsPty: isPty,
}
return fr
}
func (r *FdReader) Close() {
r.CVar.L.Lock()
defer r.CVar.L.Unlock()
if r.Closed {
return
}
if r.Fd != nil && r.ShouldCloseFd {
r.Fd.Close()
}
r.CVar.Broadcast()
}
func (r *FdReader) GetBufSize() int {
r.CVar.L.Lock()
defer r.CVar.L.Unlock()
return r.BufSize
}
func (r *FdReader) NotifyAck(ackLen int) {
r.CVar.L.Lock()
defer r.CVar.L.Unlock()
if r.Closed {
return
}
r.BufSize -= ackLen
if r.BufSize < 0 {
r.BufSize = 0
}
r.CVar.Broadcast()
}
// !! inverse locking. must already hold the lock when you call this method.
// will *unlock*, send the packet, and then *relock* once it is done.
// this can prevent an unlikely deadlock where we are holding r.CVar.L and stuck on sender.SendCh
func (r *FdReader) sendPacket_unlock(pk packet.PacketType) {
r.CVar.L.Unlock()
defer r.CVar.L.Lock()
r.M.sendPacket(pk)
}
// returns (success)
func (r *FdReader) WriteWait(data []byte, isEof bool) bool {
r.CVar.L.Lock()
defer r.CVar.L.Unlock()
for {
bufAvail := ReadBufSize - r.BufSize
if r.Closed {
return false
}
if bufAvail == 0 {
r.CVar.Wait()
continue
}
writeLen := min(bufAvail, len(data))
pk := r.M.makeDataPacket(r.FdNum, data[0:writeLen], nil)
pk.Eof = isEof && (writeLen == len(data))
r.BufSize += writeLen
data = data[writeLen:]
r.sendPacket_unlock(pk)
if len(data) == 0 {
return true
}
// do *not* do a CVar.Wait() here -- because we *unlocked* to send the packet, we should
// recheck the condition before waiting to avoid deadlock.
}
}
func min(v1 int, v2 int) int {
if v1 <= v2 {
return v1
}
return v2
}
func (r *FdReader) isClosed() bool {
r.CVar.L.Lock()
defer r.CVar.L.Unlock()
return r.Closed
}
func (r *FdReader) ReadLoop(wg *sync.WaitGroup) {
defer r.Close()
if wg != nil {
defer wg.Done()
}
buf := make([]byte, 4096)
for {
nr, err := r.Fd.Read(buf)
if r.isClosed() {
return // should not send data or error if we already closed the fd
}
if nr > 0 || err == io.EOF {
isOpen := r.WriteWait(buf[0:nr], (err == io.EOF))
if !isOpen {
return
}
if err == io.EOF {
return
}
}
if err != nil {
if r.IsPty {
r.WriteWait(nil, true)
return
}
errPk := r.M.makeDataPacket(r.FdNum, nil, err)
r.M.sendPacket(errPk)
return
}
}
}

View File

@ -0,0 +1,112 @@
package mpio
import (
"fmt"
"io"
"sync"
)
type FdWriter struct {
CVar *sync.Cond
M *Multiplexer
FdNum int
Buffer []byte
BufferLimit int
Fd io.WriteCloser
Eof bool
Closed bool
ShouldCloseFd bool
Desc string
}
func MakeFdWriter(m *Multiplexer, fd io.WriteCloser, fdNum int, shouldCloseFd bool, desc string) *FdWriter {
fw := &FdWriter{
CVar: sync.NewCond(&sync.Mutex{}),
Fd: fd,
M: m,
FdNum: fdNum,
ShouldCloseFd: shouldCloseFd,
Desc: desc,
BufferLimit: WriteBufSize,
}
return fw
}
func (w *FdWriter) Close() {
w.CVar.L.Lock()
defer w.CVar.L.Unlock()
if w.Closed {
return
}
w.Closed = true
if w.Fd != nil && w.ShouldCloseFd {
w.Fd.Close()
}
w.Buffer = nil
w.CVar.Broadcast()
}
func (w *FdWriter) WaitForData() ([]byte, bool) {
w.CVar.L.Lock()
defer w.CVar.L.Unlock()
for {
if len(w.Buffer) > 0 || w.Eof || w.Closed {
toWrite := w.Buffer
w.Buffer = nil
return toWrite, w.Eof
}
w.CVar.Wait()
}
}
func (w *FdWriter) AddData(data []byte, eof bool) error {
w.CVar.L.Lock()
defer w.CVar.L.Unlock()
if w.Closed || w.Eof {
if len(data) == 0 {
return nil
}
return fmt.Errorf("write to closed file %q (fd:%d) eof[%v]", w.Desc, w.FdNum, w.Eof)
}
if len(data) > 0 {
if len(data)+len(w.Buffer) > w.BufferLimit {
return fmt.Errorf("write exceeds buffer size %q (fd:%d) bufsize=%d (max=%d)", w.Desc, w.FdNum, len(data)+len(w.Buffer), w.BufferLimit)
}
w.Buffer = append(w.Buffer, data...)
}
if eof {
w.Eof = true
}
w.CVar.Broadcast()
return nil
}
func (w *FdWriter) WriteLoop(wg *sync.WaitGroup) {
defer w.Close()
if wg != nil {
defer wg.Done()
}
for {
data, isEof := w.WaitForData()
// chunk the writes to make sure we send ample ack packets
for len(data) > 0 {
if w.Closed {
return
}
chunkSize := min(len(data), MaxSingleWriteSize)
chunk := data[0:chunkSize]
nw, err := w.Fd.Write(chunk)
if nw > 0 || err != nil {
ack := w.M.makeDataAckPacket(w.FdNum, nw, err)
w.M.sendPacket(ack)
}
if err != nil {
return
}
data = data[chunkSize:]
}
if isEof {
return
}
}
}

303
waveshell/pkg/mpio/mpio.go Normal file
View File

@ -0,0 +1,303 @@
package mpio
import (
"encoding/base64"
"fmt"
"io"
"os"
"sync"
"github.com/commandlinedev/apishell/pkg/base"
"github.com/commandlinedev/apishell/pkg/packet"
)
const ReadBufSize = 128 * 1024
const WriteBufSize = 128 * 1024
const MaxSingleWriteSize = 4 * 1024
const MaxTotalRunDataSize = 10 * ReadBufSize
type Multiplexer struct {
Lock *sync.Mutex
CK base.CommandKey
FdReaders map[int]*FdReader // synchronized
FdWriters map[int]*FdWriter // synchronized
RunData map[int]*FdReader // synchronized
CloseAfterStart []*os.File // synchronized
Sender *packet.PacketSender
Input *packet.PacketParser
Started bool
UPR packet.UnknownPacketReporter
Debug bool
}
func MakeMultiplexer(ck base.CommandKey, upr packet.UnknownPacketReporter) *Multiplexer {
if upr == nil {
upr = packet.DefaultUPR{}
}
return &Multiplexer{
Lock: &sync.Mutex{},
CK: ck,
FdReaders: make(map[int]*FdReader),
FdWriters: make(map[int]*FdWriter),
UPR: upr,
}
}
func (m *Multiplexer) Close() {
m.Lock.Lock()
defer m.Lock.Unlock()
for _, fr := range m.FdReaders {
fr.Close()
}
for _, fw := range m.FdWriters {
fw.Close()
}
for _, fd := range m.CloseAfterStart {
fd.Close()
}
}
func (m *Multiplexer) HandleInputDone() {
m.Lock.Lock()
defer m.Lock.Unlock()
// close readers (obviously the done command needs no more input)
for _, fr := range m.FdReaders {
fr.Close()
}
// ensure EOF on all writers (ignore error)
for _, fw := range m.FdWriters {
fw.AddData(nil, true)
}
}
// returns the *writer* to connect to process, reader is put in FdReaders
func (m *Multiplexer) MakeReaderPipe(fdNum int) (*os.File, error) {
pr, pw, err := os.Pipe()
if err != nil {
return nil, err
}
m.Lock.Lock()
defer m.Lock.Unlock()
m.FdReaders[fdNum] = MakeFdReader(m, pr, fdNum, true, false)
m.CloseAfterStart = append(m.CloseAfterStart, pw)
return pw, nil
}
// returns the *reader* to connect to process, writer is put in FdWriters
func (m *Multiplexer) MakeWriterPipe(fdNum int, desc string) (*os.File, error) {
pr, pw, err := os.Pipe()
if err != nil {
return nil, err
}
m.Lock.Lock()
defer m.Lock.Unlock()
m.FdWriters[fdNum] = MakeFdWriter(m, pw, fdNum, true, desc)
m.CloseAfterStart = append(m.CloseAfterStart, pr)
return pr, nil
}
// returns the *reader* to connect to process, writer is put in FdWriters
func (m *Multiplexer) MakeStaticWriterPipe(fdNum int, data []byte, bufferLimit int, desc string) (*os.File, error) {
pr, pw, err := os.Pipe()
if err != nil {
return nil, err
}
m.Lock.Lock()
defer m.Lock.Unlock()
fdWriter := MakeFdWriter(m, pw, fdNum, true, desc)
fdWriter.BufferLimit = bufferLimit
err = fdWriter.AddData(data, true)
if err != nil {
return nil, err
}
m.FdWriters[fdNum] = fdWriter
m.CloseAfterStart = append(m.CloseAfterStart, pr)
return pr, nil
}
func (m *Multiplexer) MakeRawFdReader(fdNum int, fd io.ReadCloser, shouldClose bool, isPty bool) {
m.Lock.Lock()
defer m.Lock.Unlock()
m.FdReaders[fdNum] = MakeFdReader(m, fd, fdNum, shouldClose, isPty)
}
func (m *Multiplexer) MakeRawFdWriter(fdNum int, fd io.WriteCloser, shouldClose bool, desc string) {
m.Lock.Lock()
defer m.Lock.Unlock()
m.FdWriters[fdNum] = MakeFdWriter(m, fd, fdNum, shouldClose, desc)
}
func (m *Multiplexer) makeDataAckPacket(fdNum int, ackLen int, err error) *packet.DataAckPacketType {
ack := packet.MakeDataAckPacket()
ack.CK = m.CK
ack.FdNum = fdNum
ack.AckLen = ackLen
if err != nil {
ack.Error = err.Error()
}
return ack
}
func (m *Multiplexer) makeDataPacket(fdNum int, data []byte, err error) *packet.DataPacketType {
pk := packet.MakeDataPacket()
pk.CK = m.CK
pk.FdNum = fdNum
pk.Data64 = base64.StdEncoding.EncodeToString(data)
if err != nil {
pk.Error = err.Error()
}
return pk
}
func (m *Multiplexer) sendPacket(p packet.PacketType) {
m.Sender.SendPacket(p)
}
func (m *Multiplexer) launchWriters(wg *sync.WaitGroup) {
m.Lock.Lock()
defer m.Lock.Unlock()
if wg != nil {
wg.Add(len(m.FdWriters))
}
for _, fw := range m.FdWriters {
go fw.WriteLoop(wg)
}
}
func (m *Multiplexer) launchReaders(wg *sync.WaitGroup) {
m.Lock.Lock()
defer m.Lock.Unlock()
if wg != nil {
wg.Add(len(m.FdReaders))
}
for _, fr := range m.FdReaders {
go fr.ReadLoop(wg)
}
}
func (m *Multiplexer) startIO(packetParser *packet.PacketParser, sender *packet.PacketSender) {
m.Lock.Lock()
defer m.Lock.Unlock()
if m.Started {
panic("Multiplexer is already running, cannot start again")
}
m.Input = packetParser
m.Sender = sender
m.Started = true
}
func (m *Multiplexer) runPacketInputLoop() *packet.CmdDonePacketType {
defer m.HandleInputDone()
for pk := range m.Input.MainCh {
if m.Debug {
fmt.Printf("PK-M> %s\n", packet.AsString(pk))
}
if pk.GetType() == packet.DataPacketStr {
dataPacket := pk.(*packet.DataPacketType)
err := m.processDataPacket(dataPacket)
if err != nil {
errPacket := m.makeDataAckPacket(dataPacket.FdNum, 0, err)
m.sendPacket(errPacket)
}
continue
}
if pk.GetType() == packet.DataAckPacketStr {
ackPacket := pk.(*packet.DataAckPacketType)
m.processAckPacket(ackPacket)
continue
}
if pk.GetType() == packet.CmdDonePacketStr {
donePacket := pk.(*packet.CmdDonePacketType)
return donePacket
}
m.UPR.UnknownPacket(pk)
}
return nil
}
func (m *Multiplexer) WriteDataToFd(fdNum int, data []byte, isEof bool) error {
m.Lock.Lock()
defer m.Lock.Unlock()
fw := m.FdWriters[fdNum]
if fw == nil {
// add a closed FdWriter as a placeholder so we only send one error
fw := MakeFdWriter(m, nil, fdNum, false, "invalid-fd")
fw.Close()
m.FdWriters[fdNum] = fw
return fmt.Errorf("write to closed file (no fd)")
}
err := fw.AddData(data, isEof)
if err != nil {
fw.Close()
return err
}
return nil
}
func (m *Multiplexer) processDataPacket(dataPacket *packet.DataPacketType) error {
realData, err := base64.StdEncoding.DecodeString(dataPacket.Data64)
if err != nil {
return fmt.Errorf("decoding base64 data: %w", err)
}
return m.WriteDataToFd(dataPacket.FdNum, realData, dataPacket.Eof)
}
func (m *Multiplexer) processAckPacket(ackPacket *packet.DataAckPacketType) {
m.Lock.Lock()
defer m.Lock.Unlock()
fr := m.FdReaders[ackPacket.FdNum]
if fr == nil {
return
}
fr.NotifyAck(ackPacket.AckLen)
}
func (m *Multiplexer) closeTempStartFds() {
m.Lock.Lock()
defer m.Lock.Unlock()
for _, fd := range m.CloseAfterStart {
fd.Close()
}
m.CloseAfterStart = nil
}
func (m *Multiplexer) RunIOAndWait(packetParser *packet.PacketParser, sender *packet.PacketSender, waitOnReaders bool, waitOnWriters bool, waitForInputLoop bool) *packet.CmdDonePacketType {
m.startIO(packetParser, sender)
m.closeTempStartFds()
var wg sync.WaitGroup
if waitOnReaders {
m.launchReaders(&wg)
} else {
m.launchReaders(nil)
}
if waitOnWriters {
m.launchWriters(&wg)
} else {
m.launchWriters(nil)
}
var donePacket *packet.CmdDonePacketType
if waitForInputLoop {
wg.Add(1)
}
go func() {
if waitForInputLoop {
defer wg.Done()
}
pkRtn := m.runPacketInputLoop()
if pkRtn != nil {
m.Lock.Lock()
donePacket = pkRtn
m.Lock.Unlock()
}
}()
wg.Wait()
m.Lock.Lock()
defer m.Lock.Unlock()
return donePacket
}

View File

@ -0,0 +1,34 @@
package packet
type CombinedPacket struct {
Type string `json:"type"`
Success bool `json:"success"`
Ts int64 `json:"ts"`
Id string `json:"id,omitempty"`
SessionId string `json:"sessionid"`
CmdId string `json:"cmdid"`
PtyPos int64 `json:"ptypos"`
PtyLen int64 `json:"ptylen"`
RunPos int64 `json:"runpos"`
RunLen int64 `json:"runlen"`
Error string `json:"error"`
NotFound bool `json:"notfound,omitempty"`
Tail bool `json:"tail,omitempty"`
Dir string `json:"dir"`
ChDir string `json:"chdir,omitempty"`
Data string `json:"data"`
PtyData string `json:"ptydata"`
RunData string `json:"rundata"`
Message string `json:"message"`
Command string `json:"command"`
ScHomeDir string `json:"schomedir"`
HomeDir string `json:"homedir"`
Env []string `json:"env"`
ExitCode int `json:"exitcode"`
RunnerPid int `json:"runnerpid"`
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,238 @@
package packet
import (
"bufio"
"context"
"io"
"strconv"
"strings"
"sync"
)
type PacketParser struct {
Lock *sync.Mutex
MainCh chan PacketType
RpcMap map[string]*RpcEntry
RpcHandler bool
Err error
}
type RpcEntry struct {
ReqId string
RespCh chan RpcResponsePacketType
}
type RpcResponseIter struct {
ReqId string
Parser *PacketParser
}
func (iter *RpcResponseIter) Next(ctx context.Context) (RpcResponsePacketType, error) {
// will unregister the rpc on ResponseDone
return iter.Parser.GetNextResponse(ctx, iter.ReqId)
}
func (iter *RpcResponseIter) Close() {
iter.Parser.UnRegisterRpc(iter.ReqId)
}
func CombinePacketParsers(p1 *PacketParser, p2 *PacketParser, rpcHandler bool) *PacketParser {
rtnParser := &PacketParser{
Lock: &sync.Mutex{},
MainCh: make(chan PacketType),
RpcMap: make(map[string]*RpcEntry),
RpcHandler: rpcHandler,
}
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
for pk := range p1.MainCh {
if rtnParser.RpcHandler {
sent := rtnParser.trySendRpcResponse(pk)
if sent {
continue
}
}
rtnParser.MainCh <- pk
}
}()
go func() {
defer wg.Done()
for pk := range p2.MainCh {
if rtnParser.RpcHandler {
sent := rtnParser.trySendRpcResponse(pk)
if sent {
continue
}
}
rtnParser.MainCh <- pk
}
}()
go func() {
wg.Wait()
close(rtnParser.MainCh)
}()
return rtnParser
}
// should have already registered rpc
func (p *PacketParser) WaitForResponse(ctx context.Context, reqId string) RpcResponsePacketType {
entry := p.getRpcEntry(reqId)
if entry == nil {
return nil
}
defer p.UnRegisterRpc(reqId)
select {
case resp := <-entry.RespCh:
return resp
case <-ctx.Done():
return nil
}
}
func (p *PacketParser) GetResponseIter(reqId string) *RpcResponseIter {
return &RpcResponseIter{Parser: p, ReqId: reqId}
}
func (p *PacketParser) GetNextResponse(ctx context.Context, reqId string) (RpcResponsePacketType, error) {
entry := p.getRpcEntry(reqId)
if entry == nil {
return nil, nil
}
select {
case resp := <-entry.RespCh:
if resp.GetResponseDone() {
p.UnRegisterRpc(reqId)
}
return resp, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
func (p *PacketParser) UnRegisterRpc(reqId string) {
p.Lock.Lock()
defer p.Lock.Unlock()
entry := p.RpcMap[reqId]
if entry != nil {
close(entry.RespCh)
delete(p.RpcMap, reqId)
}
}
func (p *PacketParser) RegisterRpc(reqId string) chan RpcResponsePacketType {
return p.RegisterRpcSz(reqId, 2)
}
func (p *PacketParser) RegisterRpcSz(reqId string, queueSize int) chan RpcResponsePacketType {
p.Lock.Lock()
defer p.Lock.Unlock()
ch := make(chan RpcResponsePacketType, queueSize)
entry := &RpcEntry{ReqId: reqId, RespCh: ch}
p.RpcMap[reqId] = entry
return ch
}
func (p *PacketParser) getRpcEntry(reqId string) *RpcEntry {
p.Lock.Lock()
defer p.Lock.Unlock()
entry := p.RpcMap[reqId]
return entry
}
func (p *PacketParser) trySendRpcResponse(pk PacketType) bool {
respPk, ok := pk.(RpcResponsePacketType)
if !ok {
return false
}
p.Lock.Lock()
defer p.Lock.Unlock()
entry := p.RpcMap[respPk.GetResponseId()]
if entry == nil {
return false
}
// nonblocking send
select {
case entry.RespCh <- respPk:
default:
}
return true
}
func (p *PacketParser) GetErr() error {
p.Lock.Lock()
defer p.Lock.Unlock()
return p.Err
}
func (p *PacketParser) SetErr(err error) {
p.Lock.Lock()
defer p.Lock.Unlock()
if p.Err == nil {
p.Err = err
}
}
func MakePacketParser(input io.Reader, rpcHandler bool) *PacketParser {
parser := &PacketParser{
Lock: &sync.Mutex{},
MainCh: make(chan PacketType),
RpcMap: make(map[string]*RpcEntry),
RpcHandler: rpcHandler,
}
bufReader := bufio.NewReader(input)
go func() {
defer func() {
close(parser.MainCh)
}()
for {
line, err := bufReader.ReadString('\n')
if err == io.EOF {
return
}
if err != nil {
parser.SetErr(err)
return
}
if line == "\n" {
continue
}
// ##[len][json]\n
// ##14{"hello":true}\n
// ##N{...}
bracePos := strings.Index(line, "{")
if !strings.HasPrefix(line, "##") || bracePos == -1 {
parser.MainCh <- MakeRawPacket(line[:len(line)-1])
continue
}
packetLen := -1
if line[2:bracePos] != "N" {
packetLen, err = strconv.Atoi(line[2:bracePos])
if err != nil || packetLen != len(line)-bracePos-1 {
parser.MainCh <- MakeRawPacket(line[:len(line)-1])
continue
}
}
pk, err := ParseJsonPacket([]byte(line[bracePos:]))
if err != nil {
parser.MainCh <- MakeRawPacket(line[:len(line)-1])
continue
}
if pk.GetType() == DonePacketStr {
return
}
if pk.GetType() == PingPacketStr {
continue
}
if parser.RpcHandler {
sent := parser.trySendRpcResponse(pk)
if sent {
continue
}
}
parser.MainCh <- pk
}
}()
return parser
}

View File

@ -0,0 +1,192 @@
package packet
import (
"bytes"
"crypto/sha1"
"encoding/base64"
"encoding/json"
"fmt"
"github.com/commandlinedev/apishell/pkg/binpack"
"github.com/commandlinedev/apishell/pkg/statediff"
)
const ShellStatePackVersion = 0
const ShellStateDiffPackVersion = 0
type ShellState struct {
Version string `json:"version"` // [type] [semver]
Cwd string `json:"cwd,omitempty"`
ShellVars []byte `json:"shellvars,omitempty"`
Aliases string `json:"aliases,omitempty"`
Funcs string `json:"funcs,omitempty"`
Error string `json:"error,omitempty"`
HashVal string `json:"-"`
}
type ShellStateDiff struct {
Version string `json:"version"` // [type] [semver]
BaseHash string `json:"basehash"`
DiffHashArr []string `json:"diffhasharr,omitempty"`
Cwd string `json:"cwd,omitempty"`
VarsDiff []byte `json:"shellvarsdiff,omitempty"` // vardiff
AliasesDiff []byte `json:"aliasesdiff,omitempty"` // linediff
FuncsDiff []byte `json:"funcsdiff,omitempty"` // linediff
Error string `json:"error,omitempty"`
HashVal string `json:"-"`
}
func (state ShellState) IsEmpty() bool {
return state.Version == "" && state.Cwd == "" && len(state.ShellVars) == 0 && state.Aliases == "" && state.Funcs == "" && state.Error == ""
}
// returns base64 hash of data
func sha1Hash(data []byte) string {
hvalRaw := sha1.Sum(data)
hval := base64.StdEncoding.EncodeToString(hvalRaw[:])
return hval
}
// returns (SHA1, encoded-state)
func (state ShellState) EncodeAndHash() (string, []byte) {
var buf bytes.Buffer
binpack.PackInt(&buf, ShellStatePackVersion)
binpack.PackValue(&buf, []byte(state.Version))
binpack.PackValue(&buf, []byte(state.Cwd))
binpack.PackValue(&buf, state.ShellVars)
binpack.PackValue(&buf, []byte(state.Aliases))
binpack.PackValue(&buf, []byte(state.Funcs))
binpack.PackValue(&buf, []byte(state.Error))
return sha1Hash(buf.Bytes()), buf.Bytes()
}
func (state ShellState) MarshalJSON() ([]byte, error) {
_, encodedBytes := state.EncodeAndHash()
return json.Marshal(encodedBytes)
}
// caches HashVal in struct
func (state *ShellState) GetHashVal(force bool) string {
if state.HashVal == "" || force {
state.HashVal, _ = state.EncodeAndHash()
}
return state.HashVal
}
func (state *ShellState) DecodeShellState(barr []byte) error {
state.HashVal = sha1Hash(barr)
buf := bytes.NewBuffer(barr)
u := binpack.MakeUnpacker(buf)
version := u.UnpackInt("ShellState pack version")
if version != ShellStatePackVersion {
return fmt.Errorf("invalid ShellState pack version: %d", version)
}
state.Version = string(u.UnpackValue("ShellState.Version"))
state.Cwd = string(u.UnpackValue("ShellState.Cwd"))
state.ShellVars = u.UnpackValue("ShellState.ShellVars")
state.Aliases = string(u.UnpackValue("ShellState.Aliases"))
state.Funcs = string(u.UnpackValue("ShellState.Funcs"))
state.Error = string(u.UnpackValue("ShellState.Error"))
return u.Error()
}
func (state *ShellState) UnmarshalJSON(jsonBytes []byte) error {
var barr []byte
err := json.Unmarshal(jsonBytes, &barr)
if err != nil {
return err
}
return state.DecodeShellState(barr)
}
func (sdiff ShellStateDiff) EncodeAndHash() (string, []byte) {
var buf bytes.Buffer
binpack.PackInt(&buf, ShellStateDiffPackVersion)
binpack.PackValue(&buf, []byte(sdiff.Version))
binpack.PackValue(&buf, []byte(sdiff.BaseHash))
binpack.PackStrArr(&buf, sdiff.DiffHashArr)
binpack.PackValue(&buf, []byte(sdiff.Cwd))
binpack.PackValue(&buf, sdiff.VarsDiff)
binpack.PackValue(&buf, sdiff.AliasesDiff)
binpack.PackValue(&buf, sdiff.FuncsDiff)
binpack.PackValue(&buf, []byte(sdiff.Error))
return sha1Hash(buf.Bytes()), buf.Bytes()
}
func (sdiff ShellStateDiff) MarshalJSON() ([]byte, error) {
_, encodedBytes := sdiff.EncodeAndHash()
return json.Marshal(encodedBytes)
}
func (sdiff *ShellStateDiff) DecodeShellStateDiff(barr []byte) error {
sdiff.HashVal = sha1Hash(barr)
buf := bytes.NewBuffer(barr)
u := binpack.MakeUnpacker(buf)
version := u.UnpackInt("ShellState pack version")
if version != ShellStateDiffPackVersion {
return fmt.Errorf("invalid ShellStateDiff pack version: %d", version)
}
sdiff.Version = string(u.UnpackValue("ShellStateDiff.Version"))
sdiff.BaseHash = string(u.UnpackValue("ShellStateDiff.BaseHash"))
sdiff.DiffHashArr = u.UnpackStrArr("ShellStateDiff.DiffHashArr")
sdiff.Cwd = string(u.UnpackValue("ShellStateDiff.Cwd"))
sdiff.VarsDiff = u.UnpackValue("ShellStateDiff.VarsDiff")
sdiff.AliasesDiff = u.UnpackValue("ShellStateDiff.AliasesDiff")
sdiff.FuncsDiff = u.UnpackValue("ShellStateDiff.FuncsDiff")
sdiff.Error = string(u.UnpackValue("ShellStateDiff.Error"))
return u.Error()
}
func (sdiff *ShellStateDiff) UnmarshalJSON(jsonBytes []byte) error {
var barr []byte
err := json.Unmarshal(jsonBytes, &barr)
if err != nil {
return err
}
return sdiff.DecodeShellStateDiff(barr)
}
// caches HashVal in struct
func (sdiff *ShellStateDiff) GetHashVal(force bool) string {
if sdiff.HashVal == "" || force {
sdiff.HashVal, _ = sdiff.EncodeAndHash()
}
return sdiff.HashVal
}
func (sdiff ShellStateDiff) Dump(vars bool, aliases bool, funcs bool) {
fmt.Printf("ShellStateDiff:\n")
fmt.Printf(" version: %s\n", sdiff.Version)
fmt.Printf(" base: %s\n", sdiff.BaseHash)
fmt.Printf(" vars: %d, aliases: %d, funcs: %d\n", len(sdiff.VarsDiff), len(sdiff.AliasesDiff), len(sdiff.FuncsDiff))
if sdiff.Error != "" {
fmt.Printf(" error: %s\n", sdiff.Error)
}
if vars {
var mdiff statediff.MapDiffType
err := mdiff.Decode(sdiff.VarsDiff)
if err != nil {
fmt.Printf(" vars: error[%s]\n", err.Error())
} else {
mdiff.Dump()
}
}
if aliases && len(sdiff.AliasesDiff) > 0 {
var ldiff statediff.LineDiffType
err := ldiff.Decode(sdiff.AliasesDiff)
if err != nil {
fmt.Printf(" aliases: error[%s]\n", err.Error())
} else {
ldiff.Dump()
}
}
if funcs && len(sdiff.FuncsDiff) > 0 {
var ldiff statediff.LineDiffType
err := ldiff.Decode(sdiff.FuncsDiff)
if err != nil {
fmt.Printf(" funcs: error[%s]\n", err.Error())
} else {
ldiff.Dump()
}
}
}

View File

@ -0,0 +1,747 @@
package server
import (
"context"
"errors"
"fmt"
"io"
"io/fs"
"os"
"os/exec"
"path/filepath"
"sort"
"strings"
"sync"
"time"
"github.com/alessio/shellescape"
"github.com/commandlinedev/apishell/pkg/base"
"github.com/commandlinedev/apishell/pkg/packet"
"github.com/commandlinedev/apishell/pkg/shexec"
)
const MaxFileDataPacketSize = 16 * 1024
const WriteFileContextTimeout = 30 * time.Second
const cleanLoopTime = 5 * time.Second
const MaxWriteFileContextData = 100
// TODO create unblockable packet-sender (backed by an array) for clientproc
type MServer struct {
Lock *sync.Mutex
MainInput *packet.PacketParser
Sender *packet.PacketSender
ClientMap map[base.CommandKey]*shexec.ClientProc
Debug bool
StateMap map[string]*packet.ShellState // sha1->state
CurrentState string // sha1
WriteErrorCh chan bool // closed if there is a I/O write error
WriteErrorChOnce *sync.Once
WriteFileContextMap map[string]*WriteFileContext
Done bool
}
type WriteFileContext struct {
CVar *sync.Cond
Data []*packet.FileDataPacketType
LastActive time.Time
Err error
Done bool
}
func (m *MServer) Close() {
m.Sender.Close()
m.Sender.WaitForDone()
m.Lock.Lock()
defer m.Lock.Unlock()
m.Done = true
}
func (m *MServer) checkDone() bool {
m.Lock.Lock()
defer m.Lock.Unlock()
return m.Done
}
func (m *MServer) getWriteFileContext(reqId string) *WriteFileContext {
m.Lock.Lock()
defer m.Lock.Unlock()
wfc := m.WriteFileContextMap[reqId]
if wfc == nil {
wfc = &WriteFileContext{
CVar: sync.NewCond(&sync.Mutex{}),
LastActive: time.Now(),
}
m.WriteFileContextMap[reqId] = wfc
}
return wfc
}
func (m *MServer) addFileDataPacket(pk *packet.FileDataPacketType) {
m.Lock.Lock()
wfc := m.WriteFileContextMap[pk.RespId]
m.Lock.Unlock()
if wfc == nil {
return
}
wfc.CVar.L.Lock()
defer wfc.CVar.L.Unlock()
if wfc.Done || wfc.Err != nil {
return
}
if len(wfc.Data) > MaxWriteFileContextData {
wfc.Err = errors.New("write-file buffer length exceeded")
wfc.Data = nil
wfc.CVar.Broadcast()
return
}
wfc.LastActive = time.Now()
wfc.Data = append(wfc.Data, pk)
wfc.CVar.Signal()
}
func (wfc *WriteFileContext) setDone() {
wfc.CVar.L.Lock()
defer wfc.CVar.L.Unlock()
wfc.Done = true
wfc.Data = nil
wfc.CVar.Broadcast()
}
func (m *MServer) cleanWriteFileContexts() {
now := time.Now()
var staleWfcs []*WriteFileContext
m.Lock.Lock()
for reqId, wfc := range m.WriteFileContextMap {
if now.Sub(wfc.LastActive) > WriteFileContextTimeout {
staleWfcs = append(staleWfcs, wfc)
delete(m.WriteFileContextMap, reqId)
}
}
m.Lock.Unlock()
// we do this outside of m.Lock just in case there is some lock contention (end of WriteFile could theoretically be slow)
for _, wfc := range staleWfcs {
wfc.setDone()
}
}
func (m *MServer) ProcessCommandPacket(pk packet.CommandPacketType) {
ck := pk.GetCK()
if ck == "" {
m.Sender.SendMessageFmt("received '%s' packet without ck", pk.GetType())
return
}
m.Lock.Lock()
cproc := m.ClientMap[ck]
m.Lock.Unlock()
if cproc == nil {
m.Sender.SendCmdError(ck, fmt.Errorf("no client proc for ck '%s', pk=%s", ck, packet.AsString(pk)))
return
}
cproc.Input.SendPacket(pk)
return
}
func runSingleCompGen(cwd string, compType string, prefix string) ([]string, bool, error) {
if !packet.IsValidCompGenType(compType) {
return nil, false, fmt.Errorf("invalid compgen type '%s'", compType)
}
compGenCmdStr := fmt.Sprintf("cd %s; compgen -A %s -- %s | sort | uniq | head -n %d", shellescape.Quote(cwd), shellescape.Quote(compType), shellescape.Quote(prefix), packet.MaxCompGenValues+1)
ecmd := exec.Command("bash", "-c", compGenCmdStr)
outputBytes, err := ecmd.Output()
if err != nil {
return nil, false, fmt.Errorf("compgen error: %w", err)
}
outputStr := string(outputBytes)
parts := strings.Split(outputStr, "\n")
if len(parts) > 0 && parts[len(parts)-1] == "" {
parts = parts[0 : len(parts)-1]
}
hasMore := false
if len(parts) > packet.MaxCompGenValues {
hasMore = true
parts = parts[0:packet.MaxCompGenValues]
}
return parts, hasMore, nil
}
func appendSlashes(comps []string) {
for idx, comp := range comps {
comps[idx] = comp + "/"
}
}
func strArrToMap(strs []string) map[string]bool {
rtn := make(map[string]bool)
for _, s := range strs {
rtn[s] = true
}
return rtn
}
func (m *MServer) runMixedCompGen(compPk *packet.CompGenPacketType) {
// get directories and files, unique them and put slashes on directories for completion
reqId := compPk.GetReqId()
compDirs, hasMoreDirs, err := runSingleCompGen(compPk.Cwd, "directory", compPk.Prefix)
if err != nil {
m.Sender.SendErrorResponse(reqId, err)
return
}
compFiles, hasMoreFiles, err := runSingleCompGen(compPk.Cwd, compPk.CompType, compPk.Prefix)
if err != nil {
m.Sender.SendErrorResponse(reqId, err)
return
}
dirMap := strArrToMap(compDirs)
// seed comps with dirs (but append slashes)
comps := compDirs
appendSlashes(comps)
// add files that are not directories (look up in dirMap)
for _, file := range compFiles {
if dirMap[file] {
continue
}
comps = append(comps, file)
}
sort.Strings(comps) // resort
m.Sender.SendResponse(reqId, map[string]interface{}{"comps": comps, "hasmore": (hasMoreFiles || hasMoreDirs)})
return
}
func (m *MServer) runCompGen(compPk *packet.CompGenPacketType) {
reqId := compPk.GetReqId()
if compPk.CompType == "file" || compPk.CompType == "command" {
m.runMixedCompGen(compPk)
return
}
comps, hasMore, err := runSingleCompGen(compPk.Cwd, compPk.CompType, compPk.Prefix)
if err != nil {
m.Sender.SendErrorResponse(reqId, err)
return
}
if compPk.CompType == "directory" {
appendSlashes(comps)
}
m.Sender.SendResponse(reqId, map[string]interface{}{"comps": comps, "hasmore": hasMore})
return
}
func (m *MServer) setCurrentState(state *packet.ShellState) {
if state == nil {
return
}
hval, _ := state.EncodeAndHash()
m.Lock.Lock()
defer m.Lock.Unlock()
m.StateMap[hval] = state
m.CurrentState = hval
}
func (m *MServer) reinit(reqId string) {
initPk, err := shexec.MakeServerInitPacket()
if err != nil {
m.Sender.SendErrorResponse(reqId, fmt.Errorf("error creating init packet: %w", err))
return
}
m.setCurrentState(initPk.State)
initPk.RespId = reqId
m.Sender.SendPacket(initPk)
}
func makeTemp(path string, mode fs.FileMode) (*os.File, error) {
dirName := filepath.Dir(path)
baseName := filepath.Base(path)
baseTempName := baseName + ".tmp."
writeFd, err := os.CreateTemp(dirName, baseTempName)
if err != nil {
return nil, err
}
err = writeFd.Chmod(mode)
if err != nil {
writeFd.Close()
os.Remove(writeFd.Name())
return nil, fmt.Errorf("error setting tempfile permissions: %w", err)
}
return writeFd, nil
}
func checkFileWritable(path string) error {
finfo, err := os.Stat(path) // ok to follow symlinks
if errors.Is(err, fs.ErrNotExist) {
dirName := filepath.Dir(path)
dirInfo, err := os.Stat(dirName)
if err != nil {
return fmt.Errorf("file does not exist, error trying to stat parent directory: %w", err)
}
if !dirInfo.IsDir() {
return fmt.Errorf("file does not exist, parent path [%s] is not a directory", dirName)
}
return nil
} else {
if err != nil {
return fmt.Errorf("cannot stat: %w", err)
}
if finfo.IsDir() {
return fmt.Errorf("invalid path, cannot write a directory")
}
if (finfo.Mode() & fs.ModeSymlink) != 0 {
return fmt.Errorf("writefile does not support symlinks") // note this shouldn't happen because we're using Stat (not Lstat)
}
if (finfo.Mode() & (fs.ModeNamedPipe | fs.ModeSocket | fs.ModeDevice)) != 0 {
return fmt.Errorf("writefile does not support special files (named pipes, sockets, devices): mode=%v", finfo.Mode())
}
writePerm := (finfo.Mode().Perm() & 0o222)
if writePerm == 0 {
return fmt.Errorf("file is not writable, perms: %v", finfo.Mode().Perm())
}
return nil
}
}
func copyFile(dstName string, srcName string) error {
srcFd, err := os.Open(srcName)
if err != nil {
return err
}
defer srcFd.Close()
dstFd, err := os.OpenFile(dstName, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o666) // use 666 because OpenFile respects umask
if err != nil {
return err
}
// we don't defer dstFd.Close() so we can return an error if dstFd.Close() returns an error
_, err = io.Copy(dstFd, srcFd)
if err != nil {
dstFd.Close()
return err
}
return dstFd.Close()
}
func (m *MServer) writeFile(pk *packet.WriteFilePacketType, wfc *WriteFileContext) {
defer wfc.setDone()
if pk.Path == "" {
resp := packet.MakeWriteFileReadyPacket(pk.ReqId)
resp.Error = "invalid write-file request, no path specified"
m.Sender.SendPacket(resp)
return
}
err := checkFileWritable(pk.Path)
if err != nil {
resp := packet.MakeWriteFileReadyPacket(pk.ReqId)
resp.Error = err.Error()
m.Sender.SendPacket(resp)
return
}
var writeFd *os.File
if pk.UseTemp {
writeFd, err = os.CreateTemp("", "mshell.writefile.*") // "" means make this file in standard TempDir
if err != nil {
resp := packet.MakeWriteFileReadyPacket(pk.ReqId)
resp.Error = fmt.Sprintf("cannot create temp file: %v", err)
m.Sender.SendPacket(resp)
return
}
} else {
writeFd, err = os.OpenFile(pk.Path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o666) // use 666 because OpenFile respects umask
if err != nil {
resp := packet.MakeWriteFileReadyPacket(pk.ReqId)
resp.Error = fmt.Sprintf("write-file could not open file: %v", err)
m.Sender.SendPacket(resp)
return
}
}
// ok, so now writeFd is valid, send the "ready" response
resp := packet.MakeWriteFileReadyPacket(pk.ReqId)
m.Sender.SendPacket(resp)
// now we wait for data (cond var)
// this Unlock() runs first (because it is a later defer) so we can still run wfc.setDone() safely
wfc.CVar.L.Lock()
defer wfc.CVar.L.Unlock()
var doneErr error
for {
if wfc.Done {
break
}
if wfc.Err != nil {
doneErr = wfc.Err
break
}
if len(wfc.Data) == 0 {
wfc.CVar.Wait()
continue
}
dataPk := wfc.Data[0]
wfc.Data = wfc.Data[1:]
if dataPk.Error != "" {
doneErr = fmt.Errorf("error received from client: %v", errors.New(dataPk.Error))
break
}
if len(dataPk.Data) > 0 {
_, err := writeFd.Write(dataPk.Data)
if err != nil {
doneErr = fmt.Errorf("error writing data to file: %v", err)
break
}
}
if dataPk.Eof {
break
}
}
closeErr := writeFd.Close()
if doneErr == nil && closeErr != nil {
doneErr = fmt.Errorf("error closing file: %v", closeErr)
}
if pk.UseTemp {
if doneErr != nil {
os.Remove(writeFd.Name())
} else {
// copy file between writeFd.Name() and pk.Path
copyErr := copyFile(pk.Path, writeFd.Name())
if err != nil {
doneErr = fmt.Errorf("error writing file: %v", copyErr)
}
os.Remove(writeFd.Name())
}
}
donePk := packet.MakeWriteFileDonePacket(pk.ReqId)
if doneErr != nil {
donePk.Error = doneErr.Error()
}
m.Sender.SendPacket(donePk)
}
func (m *MServer) returnStreamFileNewFileResponse(pk *packet.StreamFilePacketType) {
// ok, file doesn't exist, so try to check the directory at least to see if we can write a file here
resp := packet.MakeStreamFileResponse(pk.ReqId)
defer func() {
if resp.Error == "" {
resp.Done = true
}
m.Sender.SendPacket(resp)
}()
dirName := filepath.Dir(pk.Path)
dirInfo, err := os.Stat(dirName)
if err != nil {
resp.Error = fmt.Sprintf("file does not exist, error trying to stat parent directory: %v", err)
return
}
if !dirInfo.IsDir() {
resp.Error = fmt.Sprintf("file does not exist, parent path [%s] is not a directory", dirName)
return
}
resp.Info = &packet.FileInfo{
Name: pk.Path,
Size: 0,
ModTs: 0,
IsDir: false,
Perm: int(dirInfo.Mode().Perm()),
NotFound: true,
}
return
}
func (m *MServer) streamFile(pk *packet.StreamFilePacketType) {
resp := packet.MakeStreamFileResponse(pk.ReqId)
finfo, err := os.Stat(pk.Path)
if errors.Is(err, fs.ErrNotExist) {
// special return
m.returnStreamFileNewFileResponse(pk)
return
}
if err != nil {
resp.Error = fmt.Sprintf("cannot stat file %q: %v", pk.Path, err)
m.Sender.SendPacket(resp)
return
}
resp.Info = &packet.FileInfo{
Name: pk.Path,
Size: finfo.Size(),
ModTs: finfo.ModTime().UnixMilli(),
IsDir: finfo.IsDir(),
Perm: int(finfo.Mode().Perm()),
}
if pk.StatOnly {
resp.Done = true
m.Sender.SendPacket(resp)
return
}
// like the http Range header. range header is end inclusive. for us, endByte is non-inclusive (so we add 1)
var startByte, endByte int64
if len(pk.ByteRange) == 0 {
endByte = finfo.Size()
} else if len(pk.ByteRange) == 1 && pk.ByteRange[0] >= 0 {
startByte = pk.ByteRange[0]
endByte = finfo.Size()
} else if len(pk.ByteRange) == 1 && pk.ByteRange[0] < 0 {
startByte = finfo.Size() + pk.ByteRange[0] // "+" since ByteRange[0] is less than 0
endByte = finfo.Size()
} else if len(pk.ByteRange) == 2 {
startByte = pk.ByteRange[0]
endByte = pk.ByteRange[1] + 1
} else {
resp.Error = fmt.Sprintf("invalid byte range (%d entries)", len(pk.ByteRange))
m.Sender.SendPacket(resp)
return
}
if startByte < 0 {
startByte = 0
}
if endByte > finfo.Size() {
endByte = finfo.Size()
}
if startByte >= endByte {
resp.Done = true
m.Sender.SendPacket(resp)
return
}
fd, err := os.Open(pk.Path)
if err != nil {
resp.Error = fmt.Sprintf("opening file: %v", err)
m.Sender.SendPacket(resp)
return
}
defer fd.Close()
m.Sender.SendPacket(resp)
var buffer [MaxFileDataPacketSize]byte
var sentDone bool
first := true
for ; startByte < endByte; startByte += MaxFileDataPacketSize {
if !first {
// throttle packet sending @ 1000 packets/s, or 16M/s
time.Sleep(1 * time.Millisecond)
}
first = false
readLen := int64Min(MaxFileDataPacketSize, endByte-startByte)
bufSlice := buffer[0:readLen]
nr, err := fd.ReadAt(bufSlice, startByte)
dataPk := packet.MakeFileDataPacket(pk.ReqId)
dataPk.Data = make([]byte, nr)
copy(dataPk.Data, bufSlice)
if err == io.EOF {
dataPk.Eof = true
} else if err != nil {
dataPk.Error = err.Error()
}
m.Sender.SendPacket(dataPk)
if dataPk.GetResponseDone() {
sentDone = true
break
}
}
if !sentDone {
dataPk := packet.MakeFileDataPacket(pk.ReqId)
dataPk.Eof = true
m.Sender.SendPacket(dataPk)
}
return
}
func int64Min(v1 int64, v2 int64) int64 {
if v1 < v2 {
return v1
}
return v2
}
func (m *MServer) ProcessRpcPacket(pk packet.RpcPacketType) {
reqId := pk.GetReqId()
if cdPk, ok := pk.(*packet.CdPacketType); ok {
err := os.Chdir(cdPk.Dir)
if err != nil {
m.Sender.SendErrorResponse(reqId, fmt.Errorf("cannot change directory: %w", err))
return
}
m.Sender.SendResponse(reqId, true)
return
}
if compPk, ok := pk.(*packet.CompGenPacketType); ok {
go m.runCompGen(compPk)
return
}
if _, ok := pk.(*packet.ReInitPacketType); ok {
go m.reinit(reqId)
return
}
if streamPk, ok := pk.(*packet.StreamFilePacketType); ok {
go m.streamFile(streamPk)
return
}
if writePk, ok := pk.(*packet.WriteFilePacketType); ok {
wfc := m.getWriteFileContext(writePk.ReqId)
go m.writeFile(writePk, wfc)
return
}
m.Sender.SendErrorResponse(reqId, fmt.Errorf("invalid rpc type '%s'", pk.GetType()))
return
}
func (m *MServer) getCurrentState() (string, *packet.ShellState) {
m.Lock.Lock()
defer m.Lock.Unlock()
return m.CurrentState, m.StateMap[m.CurrentState]
}
func (m *MServer) clientPacketCallback(pk packet.PacketType) {
if pk.GetType() != packet.CmdDonePacketStr {
return
}
donePk := pk.(*packet.CmdDonePacketType)
if donePk.FinalState == nil {
return
}
stateHash, curState := m.getCurrentState()
if curState == nil {
return
}
diff, err := shexec.MakeShellStateDiff(*curState, stateHash, *donePk.FinalState)
if err != nil {
return
}
donePk.FinalState = nil
donePk.FinalStateDiff = &diff
}
func (m *MServer) runCommand(runPacket *packet.RunPacketType) {
if err := runPacket.CK.Validate("packet"); err != nil {
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("server run packets require valid ck: %s", err))
return
}
ecmd, err := shexec.SSHOpts{}.MakeMShellSingleCmd(true)
if err != nil {
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("server run packets require valid ck: %s", err))
return
}
cproc, _, err := shexec.MakeClientProc(context.Background(), ecmd)
if err != nil {
m.Sender.SendErrorResponse(runPacket.ReqId, fmt.Errorf("starting mshell client: %s", err))
return
}
m.Lock.Lock()
m.ClientMap[runPacket.CK] = cproc
m.Lock.Unlock()
go func() {
defer func() {
r := recover()
finalPk := packet.MakeCmdFinalPacket(runPacket.CK)
finalPk.Ts = time.Now().UnixMilli()
if r != nil {
finalPk.Error = fmt.Sprintf("%s", r)
}
m.Sender.SendPacket(finalPk)
m.Lock.Lock()
delete(m.ClientMap, runPacket.CK)
m.Lock.Unlock()
cproc.Close()
}()
shexec.SendRunPacketAndRunData(context.Background(), cproc.Input, runPacket)
cproc.ProxySingleOutput(runPacket.CK, m.Sender, m.clientPacketCallback)
}()
}
func (m *MServer) packetSenderErrorHandler(sender *packet.PacketSender, pk packet.PacketType, err error) {
if serr, ok := err.(*packet.SendError); ok && serr.IsMarshalError {
msg := packet.MakeMessagePacket(err.Error())
if cpk, ok := pk.(packet.CommandPacketType); ok {
msg.CK = cpk.GetCK()
}
sender.SendPacket(msg)
return
} else {
// I/O error: close the WriteErrorCh to signal that we are dead (cannot continue if we can't write output)
m.WriteErrorChOnce.Do(func() {
close(m.WriteErrorCh)
})
}
}
func (server *MServer) runReadLoop() {
builder := packet.MakeRunPacketBuilder()
for pk := range server.MainInput.MainCh {
if server.Debug {
fmt.Printf("PK> %s\n", packet.AsString(pk))
}
ok, runPacket := builder.ProcessPacket(pk)
if ok {
if runPacket != nil {
server.runCommand(runPacket)
continue
}
continue
}
if cmdPk, ok := pk.(packet.CommandPacketType); ok {
server.ProcessCommandPacket(cmdPk)
continue
}
if rpcPk, ok := pk.(packet.RpcPacketType); ok {
server.ProcessRpcPacket(rpcPk)
continue
}
if fileDataPk, ok := pk.(*packet.FileDataPacketType); ok {
server.addFileDataPacket(fileDataPk)
continue
}
server.Sender.SendMessageFmt("invalid packet '%s' sent to mshell server", packet.AsString(pk))
continue
}
}
func RunServer() (int, error) {
debug := false
if len(os.Args) >= 3 && os.Args[2] == "--debug" {
debug = true
}
server := &MServer{
Lock: &sync.Mutex{},
ClientMap: make(map[base.CommandKey]*shexec.ClientProc),
StateMap: make(map[string]*packet.ShellState),
Debug: debug,
WriteErrorCh: make(chan bool),
WriteErrorChOnce: &sync.Once{},
WriteFileContextMap: make(map[string]*WriteFileContext),
}
go func() {
for {
if server.checkDone() {
return
}
time.Sleep(cleanLoopTime)
server.cleanWriteFileContexts()
}
}()
if debug {
packet.GlobalDebug = true
}
server.MainInput = packet.MakePacketParser(os.Stdin, false)
server.Sender = packet.MakePacketSender(os.Stdout, server.packetSenderErrorHandler)
defer server.Close()
var err error
initPacket, err := shexec.MakeServerInitPacket()
if err != nil {
return 1, err
}
server.setCurrentState(initPacket.State)
server.Sender.SendPacket(initPacket)
ticker := time.NewTicker(1 * time.Minute)
go func() {
for range ticker.C {
server.Sender.SendPacket(packet.MakePingPacket())
}
}()
defer ticker.Stop()
readLoopDoneCh := make(chan bool)
go func() {
defer close(readLoopDoneCh)
server.runReadLoop()
}()
select {
case <-readLoopDoneCh:
break
case <-server.WriteErrorCh:
break
}
return 0, nil
}

View File

@ -0,0 +1,132 @@
package shexec
import (
"context"
"fmt"
"io"
"os/exec"
"time"
"github.com/commandlinedev/apishell/pkg/base"
"github.com/commandlinedev/apishell/pkg/packet"
"golang.org/x/mod/semver"
)
// TODO - track buffer sizes for sending input
const NotFoundVersion = "v0.0"
type ClientProc struct {
Cmd *exec.Cmd
InitPk *packet.InitPacketType
StartTs time.Time
StdinWriter io.WriteCloser
StdoutReader io.ReadCloser
StderrReader io.ReadCloser
Input *packet.PacketSender
Output *packet.PacketParser
}
// returns (clientproc, initpk, error)
func MakeClientProc(ctx context.Context, ecmd *exec.Cmd) (*ClientProc, *packet.InitPacketType, error) {
inputWriter, err := ecmd.StdinPipe()
if err != nil {
return nil, nil, fmt.Errorf("creating stdin pipe: %v", err)
}
stdoutReader, err := ecmd.StdoutPipe()
if err != nil {
return nil, nil, fmt.Errorf("creating stdout pipe: %v", err)
}
stderrReader, err := ecmd.StderrPipe()
if err != nil {
return nil, nil, fmt.Errorf("creating stderr pipe: %v", err)
}
startTs := time.Now()
err = ecmd.Start()
if err != nil {
return nil, nil, fmt.Errorf("running local client: %w", err)
}
sender := packet.MakePacketSender(inputWriter, nil)
stdoutPacketParser := packet.MakePacketParser(stdoutReader, false)
stderrPacketParser := packet.MakePacketParser(stderrReader, false)
packetParser := packet.CombinePacketParsers(stdoutPacketParser, stderrPacketParser, true)
cproc := &ClientProc{
Cmd: ecmd,
StartTs: startTs,
StdinWriter: inputWriter,
StdoutReader: stdoutReader,
StderrReader: stderrReader,
Input: sender,
Output: packetParser,
}
var pk packet.PacketType
select {
case pk = <-packetParser.MainCh:
case <-ctx.Done():
cproc.Close()
return nil, nil, ctx.Err()
}
if pk != nil {
if pk.GetType() != packet.InitPacketStr {
cproc.Close()
return nil, nil, fmt.Errorf("invalid packet received from mshell client: %s", packet.AsString(pk))
}
initPk := pk.(*packet.InitPacketType)
if initPk.NotFound {
cproc.Close()
return nil, initPk, fmt.Errorf("mshell client not found")
}
if semver.MajorMinor(initPk.Version) != semver.MajorMinor(base.MShellVersion) {
cproc.Close()
return nil, initPk, fmt.Errorf("invalid remote mshell version '%s', must be '=%s'", initPk.Version, semver.MajorMinor(base.MShellVersion))
}
cproc.InitPk = initPk
}
if cproc.InitPk == nil {
cproc.Close()
return nil, nil, fmt.Errorf("no init packet received from mshell client")
}
return cproc, cproc.InitPk, nil
}
func (cproc *ClientProc) Close() {
if cproc.Input != nil {
cproc.Input.Close()
}
if cproc.StdinWriter != nil {
cproc.StdinWriter.Close()
}
if cproc.StdoutReader != nil {
cproc.StdoutReader.Close()
}
if cproc.StderrReader != nil {
cproc.StderrReader.Close()
}
if cproc.Cmd != nil {
cproc.Cmd.Process.Kill()
}
}
func (cproc *ClientProc) ProxySingleOutput(ck base.CommandKey, sender *packet.PacketSender, packetCallback func(packet.PacketType)) {
sentDonePk := false
for pk := range cproc.Output.MainCh {
if packetCallback != nil {
packetCallback(pk)
}
if pk.GetType() == packet.CmdDonePacketStr {
sentDonePk = true
}
sender.SendPacket(pk)
}
exitErr := cproc.Cmd.Wait()
if !sentDonePk {
endTs := time.Now()
cmdDuration := endTs.Sub(cproc.StartTs)
donePacket := packet.MakeCmdDonePacket(ck)
donePacket.Ts = endTs.UnixMilli()
donePacket.ExitCode = GetExitCode(exitErr)
donePacket.DurationMs = int64(cmdDuration / time.Millisecond)
sender.SendPacket(donePacket)
}
}

View File

@ -0,0 +1,638 @@
package shexec
import (
"bytes"
"fmt"
"io"
"regexp"
"sort"
"strings"
"github.com/alessio/shellescape"
"github.com/commandlinedev/apishell/pkg/packet"
"github.com/commandlinedev/apishell/pkg/simpleexpand"
"github.com/commandlinedev/apishell/pkg/statediff"
"mvdan.cc/sh/v3/expand"
"mvdan.cc/sh/v3/syntax"
)
const (
DeclTypeArray = "array"
DeclTypeAssocArray = "assoc"
DeclTypeInt = "int"
DeclTypeNormal = "normal"
)
type ParseEnviron struct {
Env map[string]string
}
func (e *ParseEnviron) Get(name string) expand.Variable {
val, ok := e.Env[name]
if !ok {
return expand.Variable{}
}
return expand.Variable{
Exported: true,
Kind: expand.String,
Str: val,
}
}
func (e *ParseEnviron) Each(fn func(name string, vr expand.Variable) bool) {
for key, _ := range e.Env {
rtn := fn(key, e.Get(key))
if !rtn {
break
}
}
}
func doCmdSubst(commandStr string, w io.Writer, word *syntax.CmdSubst) error {
return nil
}
func doProcSubst(w *syntax.ProcSubst) (string, error) {
return "", nil
}
func GetParserConfig(envMap map[string]string) *expand.Config {
cfg := &expand.Config{
Env: &ParseEnviron{Env: envMap},
GlobStar: false,
NullGlob: false,
NoUnset: false,
CmdSubst: func(w io.Writer, word *syntax.CmdSubst) error { return doCmdSubst("", w, word) },
ProcSubst: doProcSubst,
ReadDir: nil,
}
return cfg
}
func writeIndent(buf *bytes.Buffer, num int) {
for i := 0; i < num; i++ {
buf.WriteByte(' ')
}
}
func makeSpaceStr(num int) string {
barr := make([]byte, num)
for i := 0; i < num; i++ {
barr[i] = ' '
}
return string(barr)
}
// https://wiki.bash-hackers.org/syntax/shellvars
var NoStoreVarNames = map[string]bool{
"BASH": true,
"BASHOPTS": true,
"BASHPID": true,
"BASH_ALIASES": true,
"BASH_ARGC": true,
"BASH_ARGV": true,
"BASH_ARGV0": true,
"BASH_CMDS": true,
"BASH_COMMAND": true,
"BASH_EXECUTION_STRING": true,
"LINENO": true,
"BASH_LINENO": true,
"BASH_REMATCH": true,
"BASH_SOURCE": true,
"BASH_SUBSHELL": true,
"COPROC": true,
"DIRSTACK": true,
"EPOCHREALTIME": true,
"EPOCHSECONDS": true,
"FUNCNAME": true,
"HISTCMD": true,
"OLDPWD": true,
"PIPESTATUS": true,
"PPID": true,
"PWD": true,
"RANDOM": true,
"SECONDS": true,
"SHLVL": true,
"HISTFILE": true,
"HISTFILESIZE": true,
"HISTCONTROL": true,
"HISTIGNORE": true,
"HISTSIZE": true,
"HISTTIMEFORMAT": true,
"SRANDOM": true,
"COLUMNS": true,
// we want these in our remote state object
// "EUID": true,
// "SHELLOPTS": true,
// "UID": true,
// "BASH_VERSINFO": true,
// "BASH_VERSION": true,
}
type DeclareDeclType struct {
Args string
Name string
// this holds the raw quoted value suitable for bash. this is *not* the real expanded variable value
Value string
}
var declareDeclArgsRe = regexp.MustCompile("^[aAxrifx]*$")
var bashValidIdentifierRe = regexp.MustCompile("^[a-zA-Z_][a-zA-Z0-9_]*$")
func (d *DeclareDeclType) Validate() error {
if len(d.Name) == 0 || !IsValidBashIdentifier(d.Name) {
return fmt.Errorf("invalid shell variable name (invalid bash identifier)")
}
if strings.Index(d.Value, "\x00") >= 0 {
return fmt.Errorf("invalid shell variable value (cannot contain 0 byte)")
}
if !declareDeclArgsRe.MatchString(d.Args) {
return fmt.Errorf("invalid shell variable type %s", shellescape.Quote(d.Args))
}
return nil
}
func (d *DeclareDeclType) Serialize() string {
return fmt.Sprintf("%s|%s=%s\x00", d.Args, d.Name, d.Value)
}
func (d *DeclareDeclType) DeclareStmt() string {
var argsStr string
if d.Args == "" {
argsStr = "--"
} else {
argsStr = "-" + d.Args
}
return fmt.Sprintf("declare %s %s=%s", argsStr, d.Name, d.Value)
}
// envline should be valid
func ParseDeclLine(envLine string) *DeclareDeclType {
eqIdx := strings.Index(envLine, "=")
if eqIdx == -1 {
return nil
}
namePart := envLine[0:eqIdx]
valPart := envLine[eqIdx+1:]
pipeIdx := strings.Index(namePart, "|")
if pipeIdx == -1 {
return nil
}
return &DeclareDeclType{
Args: namePart[0:pipeIdx],
Name: namePart[pipeIdx+1:],
Value: valPart,
}
}
// returns name => full-line
func parseDeclLineToKV(envLine string) (string, string) {
decl := ParseDeclLine(envLine)
if decl == nil {
return "", ""
}
return decl.Name, envLine
}
func shellStateVarsToMap(shellVars []byte) map[string]string {
if len(shellVars) == 0 {
return nil
}
rtn := make(map[string]string)
vars := bytes.Split(shellVars, []byte{0})
for _, varLine := range vars {
name, val := parseDeclLineToKV(string(varLine))
if name == "" {
continue
}
rtn[name] = val
}
return rtn
}
func strMapToShellStateVars(varMap map[string]string) []byte {
var buf bytes.Buffer
orderedKeys := getOrderedKeysStrMap(varMap)
for _, key := range orderedKeys {
val := varMap[key]
buf.WriteString(val)
buf.WriteByte(0)
}
return buf.Bytes()
}
func getOrderedKeysStrMap(m map[string]string) []string {
keys := make([]string, 0, len(m))
for key, _ := range m {
keys = append(keys, key)
}
sort.Strings(keys)
return keys
}
func getOrderedKeysDeclMap(m map[string]*DeclareDeclType) []string {
keys := make([]string, 0, len(m))
for key, _ := range m {
keys = append(keys, key)
}
sort.Strings(keys)
return keys
}
func DeclMapFromState(state *packet.ShellState) map[string]*DeclareDeclType {
if state == nil {
return nil
}
rtn := make(map[string]*DeclareDeclType)
vars := bytes.Split(state.ShellVars, []byte{0})
for _, varLine := range vars {
decl := ParseDeclLine(string(varLine))
if decl != nil {
rtn[decl.Name] = decl
}
}
return rtn
}
func SerializeDeclMap(declMap map[string]*DeclareDeclType) []byte {
var rtn bytes.Buffer
orderedKeys := getOrderedKeysDeclMap(declMap)
for _, key := range orderedKeys {
decl := declMap[key]
rtn.WriteString(decl.Serialize())
}
return rtn.Bytes()
}
func EnvMapFromState(state *packet.ShellState) map[string]string {
if state == nil {
return nil
}
rtn := make(map[string]string)
ectx := simpleexpand.SimpleExpandContext{}
vars := bytes.Split(state.ShellVars, []byte{0})
for _, varLine := range vars {
decl := ParseDeclLine(string(varLine))
if decl != nil && decl.IsExport() {
rtn[decl.Name], _ = simpleexpand.SimpleExpandPartialWord(ectx, decl.Value, false)
}
}
return rtn
}
func ShellVarMapFromState(state *packet.ShellState) map[string]string {
if state == nil {
return nil
}
rtn := make(map[string]string)
ectx := simpleexpand.SimpleExpandContext{}
vars := bytes.Split(state.ShellVars, []byte{0})
for _, varLine := range vars {
decl := ParseDeclLine(string(varLine))
if decl != nil {
rtn[decl.Name], _ = simpleexpand.SimpleExpandPartialWord(ectx, decl.Value, false)
}
}
return rtn
}
func DumpVarMapFromState(state *packet.ShellState) {
fmt.Printf("DUMP-STATE-VARS:\n")
if state == nil {
fmt.Printf(" nil\n")
return
}
vars := bytes.Split(state.ShellVars, []byte{0})
for _, varLine := range vars {
fmt.Printf(" %s\n", varLine)
}
}
func VarDeclsFromState(state *packet.ShellState) []*DeclareDeclType {
if state == nil {
return nil
}
var rtn []*DeclareDeclType
vars := bytes.Split(state.ShellVars, []byte{0})
for _, varLine := range vars {
decl := ParseDeclLine(string(varLine))
if decl != nil {
rtn = append(rtn, decl)
}
}
return rtn
}
func IsValidBashIdentifier(s string) bool {
return bashValidIdentifierRe.MatchString(s)
}
func (d *DeclareDeclType) IsExport() bool {
return strings.Index(d.Args, "x") >= 0
}
func (d *DeclareDeclType) IsReadOnly() bool {
return strings.Index(d.Args, "r") >= 0
}
func (d *DeclareDeclType) DataType() string {
if strings.Index(d.Args, "a") >= 0 {
return DeclTypeArray
}
if strings.Index(d.Args, "A") >= 0 {
return DeclTypeAssocArray
}
if strings.Index(d.Args, "i") >= 0 {
return DeclTypeInt
}
return DeclTypeNormal
}
func parseDeclareStmt(stmt *syntax.Stmt, src string) (*DeclareDeclType, error) {
cmd := stmt.Cmd
decl, ok := cmd.(*syntax.DeclClause)
if !ok || decl.Variant.Value != "declare" || len(decl.Args) != 2 {
return nil, fmt.Errorf("invalid declare variant")
}
rtn := &DeclareDeclType{}
declArgs := decl.Args[0]
if !declArgs.Naked || len(declArgs.Value.Parts) != 1 {
return nil, fmt.Errorf("wrong number of declare args parts")
}
declArgsLit, ok := declArgs.Value.Parts[0].(*syntax.Lit)
if !ok {
return nil, fmt.Errorf("declare args is not a literal")
}
if !strings.HasPrefix(declArgsLit.Value, "-") {
return nil, fmt.Errorf("declare args not an argument (does not start with '-')")
}
if declArgsLit.Value == "--" {
rtn.Args = ""
} else {
rtn.Args = declArgsLit.Value[1:]
}
declAssign := decl.Args[1]
if declAssign.Name == nil {
return nil, fmt.Errorf("declare does not have a valid name")
}
rtn.Name = declAssign.Name.Value
if declAssign.Naked || declAssign.Index != nil || declAssign.Append {
return nil, fmt.Errorf("invalid decl format")
}
if declAssign.Value != nil {
rtn.Value = string(src[declAssign.Value.Pos().Offset():declAssign.Value.End().Offset()])
} else if declAssign.Array != nil {
rtn.Value = string(src[declAssign.Array.Pos().Offset():declAssign.Array.End().Offset()])
} else {
return nil, fmt.Errorf("invalid decl, not plain value or array")
}
err := rtn.normalize()
if err != nil {
return nil, err
}
if err = rtn.Validate(); err != nil {
return nil, err
}
return rtn, nil
}
func parseDeclareOutput(state *packet.ShellState, declareBytes []byte, pvarBytes []byte) error {
declareStr := string(declareBytes)
r := bytes.NewReader(declareBytes)
parser := syntax.NewParser(syntax.Variant(syntax.LangBash))
file, err := parser.Parse(r, "aliases")
if err != nil {
return err
}
var firstParseErr error
declMap := make(map[string]*DeclareDeclType)
for _, stmt := range file.Stmts {
decl, err := parseDeclareStmt(stmt, declareStr)
if err != nil {
if firstParseErr == nil {
firstParseErr = err
}
}
if decl != nil && !NoStoreVarNames[decl.Name] {
declMap[decl.Name] = decl
}
}
pvars := bytes.Split(pvarBytes, []byte{0})
for _, pvarBA := range pvars {
pvarStr := string(pvarBA)
pvarFields := strings.SplitN(pvarStr, " ", 2)
if len(pvarFields) != 2 {
continue
}
if pvarFields[0] == "" {
continue
}
decl := &DeclareDeclType{Args: "x"}
decl.Name = "PROMPTVAR_" + pvarFields[0]
decl.Value = shellescape.Quote(pvarFields[1])
declMap[decl.Name] = decl
}
state.ShellVars = SerializeDeclMap(declMap) // this writes out the decls in a canonical order
if firstParseErr != nil {
state.Error = firstParseErr.Error()
}
return nil
}
func ParseShellStateOutput(outputBytes []byte) (*packet.ShellState, error) {
// 5 fields: version, cwd, env/vars, aliases, funcs
fields := bytes.Split(outputBytes, []byte{0, 0})
if len(fields) != 6 {
return nil, fmt.Errorf("invalid shell state output, wrong number of fields, fields=%d", len(fields))
}
rtn := &packet.ShellState{}
rtn.Version = strings.TrimSpace(string(fields[0]))
if strings.Index(rtn.Version, "bash") == -1 {
return nil, fmt.Errorf("invalid shell state output, only bash is supported")
}
rtn.Version = rtn.Version
cwdStr := string(fields[1])
if strings.HasSuffix(cwdStr, "\r\n") {
cwdStr = cwdStr[0 : len(cwdStr)-2]
} else if strings.HasSuffix(cwdStr, "\n") {
cwdStr = cwdStr[0 : len(cwdStr)-1]
}
rtn.Cwd = string(cwdStr)
err := parseDeclareOutput(rtn, fields[2], fields[5])
if err != nil {
return nil, err
}
rtn.Aliases = strings.ReplaceAll(string(fields[3]), "\r\n", "\n")
rtn.Funcs = strings.ReplaceAll(string(fields[4]), "\r\n", "\n")
rtn.Funcs = removeFunc(rtn.Funcs, "_mshell_exittrap")
return rtn, nil
}
func removeFunc(funcs string, toRemove string) string {
lines := strings.Split(funcs, "\n")
var newLines []string
removeLine := fmt.Sprintf("%s ()", toRemove)
doingRemove := false
for _, line := range lines {
if line == removeLine {
doingRemove = true
continue
}
if doingRemove {
if line == "}" {
doingRemove = false
}
continue
}
newLines = append(newLines, line)
}
return strings.Join(newLines, "\n")
}
func (d *DeclareDeclType) normalize() error {
if d.DataType() == DeclTypeAssocArray {
return d.normalizeAssocArrayDecl()
}
return nil
}
// normalizes order of assoc array keys so value is stable
func (d *DeclareDeclType) normalizeAssocArrayDecl() error {
if d.DataType() != DeclTypeAssocArray {
return fmt.Errorf("invalid decltype passed to assocArrayDeclToStr: %s", d.DataType())
}
varMap, err := assocArrayVarToMap(d)
if err != nil {
return err
}
keys := make([]string, 0, len(varMap))
for key, _ := range varMap {
keys = append(keys, key)
}
sort.Strings(keys)
var buf bytes.Buffer
buf.WriteByte('(')
for _, key := range keys {
buf.WriteByte('[')
buf.WriteString(key)
buf.WriteByte(']')
buf.WriteByte('=')
buf.WriteString(varMap[key])
buf.WriteByte(' ')
}
buf.WriteByte(')')
d.Value = buf.String()
return nil
}
func assocArrayVarToMap(d *DeclareDeclType) (map[string]string, error) {
if d.DataType() != DeclTypeAssocArray {
return nil, fmt.Errorf("decl is not an assoc-array")
}
refStr := "X=" + d.Value
r := strings.NewReader(refStr)
parser := syntax.NewParser(syntax.Variant(syntax.LangBash))
file, err := parser.Parse(r, "assocdecl")
if err != nil {
return nil, err
}
if len(file.Stmts) != 1 {
return nil, fmt.Errorf("invalid assoc-array parse (multiple stmts)")
}
stmt := file.Stmts[0]
callExpr, ok := stmt.Cmd.(*syntax.CallExpr)
if !ok || len(callExpr.Args) != 0 || len(callExpr.Assigns) != 1 {
return nil, fmt.Errorf("invalid assoc-array parse (bad expr)")
}
assign := callExpr.Assigns[0]
arrayExpr := assign.Array
if arrayExpr == nil {
return nil, fmt.Errorf("invalid assoc-array parse (no array expr)")
}
rtn := make(map[string]string)
for _, elem := range arrayExpr.Elems {
indexStr := refStr[elem.Index.Pos().Offset():elem.Index.End().Offset()]
valStr := refStr[elem.Value.Pos().Offset():elem.Value.End().Offset()]
rtn[indexStr] = valStr
}
return rtn, nil
}
func strMapsEqual(m1 map[string]string, m2 map[string]string) bool {
if len(m1) != len(m2) {
return false
}
for key, val1 := range m1 {
val2, found := m2[key]
if !found || val1 != val2 {
return false
}
}
for key, _ := range m2 {
_, found := m1[key]
if !found {
return false
}
}
return true
}
func DeclsEqual(compareName bool, d1 *DeclareDeclType, d2 *DeclareDeclType) bool {
if d1.IsExport() != d2.IsExport() {
return false
}
if d1.DataType() != d2.DataType() {
return false
}
if compareName && d1.Name != d2.Name {
return false
}
return d1.Value == d2.Value // this works even for assoc arrays because we normalize them when parsing
}
func MakeShellStateDiff(oldState packet.ShellState, oldStateHash string, newState packet.ShellState) (packet.ShellStateDiff, error) {
var rtn packet.ShellStateDiff
rtn.BaseHash = oldStateHash
if oldState.Version != newState.Version {
return rtn, fmt.Errorf("cannot diff, states have different versions")
}
rtn.Version = newState.Version
if oldState.Cwd != newState.Cwd {
rtn.Cwd = newState.Cwd
}
rtn.Error = newState.Error
oldVars := shellStateVarsToMap(oldState.ShellVars)
newVars := shellStateVarsToMap(newState.ShellVars)
rtn.VarsDiff = statediff.MakeMapDiff(oldVars, newVars)
rtn.AliasesDiff = statediff.MakeLineDiff(oldState.Aliases, newState.Aliases)
rtn.FuncsDiff = statediff.MakeLineDiff(oldState.Funcs, newState.Funcs)
return rtn, nil
}
func ApplyShellStateDiff(oldState packet.ShellState, diff packet.ShellStateDiff) (packet.ShellState, error) {
var rtnState packet.ShellState
var err error
rtnState.Version = oldState.Version
rtnState.Cwd = oldState.Cwd
if diff.Cwd != "" {
rtnState.Cwd = diff.Cwd
}
rtnState.Error = diff.Error
oldVars := shellStateVarsToMap(oldState.ShellVars)
newVars, err := statediff.ApplyMapDiff(oldVars, diff.VarsDiff)
if err != nil {
return rtnState, fmt.Errorf("applying mapdiff 'vars': %v", err)
}
rtnState.ShellVars = strMapToShellStateVars(newVars)
rtnState.Aliases, err = statediff.ApplyLineDiff(oldState.Aliases, diff.AliasesDiff)
if err != nil {
return rtnState, fmt.Errorf("applying diff 'aliases': %v", err)
}
rtnState.Funcs, err = statediff.ApplyLineDiff(oldState.Funcs, diff.FuncsDiff)
if err != nil {
return rtnState, fmt.Errorf("applying diff 'funcs': %v", err)
}
return rtnState, nil
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,222 @@
package simpleexpand
import (
"bytes"
"strings"
"mvdan.cc/sh/v3/expand"
"mvdan.cc/sh/v3/syntax"
)
type SimpleExpandContext struct {
HomeDir string
}
type SimpleExpandInfo struct {
HasTilde bool // only ~ as the first character when SimpleExpandContext.HomeDir is set
HasVar bool // $x, $$, ${...}
HasGlob bool // *, ?, [, {
HasExtGlob bool // ?(...) ... ?*+@!
HasHistory bool // ! (anywhere)
HasSpecial bool // subshell, arith
}
func expandHomeDir(info *SimpleExpandInfo, litVal string, multiPart bool, homeDir string) string {
if homeDir == "" {
return litVal
}
if litVal == "~" && !multiPart {
return homeDir
}
if strings.HasPrefix(litVal, "~/") {
info.HasTilde = true
return homeDir + litVal[1:]
}
return litVal
}
func expandLiteral(buf *bytes.Buffer, info *SimpleExpandInfo, litVal string) {
var lastBackSlash bool
var lastExtGlob bool
var lastDollar bool
for _, ch := range litVal {
if ch == 0 {
break
}
if lastBackSlash {
lastBackSlash = false
if ch == '\n' {
// special case, backslash *and* newline are ignored
continue
}
buf.WriteRune(ch)
continue
}
if ch == '\\' {
lastBackSlash = true
lastExtGlob = false
lastDollar = false
continue
}
if ch == '*' || ch == '?' || ch == '[' || ch == '{' {
info.HasGlob = true
}
if ch == '`' {
info.HasSpecial = true
}
if ch == '!' {
info.HasHistory = true
}
if lastExtGlob && ch == '(' {
info.HasExtGlob = true
}
if lastDollar && (ch != ' ' && ch != '"' && ch != '\'' && ch != '(' || ch != '[') {
info.HasVar = true
}
if lastDollar && (ch == '(' || ch == '[') {
info.HasSpecial = true
}
lastExtGlob = (ch == '?' || ch == '*' || ch == '+' || ch == '@' || ch == '!')
lastDollar = (ch == '$')
buf.WriteRune(ch)
}
if lastBackSlash {
buf.WriteByte('\\')
}
}
// also expands ~
func expandLiteralPlus(buf *bytes.Buffer, info *SimpleExpandInfo, litVal string, multiPart bool, ectx SimpleExpandContext) {
litVal = expandHomeDir(info, litVal, multiPart, ectx.HomeDir)
expandLiteral(buf, info, litVal)
}
func expandSQANSILiteral(buf *bytes.Buffer, litVal string) {
// no info specials
if strings.HasSuffix(litVal, "'") {
litVal = litVal[0 : len(litVal)-1]
}
str, _, _ := expand.Format(nil, litVal, nil)
buf.WriteString(str)
}
func expandSQLiteral(buf *bytes.Buffer, litVal string) {
// no info specials
if strings.HasSuffix(litVal, "'") {
litVal = litVal[0 : len(litVal)-1]
}
buf.WriteString(litVal)
}
// will also work for partial double quoted strings
func expandDQLiteral(buf *bytes.Buffer, info *SimpleExpandInfo, litVal string) {
var lastBackSlash bool
var lastDollar bool
for _, ch := range litVal {
if ch == 0 {
break
}
if lastBackSlash {
lastBackSlash = false
if ch == '"' || ch == '\\' || ch == '$' || ch == '`' {
buf.WriteRune(ch)
continue
}
buf.WriteRune('\\')
buf.WriteRune(ch)
continue
}
if ch == '\\' {
lastBackSlash = true
lastDollar = false
continue
}
if ch == '"' {
break
}
// similar to expandLiteral, but no globbing
if ch == '`' {
info.HasSpecial = true
}
if ch == '!' {
info.HasHistory = true
}
if lastDollar && (ch != ' ' && ch != '"' && ch != '\'' && ch != '(' || ch != '[') {
info.HasVar = true
}
if lastDollar && (ch == '(' || ch == '[') {
info.HasSpecial = true
}
lastDollar = (ch == '$')
buf.WriteRune(ch)
}
// in a valid parsed DQ string, you cannot have a trailing backslash (because \" would not end the string)
// still putting the case here though in case we ever deal with incomplete strings (e.g. completion)
if lastBackSlash {
buf.WriteByte('\\')
}
}
func simpleExpandWordInternal(buf *bytes.Buffer, info *SimpleExpandInfo, ectx SimpleExpandContext, parts []syntax.WordPart, sourceStr string, inDoubleQuote bool, level int) {
for partIdx, untypedPart := range parts {
switch part := untypedPart.(type) {
case *syntax.Lit:
if !inDoubleQuote && partIdx == 0 && level == 1 && ectx.HomeDir != "" {
expandLiteralPlus(buf, info, part.Value, len(parts) > 1, ectx)
} else if inDoubleQuote {
expandDQLiteral(buf, info, part.Value)
} else {
expandLiteral(buf, info, part.Value)
}
case *syntax.SglQuoted:
if part.Dollar {
expandSQANSILiteral(buf, part.Value)
} else {
expandSQLiteral(buf, part.Value)
}
case *syntax.DblQuoted:
simpleExpandWordInternal(buf, info, ectx, part.Parts, sourceStr, true, level+1)
default:
rawStr := sourceStr[part.Pos().Offset():part.End().Offset()]
buf.WriteString(rawStr)
}
}
}
// simple word expansion
// expands: literals, single-quoted strings, double-quoted strings (recursively)
// does *not* expand: params (variables), command substitution, arithmetic expressions, process substituions, globs
// for the not expands, they will show up as the literal string
// this is different than expand.Literal which will replace variables as empty string if they aren't defined.
// so "a"'foo'${bar}$x => "afoo${bar}$x", but expand.Literal would produce => "afoo"
// note will do ~ expansion (will not do ~user expansion)
func SimpleExpandWord(ectx SimpleExpandContext, word *syntax.Word, sourceStr string) (string, SimpleExpandInfo) {
var buf bytes.Buffer
var info SimpleExpandInfo
simpleExpandWordInternal(&buf, &info, ectx, word.Parts, sourceStr, false, 1)
return buf.String(), info
}
func SimpleExpandPartialWord(ectx SimpleExpandContext, partialWord string, multiPart bool) (string, SimpleExpandInfo) {
var buf bytes.Buffer
var info SimpleExpandInfo
if partialWord == "" {
return "", info
}
if strings.HasPrefix(partialWord, "\"") {
expandDQLiteral(&buf, &info, partialWord[1:])
} else if strings.HasPrefix(partialWord, "$\"") {
expandDQLiteral(&buf, &info, partialWord[2:])
} else if strings.HasPrefix(partialWord, "'") {
expandSQLiteral(&buf, partialWord[1:])
} else if strings.HasPrefix(partialWord, "$'") {
expandSQANSILiteral(&buf, partialWord[2:])
} else {
expandLiteralPlus(&buf, &info, partialWord, multiPart, ectx)
}
return buf.String(), info
}

View File

@ -0,0 +1,188 @@
package statediff
import (
"bytes"
"encoding/binary"
"fmt"
"strings"
)
const LineDiffVersion = 0
type SingleLineEntry struct {
LineVal int
Run int
}
type LineDiffType struct {
Lines []SingleLineEntry
NewData []string
}
func (diff LineDiffType) Dump() {
fmt.Printf("DIFF:\n")
pos := 1
for _, entry := range diff.Lines {
fmt.Printf(" %d-%d: %d\n", pos, pos+entry.Run, entry.LineVal)
pos += entry.Run
}
for idx, str := range diff.NewData {
fmt.Printf(" n%d: %s\n", idx+1, str)
}
}
// simple encoding
// a 0 means read a line from NewData
// a non-zero number means read the 1-indexed line from OldData
func (diff LineDiffType) applyDiff(oldData []string) ([]string, error) {
rtn := make([]string, 0, len(diff.Lines))
newDataPos := 0
for _, entry := range diff.Lines {
if entry.LineVal == 0 {
for i := 0; i < entry.Run; i++ {
if newDataPos >= len(diff.NewData) {
return nil, fmt.Errorf("not enough newdata for diff")
}
rtn = append(rtn, diff.NewData[newDataPos])
newDataPos++
}
} else {
oldDataPos := entry.LineVal - 1 // 1-indexed
for i := 0; i < entry.Run; i++ {
realPos := oldDataPos + i
if realPos < 0 || realPos >= len(oldData) {
return nil, fmt.Errorf("diff index out of bounds %d old-data-len:%d", realPos, len(oldData))
}
rtn = append(rtn, oldData[realPos])
}
}
}
return rtn, nil
}
func putUVarint(buf *bytes.Buffer, viBuf []byte, ival int) {
l := binary.PutUvarint(viBuf, uint64(ival))
buf.Write(viBuf[0:l])
}
// simple encoding
// write varints. first version, then len, then len-number-of-varints, then fill the rest with newdata
// [version] [len-varint] [varint]xlen... newdata (bytes)
func (diff LineDiffType) Encode() []byte {
var buf bytes.Buffer
viBuf := make([]byte, binary.MaxVarintLen64)
putUVarint(&buf, viBuf, LineDiffVersion)
putUVarint(&buf, viBuf, len(diff.Lines))
for _, entry := range diff.Lines {
putUVarint(&buf, viBuf, entry.LineVal)
putUVarint(&buf, viBuf, entry.Run)
}
for idx, str := range diff.NewData {
buf.WriteString(str)
if idx != len(diff.NewData)-1 {
buf.WriteByte('\n')
}
}
return buf.Bytes()
}
func (rtn *LineDiffType) Decode(diffBytes []byte) error {
r := bytes.NewBuffer(diffBytes)
version, err := binary.ReadUvarint(r)
if err != nil {
return fmt.Errorf("invalid diff, cannot read version: %v", err)
}
if version != LineDiffVersion {
return fmt.Errorf("invalid diff, bad version: %d", version)
}
linesLen64, err := binary.ReadUvarint(r)
if err != nil {
return fmt.Errorf("invalid diff, cannot read lines length: %v", err)
}
linesLen := int(linesLen64)
rtn.Lines = make([]SingleLineEntry, linesLen)
for idx := 0; idx < linesLen; idx++ {
lineVal, err := binary.ReadUvarint(r)
if err != nil {
return fmt.Errorf("invalid diff, cannot read line %d: %v", idx, err)
}
lineRun, err := binary.ReadUvarint(r)
if err != nil {
return fmt.Errorf("invalid diff, cannot read line-run %d: %v", idx, err)
}
rtn.Lines[idx] = SingleLineEntry{LineVal: int(lineVal), Run: int(lineRun)}
}
restOfInput := string(r.Bytes())
if len(restOfInput) > 0 {
rtn.NewData = strings.Split(restOfInput, "\n")
}
return nil
}
func makeLineDiff(oldData []string, newData []string) LineDiffType {
var rtn LineDiffType
oldDataMap := make(map[string]int) // 1-indexed
for idx, str := range oldData {
if _, found := oldDataMap[str]; found {
continue
}
oldDataMap[str] = idx + 1
}
var cur *SingleLineEntry
rtn.Lines = make([]SingleLineEntry, 0)
for _, str := range newData {
oldIdx, found := oldDataMap[str]
if cur != nil && cur.LineVal != 0 {
checkLine := cur.LineVal + cur.Run - 1
if checkLine < len(oldData) && oldData[checkLine] == str {
cur.Run++
continue
}
} else if cur != nil && cur.LineVal == 0 && !found {
cur.Run++
rtn.NewData = append(rtn.NewData, str)
continue
}
if cur != nil {
rtn.Lines = append(rtn.Lines, *cur)
}
cur = &SingleLineEntry{Run: 1}
if found {
cur.LineVal = oldIdx
} else {
cur.LineVal = 0
rtn.NewData = append(rtn.NewData, str)
}
}
if cur != nil {
rtn.Lines = append(rtn.Lines, *cur)
}
return rtn
}
func MakeLineDiff(str1 string, str2 string) []byte {
if str1 == str2 {
return nil
}
str1Arr := strings.Split(str1, "\n")
str2Arr := strings.Split(str2, "\n")
diff := makeLineDiff(str1Arr, str2Arr)
return diff.Encode()
}
func ApplyLineDiff(str1 string, diffBytes []byte) (string, error) {
if len(diffBytes) == 0 {
return str1, nil
}
var diff LineDiffType
err := diff.Decode(diffBytes)
if err != nil {
return "", err
}
str1Arr := strings.Split(str1, "\n")
str2Arr, err := diff.applyDiff(str1Arr)
if err != nil {
return "", err
}
return strings.Join(str2Arr, "\n"), nil
}

View File

@ -0,0 +1,130 @@
package statediff
import (
"bytes"
"encoding/binary"
"fmt"
)
const MapDiffVersion = 0
// 0-bytes are not allowed in entries or keys (same as bash)
type MapDiffType struct {
ToAdd map[string]string
ToRemove []string
}
func (diff MapDiffType) Dump() {
fmt.Printf("VAR-DIFF\n")
for name, val := range diff.ToAdd {
fmt.Printf(" add[%s] %s\n", name, val)
}
for _, name := range diff.ToRemove {
fmt.Printf(" rem[%s]\n", name)
}
}
func makeMapDiff(oldMap map[string]string, newMap map[string]string) MapDiffType {
var rtn MapDiffType
rtn.ToAdd = make(map[string]string)
for name, newVal := range newMap {
oldVal, found := oldMap[name]
if !found || oldVal != newVal {
rtn.ToAdd[name] = newVal
continue
}
}
for name, _ := range oldMap {
_, found := newMap[name]
if !found {
rtn.ToRemove = append(rtn.ToRemove, name)
}
}
return rtn
}
func (diff MapDiffType) apply(oldMap map[string]string) map[string]string {
rtn := make(map[string]string)
for name, val := range oldMap {
rtn[name] = val
}
for name, val := range diff.ToAdd {
rtn[name] = val
}
for _, name := range diff.ToRemove {
delete(rtn, name)
}
return rtn
}
func (diff MapDiffType) Encode() []byte {
var buf bytes.Buffer
viBuf := make([]byte, binary.MaxVarintLen64)
putUVarint(&buf, viBuf, MapDiffVersion)
putUVarint(&buf, viBuf, len(diff.ToAdd))
for key, val := range diff.ToAdd {
buf.WriteString(key)
buf.WriteByte(0)
buf.WriteString(val)
buf.WriteByte(0)
}
for _, val := range diff.ToRemove {
buf.WriteString(val)
buf.WriteByte(0)
}
return buf.Bytes()
}
func (diff *MapDiffType) Decode(diffBytes []byte) error {
r := bytes.NewBuffer(diffBytes)
version, err := binary.ReadUvarint(r)
if err != nil {
return fmt.Errorf("invalid diff, cannot read version: %v", err)
}
if version != MapDiffVersion {
return fmt.Errorf("invalid diff, bad version: %d", version)
}
mapLen64, err := binary.ReadUvarint(r)
if err != nil {
return fmt.Errorf("invalid diff, cannot map length: %v", err)
}
mapLen := int(mapLen64)
fields := bytes.Split(r.Bytes(), []byte{0})
if len(fields) < 2*mapLen {
return fmt.Errorf("invalid diff, not enough fields, maplen:%d fields:%d", mapLen, len(fields))
}
mapFields := fields[0 : 2*mapLen]
removeFields := fields[2*mapLen:]
diff.ToAdd = make(map[string]string)
for i := 0; i < len(mapFields); i += 2 {
diff.ToAdd[string(mapFields[i])] = string(mapFields[i+1])
}
for _, removeVal := range removeFields {
if len(removeVal) == 0 {
continue
}
diff.ToRemove = append(diff.ToRemove, string(removeVal))
}
return nil
}
func MakeMapDiff(m1 map[string]string, m2 map[string]string) []byte {
diff := makeMapDiff(m1, m2)
if len(diff.ToAdd) == 0 && len(diff.ToRemove) == 0 {
return nil
}
return diff.Encode()
}
func ApplyMapDiff(oldMap map[string]string, diffBytes []byte) (map[string]string, error) {
if len(diffBytes) == 0 {
return oldMap, nil
}
var diff MapDiffType
err := diff.Decode(diffBytes)
if err != nil {
return nil, err
}
return diff.apply(oldMap), nil
}

View File

@ -0,0 +1,99 @@
package statediff
import (
"fmt"
"testing"
)
const Str1 = `
hello
line #2
apple
grapes
banana
apple
`
const Str2 = `
line #2
apple
grapes
banana
`
const Str3 = `
more
stuff
banana
coconut
`
const Str4 = `
more
stuff
banana2
coconut
`
func testLineDiff(t *testing.T, str1 string, str2 string) {
diffBytes := MakeLineDiff(str1, str2)
fmt.Printf("diff-len: %d\n", len(diffBytes))
out, err := ApplyLineDiff(str1, diffBytes)
if err != nil {
t.Errorf("error in diff: %v", err)
return
}
if out != str2 {
t.Errorf("bad diff output")
}
var dt LineDiffType
err = dt.Decode(diffBytes)
if err != nil {
t.Errorf("error decoding diff: %v\n", err)
}
}
func TestLineDiff(t *testing.T) {
testLineDiff(t, Str1, Str2)
testLineDiff(t, Str2, Str3)
testLineDiff(t, Str1, Str3)
testLineDiff(t, Str3, Str1)
testLineDiff(t, Str3, Str4)
}
func strMapsEqual(m1 map[string]string, m2 map[string]string) bool {
if len(m1) != len(m2) {
return false
}
for key, val := range m1 {
val2, ok := m2[key]
if !ok || val != val2 {
return false
}
}
for key, val := range m2 {
val2, ok := m1[key]
if !ok || val != val2 {
return false
}
}
return true
}
func TestMapDiff(t *testing.T) {
m1 := map[string]string{"a": "5", "b": "hello", "c": "mike"}
m2 := map[string]string{"a": "5", "b": "goodbye", "d": "more"}
diffBytes := MakeMapDiff(m1, m2)
fmt.Printf("mapdifflen: %d\n", len(diffBytes))
var diff MapDiffType
diff.Decode(diffBytes)
diff.Dump()
mcheck, err := ApplyMapDiff(m1, diffBytes)
if err != nil {
t.Fatalf("error applying map diff: %v", err)
}
if !strMapsEqual(m2, mcheck) {
t.Errorf("maps not equal")
}
fmt.Printf("%v\n", mcheck)
}

18
waveshell/scripthaus.md Normal file
View File

@ -0,0 +1,18 @@
```bash
# @scripthaus command build
GO_LDFLAGS="-s -w -X main.BuildTime=$(date +'%Y%m%d%H%M')"
go build -ldflags="$GO_LDFLAGS" -o bin/mshell-v0.3-darwin.amd64 main-waveshell.go
```
```bash
# @scripthaus command fullbuild
GO_LDFLAGS="-s -w -X main.BuildTime=$(date +'%Y%m%d%H%M')"
go build -ldflags="$GO_LDFLAGS" -o ~/.mshell/mshell-v0.2 main-waveshell.go
GOOS=linux GOARCH=amd64 go build -ldflags="$GO_LDFLAGS" -o bin/mshell-v0.3-linux.amd64 main-waveshell.go
GOOS=linux GOARCH=arm64 go build -ldflags="$GO_LDFLAGS" -o bin/mshell-v0.3-linux.arm64 main-waveshell.go
GOOS=darwin GOARCH=amd64 go build -ldflags="$GO_LDFLAGS" -o bin/mshell-v0.3-darwin.amd64 main-waveshell.go
GOOS=darwin GOARCH=arm64 go build -ldflags="$GO_LDFLAGS" -o bin/mshell-v0.3-darwin.arm64 main-waveshell.go
```