waveterm/waveshell/pkg/base/base.go

425 lines
10 KiB
Go

// Copyright 2023, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package base
import (
"errors"
"fmt"
"io"
"io/fs"
"log"
"os"
"os/exec"
"path"
"path/filepath"
"strings"
"sync"
"github.com/google/uuid"
"golang.org/x/mod/semver"
)
const HomeVarName = "HOME"
const DefaultMShellHome = "~/.mshell"
const DefaultMShellName = "mshell"
const MShellPathVarName = "MSHELL_PATH"
const MShellHomeVarName = "MSHELL_HOME"
const MShellInstallBinVarName = "MSHELL_INSTALLBIN_PATH"
const SSHCommandVarName = "SSH_COMMAND"
const MShellDebugVarName = "MSHELL_DEBUG"
const SessionsDirBaseName = "sessions"
const RcFilesDirBaseName = "rcfiles"
const MShellVersion = "v0.4.0"
const RemoteIdFile = "remoteid"
const DefaultMShellInstallBinDir = "/opt/mshell/bin"
const LogFileName = "mshell.log"
const ForceDebugLog = false
const DebugFlag_LogRcFile = "logrc"
const LogRcFileName = "debug.rcfile"
const (
ProcessType_Unknown = "unknown"
ProcessType_WaveSrv = "wavesrv"
ProcessType_WaveShellSingle = "wsh-1"
ProcessType_WaveShellServer = "wsh-s"
)
// keys are sessionids (also the key RcFilesDirBaseName)
var ensureDirCache = make(map[string]bool)
var baseLock = &sync.Mutex{}
var DebugLogEnabled = false
var DebugLogger *log.Logger
var BuildTime string = "0"
var ProcessType string = ProcessType_Unknown
type CommandFileNames struct {
PtyOutFile string
StdinFifo string
RunnerOutFile string
}
type CommandKey string
func SetBuildTime(build string) {
BuildTime = build
}
func IsWaveSrv() bool {
return ProcessType == ProcessType_WaveSrv
}
func MakeCommandKey(sessionId string, cmdId string) CommandKey {
if sessionId == "" && cmdId == "" {
return CommandKey("")
}
return CommandKey(fmt.Sprintf("%s/%s", sessionId, cmdId))
}
func (ckey CommandKey) IsEmpty() bool {
return string(ckey) == ""
}
func Logf(fmtStr string, args ...interface{}) {
if (!DebugLogEnabled && !ForceDebugLog) || DebugLogger == nil {
return
}
DebugLogger.Printf(fmtStr, args...)
}
func InitDebugLog(prefix string) {
homeDir := GetMShellHomeDir()
err := os.MkdirAll(homeDir, 0777)
if err != nil {
return
}
logFile := path.Join(homeDir, LogFileName)
fd, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return
}
DebugLogger = log.New(fd, prefix+" ", log.LstdFlags)
Logf("logger initialized\n")
}
func SetEnableDebugLog(enable bool) {
DebugLogEnabled = enable
}
// deprecated (use GetGroupId instead)
func (ckey CommandKey) GetSessionId() string {
return ckey.GetGroupId()
}
func (ckey CommandKey) GetGroupId() string {
slashIdx := strings.Index(string(ckey), "/")
if slashIdx == -1 {
return ""
}
return string(ckey[0:slashIdx])
}
func (ckey CommandKey) GetCmdId() string {
slashIdx := strings.Index(string(ckey), "/")
if slashIdx == -1 {
return ""
}
return string(ckey[slashIdx+1:])
}
func (ckey CommandKey) Split() (string, string) {
fields := strings.SplitN(string(ckey), "/", 2)
if len(fields) < 2 {
return "", ""
}
return fields[0], fields[1]
}
func (ckey CommandKey) Validate(typeStr string) error {
if typeStr == "" {
typeStr = "ck"
}
if ckey == "" {
return fmt.Errorf("%s has empty commandkey", typeStr)
}
sessionId, cmdId := ckey.Split()
if sessionId == "" {
return fmt.Errorf("%s does not have sessionid", typeStr)
}
_, err := uuid.Parse(sessionId)
if err != nil {
return fmt.Errorf("%s has invalid sessionid '%s'", typeStr, sessionId)
}
if cmdId == "" {
return fmt.Errorf("%s does not have cmdid", typeStr)
}
_, err = uuid.Parse(cmdId)
if err != nil {
return fmt.Errorf("%s has invalid cmdid '%s'", typeStr, cmdId)
}
return nil
}
func HasDebugFlag(envMap map[string]string, flagName string) bool {
msDebug := envMap[MShellDebugVarName]
flags := strings.Split(msDebug, ",")
for _, flag := range flags {
if strings.TrimSpace(flag) == flagName {
return true
}
}
return false
}
func GetDebugRcFileName() string {
msHome := GetMShellHomeDir()
return path.Join(msHome, LogRcFileName)
}
func GetHomeDir() string {
homeVar := os.Getenv(HomeVarName)
if homeVar == "" {
return "/"
}
return homeVar
}
func GetMShellHomeDir() string {
homeVar := os.Getenv(MShellHomeVarName)
if homeVar != "" {
return homeVar
}
return ExpandHomeDir(DefaultMShellHome)
}
func GetCommandFileNames(ck CommandKey) (*CommandFileNames, error) {
if err := ck.Validate("ck"); err != nil {
return nil, fmt.Errorf("cannot get command files: %w", err)
}
sessionId, cmdId := ck.Split()
sdir, err := EnsureSessionDir(sessionId)
if err != nil {
return nil, err
}
base := path.Join(sdir, cmdId)
return &CommandFileNames{
PtyOutFile: base + ".ptyout",
StdinFifo: base + ".stdin",
RunnerOutFile: base + ".runout",
}, nil
}
func CleanUpCmdFiles(sessionId string, cmdId string) error {
if cmdId == "" {
return fmt.Errorf("bad cmdid, cannot clean up")
}
sdir, err := EnsureSessionDir(sessionId)
if err != nil {
return err
}
cmdFileGlob := path.Join(sdir, cmdId+".*")
matches, err := filepath.Glob(cmdFileGlob)
if err != nil {
return err
}
for _, file := range matches {
rmErr := os.Remove(file)
if err == nil && rmErr != nil {
err = rmErr
}
}
return err
}
func GetSessionsDir() string {
mhome := GetMShellHomeDir()
sdir := path.Join(mhome, SessionsDirBaseName)
return sdir
}
func EnsureRcFilesDir() (string, error) {
mhome := GetMShellHomeDir()
dirName := path.Join(mhome, RcFilesDirBaseName)
err := CacheEnsureDir(dirName, RcFilesDirBaseName, 0700, "rcfiles dir")
if err != nil {
return "", err
}
return dirName, nil
}
func EnsureSessionDir(sessionId string) (string, error) {
if sessionId == "" {
return "", fmt.Errorf("Bad sessionid, cannot be empty")
}
mhome := GetMShellHomeDir()
sdir := path.Join(mhome, SessionsDirBaseName, sessionId)
err := CacheEnsureDir(sdir, sessionId, 0777, "mshell session dir")
if err != nil {
return "", err
}
return sdir, nil
}
func GetMShellPath() (string, error) {
msPath := os.Getenv(MShellPathVarName) // use MSHELL_PATH
if msPath != "" {
return exec.LookPath(msPath)
}
mhome := GetMShellHomeDir()
userMShellPath := path.Join(mhome, DefaultMShellName) // look in ~/.mshell
msPath, err := exec.LookPath(userMShellPath)
if err == nil {
return msPath, nil
}
return exec.LookPath(DefaultMShellName) // standard path lookup for 'mshell'
}
func GetMShellSessionsDir() (string, error) {
mhome := GetMShellHomeDir()
return path.Join(mhome, SessionsDirBaseName), nil
}
func ExpandHomeDir(pathStr string) string {
if pathStr != "~" && !strings.HasPrefix(pathStr, "~/") {
return pathStr
}
homeDir := GetHomeDir()
if pathStr == "~" {
return homeDir
}
return path.Join(homeDir, pathStr[2:])
}
func ValidGoArch(goos string, goarch string) bool {
return (goos == "darwin" || goos == "linux") && (goarch == "amd64" || goarch == "arm64")
}
func GoArchOptFile(version string, goos string, goarch string) string {
installBinDir := os.Getenv(MShellInstallBinVarName)
if installBinDir == "" {
installBinDir = DefaultMShellInstallBinDir
}
versionStr := semver.MajorMinor(version)
if versionStr == "" {
versionStr = "unknown"
}
binBaseName := fmt.Sprintf("mshell-%s-%s.%s", versionStr, goos, goarch)
return fmt.Sprintf(path.Join(installBinDir, binBaseName))
}
func MShellBinaryFromOptDir(version string, goos string, goarch string) (io.ReadCloser, error) {
if !ValidGoArch(goos, goarch) {
return nil, fmt.Errorf("invalid goos/goarch combination: %s/%s", goos, goarch)
}
versionStr := semver.MajorMinor(version)
if versionStr == "" {
return nil, fmt.Errorf("invalid mshell version: %q", version)
}
fileName := GoArchOptFile(version, goos, goarch)
fd, err := os.Open(fileName)
if err != nil {
return nil, fmt.Errorf("cannot open mshell binary %q: %v", fileName, err)
}
return fd, nil
}
func GetRemoteId() (string, error) {
mhome := GetMShellHomeDir()
homeInfo, err := os.Stat(mhome)
if errors.Is(err, fs.ErrNotExist) {
err = os.MkdirAll(mhome, 0777)
if err != nil {
return "", fmt.Errorf("cannot make mshell home directory[%s]: %w", mhome, err)
}
homeInfo, err = os.Stat(mhome)
}
if err != nil {
return "", fmt.Errorf("cannot stat mshell home directory[%s]: %w", mhome, err)
}
if !homeInfo.IsDir() {
return "", fmt.Errorf("mshell home directory[%s] is not a directory", mhome)
}
remoteIdFile := path.Join(mhome, RemoteIdFile)
fd, err := os.Open(remoteIdFile)
if errors.Is(err, fs.ErrNotExist) {
// write the file
remoteId := uuid.New().String()
err = os.WriteFile(remoteIdFile, []byte(remoteId), 0644)
if err != nil {
return "", fmt.Errorf("cannot write remoteid to '%s': %w", remoteIdFile, err)
}
return remoteId, nil
} else if err != nil {
return "", fmt.Errorf("cannot read remoteid file '%s': %w", remoteIdFile, err)
} else {
defer fd.Close()
contents, err := io.ReadAll(fd)
if err != nil {
return "", fmt.Errorf("cannot read remoteid file '%s': %w", remoteIdFile, err)
}
uuidStr := string(contents)
_, err = uuid.Parse(uuidStr)
if err != nil {
return "", fmt.Errorf("invalid uuid read from '%s': %w", remoteIdFile, err)
}
return uuidStr, nil
}
}
func BoundInt(ival int, minVal int, maxVal int) int {
if ival < minVal {
return minVal
}
if ival > maxVal {
return maxVal
}
return ival
}
func BoundInt64(ival int64, minVal int64, maxVal int64) int64 {
if ival < minVal {
return minVal
}
if ival > maxVal {
return maxVal
}
return ival
}
func CacheEnsureDir(dirName string, cacheKey string, perm os.FileMode, dirDesc string) error {
baseLock.Lock()
ok := ensureDirCache[cacheKey]
baseLock.Unlock()
if ok {
return nil
}
err := TryMkdirs(dirName, perm, dirDesc)
if err != nil {
return err
}
baseLock.Lock()
ensureDirCache[cacheKey] = true
baseLock.Unlock()
return nil
}
func TryMkdirs(dirName string, perm os.FileMode, dirDesc string) error {
info, err := os.Stat(dirName)
if errors.Is(err, fs.ErrNotExist) {
err = os.MkdirAll(dirName, perm)
if err != nil {
return fmt.Errorf("cannot make %s %q: %w", dirDesc, dirName, err)
}
info, err = os.Stat(dirName)
}
if err != nil {
return fmt.Errorf("error trying to stat %s: %w", dirDesc, err)
}
if !info.IsDir() {
return fmt.Errorf("%s %q must be a directory", dirDesc, dirName)
}
return nil
}