bash v3 has a bug with reading large rcfiles from a /dev/fd pipe. fallback to writing a tmp file in that case to allow for large (128k+) rcfiles

This commit is contained in:
sawka 2023-10-25 15:20:25 -07:00
parent 7c966e2660
commit 737e08b583
5 changed files with 133 additions and 39 deletions

View File

@ -144,6 +144,7 @@ go build -ldflags="$GO_LDFLAGS" -o bin/mshell-v0.3-darwin.amd64 main-waveshell.g
```bash
# @scripthaus command fullbuild-waveshell
set -e
cd waveshell
GO_LDFLAGS="-s -w -X main.BuildTime=$(date +'%Y%m%d%H%M')"
go build -ldflags="$GO_LDFLAGS" -o ~/.mshell/mshell-v0.2 main-waveshell.go

View File

@ -542,15 +542,6 @@ func main() {
} else if firstArg == "--version" {
fmt.Printf("mshell %s+%s\n", base.MShellVersion, base.BuildTime)
return
} else if firstArg == "--test-env" {
state, err := shexec.GetShellState()
if state != nil {
}
if err != nil {
fmt.Fprintf(os.Stderr, "[error] %v\n", err)
os.Exit(1)
}
} else if firstArg == "--single" {
base.InitDebugLog("single")
handleSingle(false)

View File

@ -29,6 +29,7 @@ const MShellInstallBinVarName = "MSHELL_INSTALLBIN_PATH"
const SSHCommandVarName = "SSH_COMMAND"
const MShellDebugVarName = "MSHELL_DEBUG"
const SessionsDirBaseName = "sessions"
const RcFilesDirBaseName = "rcfiles"
const MShellVersion = "v0.3.0"
const RemoteIdFile = "remoteid"
const DefaultMShellInstallBinDir = "/opt/mshell/bin"
@ -38,7 +39,8 @@ const ForceDebugLog = false
const DebugFlag_LogRcFile = "logrc"
const LogRcFileName = "debug.rcfile"
var sessionDirCache = make(map[string]string)
// keys are sessionids (also the key RcFilesDirBaseName)
var ensureDirCache = make(map[string]bool)
var baseLock = &sync.Mutex{}
var DebugLogEnabled = false
var DebugLogger *log.Logger
@ -225,35 +227,26 @@ func GetSessionsDir() string {
return sdir
}
func EnsureRcFilesDir() (string, error) {
mhome := GetMShellHomeDir()
dirName := path.Join(mhome, RcFilesDirBaseName)
err := CacheEnsureDir(dirName, RcFilesDirBaseName, 0700, "rcfiles dir")
if err != nil {
return "", err
}
return dirName, nil
}
func EnsureSessionDir(sessionId string) (string, error) {
if sessionId == "" {
return "", fmt.Errorf("Bad sessionid, cannot be empty")
}
baseLock.Lock()
sdir, ok := sessionDirCache[sessionId]
baseLock.Unlock()
if ok {
return sdir, nil
}
mhome := GetMShellHomeDir()
sdir = path.Join(mhome, SessionsDirBaseName, sessionId)
info, err := os.Stat(sdir)
if errors.Is(err, fs.ErrNotExist) {
err = os.MkdirAll(sdir, 0777)
if err != nil {
return "", fmt.Errorf("cannot make mshell session directory[%s]: %w", sdir, err)
}
info, err = os.Stat(sdir)
}
sdir := path.Join(mhome, SessionsDirBaseName, sessionId)
err := CacheEnsureDir(sdir, sessionId, 0777, "mshell session dir")
if err != nil {
return "", err
}
if !info.IsDir() {
return "", fmt.Errorf("session dir '%s' must be a directory", sdir)
}
baseLock.Lock()
sessionDirCache[sessionId] = sdir
baseLock.Unlock()
return sdir, nil
}
@ -382,3 +375,38 @@ func BoundInt64(ival int64, minVal int64, maxVal int64) int64 {
}
return ival
}
func CacheEnsureDir(dirName string, cacheKey string, perm os.FileMode, dirDesc string) error {
baseLock.Lock()
ok := ensureDirCache[cacheKey]
baseLock.Unlock()
if ok {
return nil
}
err := TryMkdirs(dirName, perm, dirDesc)
if err != nil {
return err
}
baseLock.Lock()
ensureDirCache[cacheKey] = true
baseLock.Unlock()
return nil
}
func TryMkdirs(dirName string, perm os.FileMode, dirDesc string) error {
info, err := os.Stat(dirName)
if errors.Is(err, fs.ErrNotExist) {
err = os.MkdirAll(dirName, perm)
if err != nil {
return fmt.Errorf("cannot make %s %q: %w", dirDesc, dirName, err)
}
info, err = os.Stat(dirName)
}
if err != nil {
return fmt.Errorf("error trying to stat %s: %w", dirDesc, err)
}
if !info.IsDir() {
return fmt.Errorf("%s %q must be a directory", dirDesc, dirName)
}
return nil
}

View File

@ -9,9 +9,11 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"strings"
"github.com/wavetermdev/waveterm/waveshell/pkg/binpack"
"github.com/wavetermdev/waveterm/waveshell/pkg/statediff"
"golang.org/x/mod/semver"
)
const ShellStatePackVersion = 0
@ -63,6 +65,18 @@ func (state ShellState) EncodeAndHash() (string, []byte) {
return sha1Hash(buf.Bytes()), buf.Bytes()
}
// returns a string like "v4" ("" is an unparseable version)
func GetBashMajorVersion(versionStr string) string {
if versionStr == "" {
return ""
}
fields := strings.Split(versionStr, " ")
if len(fields) < 2 {
return ""
}
return semver.Major(fields[1])
}
func (state ShellState) MarshalJSON() ([]byte, error) {
_, encodedBytes := state.EncodeAndHash()
return json.Marshal(encodedBytes)

View File

@ -13,6 +13,7 @@ import (
"os/exec"
"os/signal"
"os/user"
"path"
"runtime"
"strconv"
"strings"
@ -21,11 +22,12 @@ import (
"time"
"github.com/alessio/shellescape"
"github.com/creack/pty"
"github.com/google/uuid"
"github.com/wavetermdev/waveterm/waveshell/pkg/base"
"github.com/wavetermdev/waveterm/waveshell/pkg/cirfile"
"github.com/wavetermdev/waveterm/waveshell/pkg/mpio"
"github.com/wavetermdev/waveterm/waveshell/pkg/packet"
"github.com/creack/pty"
"golang.org/x/mod/semver"
"golang.org/x/sys/unix"
)
@ -50,6 +52,12 @@ const GetStateTimeout = 5 * time.Second
const BaseBashOpts = `set +m; set +H; shopt -s extglob`
const ShellVersionCmdStr = `echo bash v${BASH_VERSINFO[0]}.${BASH_VERSINFO[1]}.${BASH_VERSINFO[2]}`
// do not use these directly, call GetLocalBashMajorVersion()
var LocalBashMajorVersionOnce = &sync.Once{}
var LocalBashMajorVersion = ""
var GetShellStateCmds = []string{
`echo bash v${BASH_VERSINFO[0]}.${BASH_VERSINFO[1]}.${BASH_VERSINFO[2]};`,
`pwd;`,
@ -125,6 +133,7 @@ type ShExecType struct {
MsgSender *packet.PacketSender // where to send out-of-band messages back to calling proceess
ReturnState *ReturnStateBuf
Exited bool // locked via Lock
TmpRcFileName string
}
type StdContext struct{}
@ -240,6 +249,9 @@ func (c *ShExecType) Close() {
if c.ReturnState != nil {
c.ReturnState.Reader.Close()
}
if c.TmpRcFileName != "" {
os.Remove(c.TmpRcFileName)
}
}
func (c *ShExecType) MakeCmdStartPacket(reqId string) *packet.CmdStartPacketType {
@ -1087,14 +1099,37 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro
base.Logf("error writing %s: %v\n", debugRcFileName, err)
}
}
rcFileFdNum, err := AddRunData(pk, rcFileStr, "rcfile")
if err != nil {
return nil, err
bashVersion := GetLocalBashMajorVersion()
isOldBashVersion := (semver.Compare(bashVersion, "v4") < 0)
var rcFileName string
if isOldBashVersion {
rcFileDir, err := base.EnsureRcFilesDir()
if err != nil {
return nil, err
}
rcFileName = path.Join(rcFileDir, uuid.New().String())
err = os.WriteFile(rcFileName, []byte(rcFileStr), 0600)
if err != nil {
return nil, fmt.Errorf("could not write temp rcfile: %w", err)
}
cmd.TmpRcFileName = rcFileName
go func() {
// cmd.Close() will also remove rcFileName
// adding this to also try to proactively clean up after 1-second.
time.Sleep(1 * time.Second)
os.Remove(rcFileName)
}()
} else {
rcFileFdNum, err := AddRunData(pk, rcFileStr, "rcfile")
if err != nil {
return nil, err
}
rcFileName = fmt.Sprintf("/dev/fd/%d", rcFileFdNum)
}
if pk.UsePty {
cmd.Cmd = exec.Command("bash", "--rcfile", fmt.Sprintf("/dev/fd/%d", rcFileFdNum), "-i", "-c", pk.Command)
cmd.Cmd = exec.Command("bash", "--rcfile", rcFileName, "-i", "-c", pk.Command)
} else {
cmd.Cmd = exec.Command("bash", "--rcfile", fmt.Sprintf("/dev/fd/%d", rcFileFdNum), "-c", pk.Command)
cmd.Cmd = exec.Command("bash", "--rcfile", rcFileName, "-c", pk.Command)
}
if !pk.StateComplete {
cmd.Cmd.Env = os.Environ()
@ -1103,7 +1138,7 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro
if state.Cwd != "" {
cmd.Cmd.Dir = base.ExpandHomeDir(state.Cwd)
}
err = ValidateRemoteFds(pk.Fds)
err := ValidateRemoteFds(pk.Fds)
if err != nil {
return nil, err
}
@ -1524,7 +1559,8 @@ func GetShellStateRedirectCommandStr(outputFdNum int) string {
}
func GetShellState() (*packet.ShellState, error) {
ctx, _ := context.WithTimeout(context.Background(), GetStateTimeout)
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
defer cancelFn()
cmdStr := BaseBashOpts + "; " + GetShellStateCmd()
ecmd := exec.CommandContext(ctx, "bash", "-l", "-i", "-c", cmdStr)
outputBytes, err := runSimpleCmdInPty(ecmd)
@ -1543,3 +1579,27 @@ func MShellEnvVars(termType string) map[string]string {
rtn["MSHELL_VERSION"] = base.MShellVersion
return rtn
}
func ExecGetLocalShellVersion() string {
ctx, cancelFn := context.WithTimeout(context.Background(), GetStateTimeout)
defer cancelFn()
ecmd := exec.CommandContext(ctx, "bash", "-c", ShellVersionCmdStr)
out, err := ecmd.Output()
if err != nil {
return ""
}
versionStr := strings.TrimSpace(string(out))
if strings.Index(versionStr, "bash ") == -1 {
// invalid shell version (only bash is supported)
return ""
}
return versionStr
}
func GetLocalBashMajorVersion() string {
LocalBashMajorVersionOnce.Do(func() {
fullVersion := ExecGetLocalShellVersion()
LocalBashMajorVersion = packet.GetBashMajorVersion(fullVersion)
})
return LocalBashMajorVersion
}