diff --git a/waveshell/pkg/packet/packet.go b/waveshell/pkg/packet/packet.go index d3c6e38e5..78dbf6e66 100644 --- a/waveshell/pkg/packet/packet.go +++ b/waveshell/pkg/packet/packet.go @@ -65,6 +65,8 @@ const ( LogPacketStr = "log" // logging packet (sent from waveshell back to server) ShellStatePacketStr = "shellstate" RpcInputPacketStr = "rpcinput" // rpc-followup + SudoRequestPacketStr = "sudorequest" + SudoResponsePacketStr = "sudoresponse" OpenAIPacketStr = "openai" // other OpenAICloudReqStr = "openai-cloudreq" @@ -120,6 +122,8 @@ func init() { TypeStrToFactory[ShellStatePacketStr] = reflect.TypeOf(ShellStatePacketType{}) TypeStrToFactory[FileStatPacketStr] = reflect.TypeOf(FileStatPacketType{}) TypeStrToFactory[RpcInputPacketStr] = reflect.TypeOf(RpcInputPacketType{}) + TypeStrToFactory[SudoRequestPacketStr] = reflect.TypeOf(SudoRequestPacketType{}) + TypeStrToFactory[SudoResponsePacketStr] = reflect.TypeOf(SudoResponsePacketType{}) var _ RpcPacketType = (*RunPacketType)(nil) var _ RpcPacketType = (*GetCmdPacketType)(nil) @@ -146,6 +150,7 @@ func init() { var _ CommandPacketType = (*CmdDonePacketType)(nil) var _ CommandPacketType = (*SpecialInputPacketType)(nil) var _ CommandPacketType = (*CmdFinalPacketType)(nil) + var _ CommandPacketType = (*SudoResponsePacketType)(nil) } func RegisterPacketType(typeStr string, rtype reflect.Type) { @@ -827,6 +832,7 @@ type RunPacketType struct { RunData []RunDataType `json:"rundata,omitempty"` Detached bool `json:"detached,omitempty"` ReturnState bool `json:"returnstate,omitempty"` + IsSudo bool `json:"issudo,omitempty"` } func (*RunPacketType) GetType() string { @@ -979,6 +985,55 @@ func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType { } } +type SudoRequestPacketType struct { + Type string `json:"type"` + CK base.CommandKey `json:"ck"` + ShellPubKey []byte `json:"shellpubkey"` + SudoStatus string `json:"sudostatus"` + ErrStr string `json:"errstr"` +} + +func (*SudoRequestPacketType) GetType() string { + return SudoRequestPacketStr +} + +func (p *SudoRequestPacketType) GetCK() base.CommandKey { + return p.CK +} + +func MakeSudoRequestPacket(ck base.CommandKey, pubKey []byte, sudoStatus string) *SudoRequestPacketType { + return &SudoRequestPacketType{ + Type: SudoRequestPacketStr, + CK: ck, + ShellPubKey: pubKey, + SudoStatus: sudoStatus, + } +} + +type SudoResponsePacketType struct { + Type string `json:"type"` + CK base.CommandKey `json:"ck"` + Secret []byte `json:"secret"` + SrvPubKey []byte `json:"srvpubkey"` +} + +func (*SudoResponsePacketType) GetType() string { + return SudoResponsePacketStr +} + +func (p *SudoResponsePacketType) GetCK() base.CommandKey { + return p.CK +} + +func MakeSudoResponsePacket(ck base.CommandKey, secret []byte, srvPubKey []byte) *SudoResponsePacketType { + return &SudoResponsePacketType{ + Type: SudoResponsePacketStr, + CK: ck, + Secret: secret, + SrvPubKey: srvPubKey, + } +} + type PacketType interface { GetType() string } diff --git a/waveshell/pkg/shexec/shexec.go b/waveshell/pkg/shexec/shexec.go index 7a5bdef51..67b9dee4f 100644 --- a/waveshell/pkg/shexec/shexec.go +++ b/waveshell/pkg/shexec/shexec.go @@ -4,8 +4,12 @@ package shexec import ( + "bufio" "bytes" "context" + "crypto/ecdh" + "crypto/rand" + "crypto/x509" "encoding/base64" "fmt" "io" @@ -33,6 +37,7 @@ import ( "github.com/wavetermdev/waveterm/waveshell/pkg/shellutil" "github.com/wavetermdev/waveterm/waveshell/pkg/utilfn" "github.com/wavetermdev/waveterm/waveshell/pkg/wlog" + "github.com/wavetermdev/waveterm/wavesrv/pkg/waveenc" "golang.org/x/mod/semver" "golang.org/x/sys/unix" ) @@ -120,6 +125,8 @@ type ShExecType struct { Exited bool // locked via Lock TmpRcFileName string // file *or* directory holding temporary rc file(s) SAPI shellapi.ShellApi + ShellPrivKey *ecdh.PrivateKey + SudoWriter *os.File } type StdContext struct{} @@ -191,6 +198,23 @@ func (s *ShExecType) processSpecialInputPacket(pk *packet.SpecialInputPacketType return nil } +func (s ShExecUPR) processSudoResponsePacket(sudoPacket *packet.SudoResponsePacketType) error { + encryptor, err := waveenc.MakeEncryptorEcdh(s.ShExec.ShellPrivKey, sudoPacket.SrvPubKey) + if err != nil { + return err + } + decrypted, err := encryptor.DecryptData(sudoPacket.Secret, "sudopw") + if err != nil { + return fmt.Errorf("decrypt secret: %e", err) + } + decrypted = append(decrypted, '\n') + _, err = s.ShExec.SudoWriter.Write(decrypted) + if err != nil { + return fmt.Errorf("unable to write secret to stdin: %e", err) + } + return nil +} + func (s ShExecUPR) UnknownPacket(pk packet.PacketType) { if pk.GetType() == packet.SpecialInputPacketStr { inputPacket := pk.(*packet.SpecialInputPacketType) @@ -202,6 +226,16 @@ func (s ShExecUPR) UnknownPacket(pk packet.PacketType) { } return } + if pk.GetType() == packet.SudoResponsePacketStr { + sudoPacket := pk.(*packet.SudoResponsePacketType) + err := s.processSudoResponsePacket(sudoPacket) + if err != nil { + sudoRequest := packet.MakeSudoRequestPacket(sudoPacket.CK, nil, "error") + sudoRequest.ErrStr = err.Error() + s.ShExec.MsgSender.SendPacket(sudoRequest) + } + return + } if s.UPR != nil { s.UPR.UnknownPacket(pk) } @@ -904,6 +938,15 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro // this ensures that the last command is a shell buitin so we always get our exit trap to run fullCmdStr = fullCmdStr + "\nexit $? 2> /dev/null" } + + var sudoKey uuid.UUID + var sudoErrKey uuid.UUID + if pk.IsSudo { + sudoKey = uuid.New() + sudoErrKey = uuid.New() + fullCmdStr = fmt.Sprintf("sudo -p \"%s\" -S true 2>&7 <&6; if [ $? != 0 ]; then echo %s >&7 && exit; fi; exec 6>&-; exec 7>&-; %s", sudoKey, sudoErrKey, fullCmdStr) + } + cmd.Cmd = sapi.MakeShExecCommand(fullCmdStr, rcFileName, pk.UsePty) if !pk.StateComplete { cmd.Cmd.Env = os.Environ() @@ -974,6 +1017,66 @@ func RunCommandSimple(pk *packet.RunPacketType, sender *packet.PacketSender, fro return nil, err } } + if pk.IsSudo { + readToSudo, writeToSudo, err := os.Pipe() + if err != nil { + return nil, err + } + + readFromSudo, writeFromSudo, err := os.Pipe() + if err != nil { + return nil, err + } + + go func() { + reader := bufio.NewReader(readFromSudo) + buffer := bytes.NewBuffer(make([]byte, 0)) + chunk := make([]byte, 1024) + firstAttempt := true + for { + len, _ := reader.Read(chunk) + buffer.Write(chunk[:len]) + if bytes.Contains(buffer.Bytes(), []byte(sudoKey.String())) { + buffer.Reset() + + // subsequent attempts get an extra \n + sudoStatus := "followup-attempt" + if firstAttempt { + sudoStatus = "first-attempt" + } + firstAttempt = false + + shellPrivKey, err := ecdh.P256().GenerateKey(rand.Reader) + if err != nil { + sudoRequest := packet.MakeSudoRequestPacket(cmd.CK, nil, "error") + sudoRequest.ErrStr = fmt.Sprintf("generate ecdh: %s", err.Error()) + rtnShExec.MsgSender.SendPacket(sudoRequest) + return + } + shellPubKey, err := x509.MarshalPKIXPublicKey(shellPrivKey.PublicKey()) + if err != nil { + sudoRequest := packet.MakeSudoRequestPacket(cmd.CK, nil, "error") + sudoRequest.ErrStr = fmt.Sprintf("marshal pub key: %s", err.Error()) + rtnShExec.MsgSender.SendPacket(sudoRequest) + return + } + rtnShExec.ShellPrivKey = shellPrivKey + rtnShExec.SudoWriter = writeToSudo + sudoRequest := packet.MakeSudoRequestPacket(cmd.CK, shellPubKey, sudoStatus) + rtnShExec.MsgSender.SendPacket(sudoRequest) + } else if bytes.Contains(buffer.Bytes(), []byte(sudoErrKey.String())) { + sudoRequest := packet.MakeSudoRequestPacket(cmd.CK, nil, "failure") + rtnShExec.MsgSender.SendPacket(sudoRequest) + } + } + }() + + if 7 >= len(extraFiles) { + extraFiles = extraFiles[:7+1] + } + extraFiles[6] = readToSudo //todo - make a constant for the 6 + extraFiles[7] = writeFromSudo // todo - same + } for _, rfd := range pk.Fds { if rfd.FdNum >= len(extraFiles) { extraFiles = extraFiles[:rfd.FdNum+1] diff --git a/wavesrv/cmd/main-server.go b/wavesrv/cmd/main-server.go index e139a8163..8c973932b 100644 --- a/wavesrv/cmd/main-server.go +++ b/wavesrv/cmd/main-server.go @@ -39,7 +39,6 @@ import ( "github.com/wavetermdev/waveterm/wavesrv/pkg/cmdrunner" "github.com/wavetermdev/waveterm/wavesrv/pkg/ephemeral" "github.com/wavetermdev/waveterm/wavesrv/pkg/pcloud" - "github.com/wavetermdev/waveterm/wavesrv/pkg/promptenc" "github.com/wavetermdev/waveterm/wavesrv/pkg/releasechecker" "github.com/wavetermdev/waveterm/wavesrv/pkg/remote" "github.com/wavetermdev/waveterm/wavesrv/pkg/rtnstate" @@ -49,6 +48,7 @@ import ( "github.com/wavetermdev/waveterm/wavesrv/pkg/scws" "github.com/wavetermdev/waveterm/wavesrv/pkg/sstore" "github.com/wavetermdev/waveterm/wavesrv/pkg/telemetry" + "github.com/wavetermdev/waveterm/wavesrv/pkg/waveenc" "github.com/wavetermdev/waveterm/wavesrv/pkg/wsshell" ) @@ -830,7 +830,7 @@ func AuthKeyWrapAllowHmac(fn WebFnType) WebFnType { w.Write([]byte("no x-authkey header")) return } - hmacOk, err := promptenc.ValidateUrlHmac([]byte(scbase.WaveAuthKey), r.URL.Path, qvals) + hmacOk, err := waveenc.ValidateUrlHmac([]byte(scbase.WaveAuthKey), r.URL.Path, qvals) if err != nil || !hmacOk { w.WriteHeader(http.StatusInternalServerError) w.Write([]byte(fmt.Sprintf("error validating hmac"))) diff --git a/wavesrv/pkg/bufferedpipe/bufferedpipe.go b/wavesrv/pkg/bufferedpipe/bufferedpipe.go index 3877c443f..3263182ba 100644 --- a/wavesrv/pkg/bufferedpipe/bufferedpipe.go +++ b/wavesrv/pkg/bufferedpipe/bufferedpipe.go @@ -15,8 +15,8 @@ import ( "time" "github.com/google/uuid" - "github.com/wavetermdev/waveterm/wavesrv/pkg/promptenc" "github.com/wavetermdev/waveterm/wavesrv/pkg/scbase" + "github.com/wavetermdev/waveterm/wavesrv/pkg/waveenc" ) const ( @@ -54,7 +54,7 @@ func (pipe *BufferedPipe) GetOutputUrl() (string, error) { qvals := make(url.Values) qvals.Set("key", pipe.Key) qvals.Set("nonce", uuid.New().String()) - hmacStr, err := promptenc.ComputeUrlHmac([]byte(scbase.WaveAuthKey), BufferedPipeGetterUrl, qvals) + hmacStr, err := waveenc.ComputeUrlHmac([]byte(scbase.WaveAuthKey), BufferedPipeGetterUrl, qvals) if err != nil { return "", err } diff --git a/wavesrv/pkg/cmdrunner/cmdrunner.go b/wavesrv/pkg/cmdrunner/cmdrunner.go index e71637f7e..eaebd5562 100644 --- a/wavesrv/pkg/cmdrunner/cmdrunner.go +++ b/wavesrv/pkg/cmdrunner/cmdrunner.go @@ -40,7 +40,6 @@ import ( "github.com/wavetermdev/waveterm/wavesrv/pkg/ephemeral" "github.com/wavetermdev/waveterm/wavesrv/pkg/history" "github.com/wavetermdev/waveterm/wavesrv/pkg/pcloud" - "github.com/wavetermdev/waveterm/wavesrv/pkg/promptenc" "github.com/wavetermdev/waveterm/wavesrv/pkg/releasechecker" "github.com/wavetermdev/waveterm/wavesrv/pkg/remote" "github.com/wavetermdev/waveterm/wavesrv/pkg/remote/openai" @@ -50,6 +49,7 @@ import ( "github.com/wavetermdev/waveterm/wavesrv/pkg/scpacket" "github.com/wavetermdev/waveterm/wavesrv/pkg/sstore" "github.com/wavetermdev/waveterm/wavesrv/pkg/telemetry" + "github.com/wavetermdev/waveterm/wavesrv/pkg/waveenc" "golang.org/x/mod/semver" ) @@ -88,7 +88,6 @@ const OpenAIPacketTimeout = 10 * time.Second const OpenAIStreamTimeout = 5 * time.Minute const OpenAICloudCompletionTelemetryOffErrorMsg = "To ensure responsible usage and prevent misuse, Wave AI requires telemetry to be enabled when using its free AI features.\n\nIf you prefer not to enable telemetry, you can still access Wave AI's features by providing your own OpenAI API key in the Settings menu. Please note that when using your personal API key, requests will be sent directly to the OpenAI API without being proxied through Wave's servers.\n\nIf you wish to continue using Wave AI's free features, you can easily enable telemetry by running the '/telemetry:on' command in the terminal. This will allow you to access the free AI features while helping to protect the platform from abuse." - const ( KwArgRenderer = "renderer" KwArgView = "view" @@ -97,6 +96,7 @@ const ( KwArgLang = "lang" KwArgMinimap = "minimap" KwArgNoHist = "nohist" + KwArgSudo = "sudo" ) var ColorNames = []string{"yellow", "blue", "pink", "mint", "cyan", "violet", "orange", "green", "red", "white"} @@ -296,6 +296,8 @@ func init() { registerCmdFn("csvview", CSVViewCommand) registerCmdFn("_debug:ri", DebugRemoteInstanceCommand) + + registerCmdFn("sudo:clear", ClearSudoCache) } func getValidCommands() []string { @@ -608,6 +610,7 @@ func RunCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scbus.Up if err != nil { return nil, fmt.Errorf("/run error, invalid lang: %w", err) } + cmdStr := firstArg(pk) expandedCmdStr, err := doCmdHistoryExpansion(ctx, ids, cmdStr) if err != nil { @@ -639,6 +642,11 @@ func RunCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scbus.Up } runPacket.Command = strings.TrimSpace(cmdStr) runPacket.ReturnState = resolveBool(pk.Kwargs["rtnstate"], isRtnStateCmd) + if sudoArg, ok := pk.Kwargs[KwArgSudo]; ok { + runPacket.IsSudo = resolveBool(sudoArg, false) + } else { + runPacket.IsSudo = IsSudoCommand(cmdStr) + } rcOpts := remote.RunCommandOpts{ SessionId: ids.SessionId, ScreenId: ids.ScreenId, @@ -3916,6 +3924,30 @@ func DebugRemoteInstanceCommand(ctx context.Context, pk *scpacket.FeCommandPacke return update, nil } +func ClearSudoCache(ctx context.Context, pk *scpacket.FeCommandPacketType) (rtnUpdate scbus.UpdatePacket, rtnErr error) { + ids, err := resolveUiIds(ctx, pk, R_Session|R_Screen|R_Remote) + if err != nil { + return nil, err + } + ids.Remote.MShell.ClearCachedSudoPw() + pluralize := "" + + clearAll := resolveBool(pk.Kwargs["all"], false) + if clearAll { + for _, proc := range remote.GetRemoteMap() { + proc.ClearCachedSudoPw() + } + pluralize = "s" + } + + update := scbus.MakeUpdatePacket() + update.AddUpdate(sstore.InfoMsgType{ + InfoMsg: fmt.Sprintf("sudo password%s cleared", pluralize), + TimeoutMs: 2000, + }) + return update, nil +} + func RemoteResetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (rtnUpdate scbus.UpdatePacket, rtnErr error) { ids, err := resolveUiIds(ctx, pk, R_Session|R_Screen|R_Remote) if err != nil { @@ -5255,7 +5287,7 @@ func MakeReadFileUrl(screenId string, lineId string, filePath string) (string, e qvals.Set("lineid", lineId) qvals.Set("path", filePath) qvals.Set("nonce", uuid.New().String()) - hmacStr, err := promptenc.ComputeUrlHmac([]byte(scbase.WaveAuthKey), "/api/read-file", qvals) + hmacStr, err := waveenc.ComputeUrlHmac([]byte(scbase.WaveAuthKey), "/api/read-file", qvals) if err != nil { return "", fmt.Errorf("error computing hmac-url: %v", err) } diff --git a/wavesrv/pkg/cmdrunner/shparse.go b/wavesrv/pkg/cmdrunner/shparse.go index 61aacfcca..3e8b477ec 100644 --- a/wavesrv/pkg/cmdrunner/shparse.go +++ b/wavesrv/pkg/cmdrunner/shparse.go @@ -311,6 +311,55 @@ func IsReturnStateCommand(cmdStr string) bool { return false } +func checkSimpleSudoCmd(cmdStr string) bool { + cmdStr = strings.TrimSpace(cmdStr) + return strings.HasPrefix(cmdStr, "sudo ") +} + +func isSudoCmd(cmd syntax.Command) bool { + if cmd == nil { + return false + } + if _, ok := cmd.(*syntax.FuncDecl); ok { + return false + } + if blockExpr, ok := cmd.(*syntax.Block); ok { + for _, stmt := range blockExpr.Stmts { + if isSudoCmd(stmt.Cmd) { + return true + } + } + return false + } + if binExpr, ok := cmd.(*syntax.BinaryCmd); ok { + if isSudoCmd(binExpr.X.Cmd) || isSudoCmd(binExpr.Y.Cmd) { + return true + } + } else if callExpr, ok := cmd.(*syntax.CallExpr); ok { + arg0 := getCallExprLitArg(callExpr, 0) + if arg0 != "" && utilfn.ContainsStr([]string{"sudo"}, arg0) { + return true + } + } + return false + +} + +func IsSudoCommand(cmdStr string) bool { + cmdReader := strings.NewReader(cmdStr) + parser := syntax.NewParser(syntax.Variant(syntax.LangBash)) + file, err := parser.Parse(cmdReader, "sudo") + if err != nil { + return checkSimpleSudoCmd(cmdStr) + } + for _, stmt := range file.Stmts { + if isSudoCmd(stmt.Cmd) { + return true + } + } + return false +} + func EvalBracketArgs(origCmdStr string) (map[string]string, string, error) { rtn := make(map[string]string) if strings.HasPrefix(origCmdStr, " ") { diff --git a/wavesrv/pkg/remote/remote.go b/wavesrv/pkg/remote/remote.go index 99de0f9b6..9f96212d1 100644 --- a/wavesrv/pkg/remote/remote.go +++ b/wavesrv/pkg/remote/remote.go @@ -6,6 +6,9 @@ package remote import ( "bytes" "context" + "crypto/ecdh" + "crypto/rand" + "crypto/x509" "encoding/base64" "errors" "fmt" @@ -41,6 +44,7 @@ import ( "github.com/wavetermdev/waveterm/wavesrv/pkg/sstore" "github.com/wavetermdev/waveterm/wavesrv/pkg/telemetry" "github.com/wavetermdev/waveterm/wavesrv/pkg/userinput" + "github.com/wavetermdev/waveterm/wavesrv/pkg/waveenc" "golang.org/x/crypto/ssh" "golang.org/x/mod/semver" @@ -56,6 +60,7 @@ const PtyReadBufSize = 100 const RemoteConnectTimeout = 15 * time.Second const RpcIterChannelSize = 100 const MaxInputDataSize = 1000 +const SudoTimeoutTime = 5 * time.Minute var envVarsToStrip map[string]bool = map[string]bool{ "PROMPT": true, @@ -164,7 +169,9 @@ type MShellProc struct { RunningCmds map[base.CommandKey]*RunCmdType PendingStateCmds map[pendingStateKey]base.CommandKey // key=[remoteinstance name] (in progress commands that might update the state) - Client *ssh.Client + Client *ssh.Client + sudoPw []byte + sudoClearDeadline int64 } type CommandInputSink interface { @@ -2361,6 +2368,35 @@ func (msh *MShellProc) updateRIWithFinalState(ctx context.Context, rct *RunCmdTy return sstore.UpdateRemoteState(ctx, rct.SessionId, rct.ScreenId, rct.RemotePtr, feState, nil, newStateDiff) } +func (msh *MShellProc) handleSudoError(ck base.CommandKey, sudoErr error) { + ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelFn() + screenId, lineId := ck.Split() + + update := scbus.MakeUpdatePacket() + errOutputStr := fmt.Sprintf("%serror: %v%s\n", utilfn.AnsiRedColor(), sudoErr, utilfn.AnsiResetColor()) + msh.writeToCmdPtyOut(ctx, screenId, lineId, []byte(errOutputStr)) + doneInfo := sstore.CmdDoneDataValues{ + Ts: time.Now().UnixMilli(), + ExitCode: 1, + DurationMs: 0, + } + err := sstore.UpdateCmdDoneInfo(ctx, update, ck, doneInfo, sstore.CmdStatusError) + if err != nil { + log.Printf("error updating cmddone info (in handleSudoError): %v\n", err) + return + } + screen, err := sstore.UpdateScreenFocusForDoneCmd(ctx, screenId, lineId) + if err != nil { + log.Printf("error trying to update screen focus type (in handleSudoError): %v\n", err) + // fall-through (nothing to do) + } + if screen != nil { + update.AddUpdate(*screen) + } + scbus.MainUpdateBus.DoUpdate(update) +} + func (msh *MShellProc) handleCmdStartError(rct *RunCmdType, startErr error) { if rct == nil { log.Printf("handleCmdStartError, no rct\n") @@ -2590,6 +2626,82 @@ func sendScreenUpdates(screens []*sstore.ScreenType) { } } +func (msh *MShellProc) startSudoPwClearChecker() { + for { + shouldExit := false + msh.WithLock(func() { + if msh.sudoClearDeadline > 0 && time.Now().Unix() > msh.sudoClearDeadline { + msh.sudoPw = nil + msh.sudoClearDeadline = 0 + } + if msh.sudoClearDeadline == 0 { + shouldExit = true + } + }) + if shouldExit { + return + } + time.Sleep(time.Second * 2) + } +} + +func (msh *MShellProc) sendSudoPassword(sudoPk *packet.SudoRequestPacketType) error { + var storedPw []byte + var rawSecret []byte + msh.WithLock(func() { + storedPw = msh.sudoPw + }) + if storedPw != nil && sudoPk.SudoStatus == "first-attempt" { + rawSecret = storedPw + } else { + request := &userinput.UserInputRequestType{ + QueryText: "Please enter your password", + ResponseType: "text", + Title: "Sudo Password", + Markdown: false, + } + ctx, cancelFn := context.WithTimeout(context.Background(), 60*time.Second) + defer cancelFn() + guiResponse, err := userinput.GetUserInput(ctx, scbus.MainRpcBus, request) + if err != nil { + return err + } + rawSecret = []byte(guiResponse.Text) + } + //new + msh.WithLock(func() { + msh.sudoPw = rawSecret + if msh.sudoClearDeadline == 0 { + go msh.startSudoPwClearChecker() + } + msh.sudoClearDeadline = time.Now().Add(SudoTimeoutTime).Unix() + }) + + srvPrivKey, err := ecdh.P256().GenerateKey(rand.Reader) + if err != nil { + return fmt.Errorf("generate ecdh: %e", err) + } + encryptor, err := waveenc.MakeEncryptorEcdh(srvPrivKey, sudoPk.ShellPubKey) + if err != nil { + return err + } + encryptedSecret, err := encryptor.EncryptData(rawSecret, "sudopw") + if err != nil { + return fmt.Errorf("encrypt secret: %e", err) + } + srvPubKey, err := x509.MarshalPKIXPublicKey(srvPrivKey.PublicKey()) + if err != nil { + return fmt.Errorf("marshal pub key: %e", err) + } + sudoResponse := packet.MakeSudoResponsePacket(sudoPk.CK, encryptedSecret, srvPubKey) + select { + case msh.ServerProc.Input.SendCh <- sudoResponse: + default: + } + return nil + +} + func (msh *MShellProc) processSinglePacket(pk packet.PacketType) { if _, ok := pk.(*packet.DataAckPacketType); ok { // TODO process ack (need to keep track of buffer size for sending) @@ -2618,6 +2730,25 @@ func (msh *MShellProc) processSinglePacket(pk packet.PacketType) { }) return } + if sudoPk, ok := pk.(*packet.SudoRequestPacketType); ok { + // final failure case -- clear cache + if sudoPk.SudoStatus == "failure" { + msh.sudoPw = nil + msh.handleSudoError(sudoPk.CK, fmt.Errorf("sudo: incorrect password entered")) + return + } + + // handle waveshell errors here + if sudoPk.SudoStatus == "error" { + msh.handleSudoError(sudoPk.CK, fmt.Errorf("sudo: shell: %s", sudoPk.ErrStr)) + return + } + + err := msh.sendSudoPassword(sudoPk) + if err != nil { + msh.handleSudoError(sudoPk.CK, fmt.Errorf("sudo: srv: %s", err)) + } + } if msgPk, ok := pk.(*packet.MessagePacketType); ok { msh.WriteToPtyBuffer("msg> [remote %s] [%s] %s\n", msh.GetRemoteName(), msgPk.CK, msgPk.Message) return @@ -2629,6 +2760,13 @@ func (msh *MShellProc) processSinglePacket(pk packet.PacketType) { msh.WriteToPtyBuffer("*[remote %s] unhandled packet %s\n", msh.GetRemoteName(), packet.AsString(pk)) } +func (msh *MShellProc) ClearCachedSudoPw() { + msh.WithLock(func() { + msh.sudoPw = nil + msh.sudoClearDeadline = 0 + }) +} + func (msh *MShellProc) ProcessPackets() { defer msh.WithLock(func() { if msh.Status == StatusConnected { diff --git a/wavesrv/pkg/userinput/userinput.go b/wavesrv/pkg/userinput/userinput.go index d930dd8d4..6d654b044 100644 --- a/wavesrv/pkg/userinput/userinput.go +++ b/wavesrv/pkg/userinput/userinput.go @@ -13,6 +13,8 @@ import ( "github.com/wavetermdev/waveterm/wavesrv/pkg/scbus" ) +const UserInputRequestStr = "userinputrequest" + // An RpcPacket for requesting user input from the client type UserInputRequestType struct { RequestId string `json:"requestid"` @@ -26,7 +28,7 @@ type UserInputRequestType struct { } func (*UserInputRequestType) GetType() string { - return "userinputrequest" + return UserInputRequestStr } func (req *UserInputRequestType) SetReqId(reqId string) { diff --git a/wavesrv/pkg/promptenc/hmac.go b/wavesrv/pkg/waveenc/hmac.go similarity index 98% rename from wavesrv/pkg/promptenc/hmac.go rename to wavesrv/pkg/waveenc/hmac.go index 30a9e22cb..c949e38b5 100644 --- a/wavesrv/pkg/promptenc/hmac.go +++ b/wavesrv/pkg/waveenc/hmac.go @@ -1,7 +1,7 @@ // Copyright 2024, Command Line Inc. // SPDX-License-Identifier: Apache-2.0 -package promptenc +package waveenc import ( "crypto/hmac" diff --git a/wavesrv/pkg/promptenc/promptenc.go b/wavesrv/pkg/waveenc/waveenc.go similarity index 85% rename from wavesrv/pkg/promptenc/promptenc.go rename to wavesrv/pkg/waveenc/waveenc.go index 2cd3d8226..359c8872d 100644 --- a/wavesrv/pkg/promptenc/promptenc.go +++ b/wavesrv/pkg/waveenc/waveenc.go @@ -1,11 +1,14 @@ // Copyright 2023, Command Line Inc. // SPDX-License-Identifier: Apache-2.0 -package promptenc +package waveenc import ( "crypto/cipher" + "crypto/ecdh" + "crypto/ecdsa" "crypto/rand" + "crypto/x509" "encoding/base64" "encoding/json" "fmt" @@ -82,23 +85,14 @@ func (enc *Encryptor) EncryptData(plainText []byte, odata string) ([]byte, error return rtn, nil } -func (enc *Encryptor) DecryptData(encData []byte, odata string) (map[string]interface{}, error) { +func (enc *Encryptor) DecryptData(encData []byte, odata string) ([]byte, error) { minLen := enc.AEAD.NonceSize() + enc.AEAD.Overhead() if len(encData) < minLen { return nil, fmt.Errorf("invalid encdata, len:%d is less than minimum len:%d", len(encData), minLen) } - m := make(map[string]interface{}) nonce := encData[0:enc.AEAD.NonceSize()] cipherText := encData[enc.AEAD.NonceSize():] - plainText, err := enc.AEAD.Open(nil, nonce, cipherText, []byte(odata)) - if err != nil { - return nil, err - } - err = json.Unmarshal(plainText, &m) - if err != nil { - return nil, err - } - return m, nil + return enc.AEAD.Open(nil, nonce, cipherText, []byte(odata)) } type EncryptMeta struct { @@ -195,7 +189,12 @@ func (enc *Encryptor) DecryptStructFields(v interface{}, odata string) error { rvPtr := reflect.ValueOf(v) rv := rvPtr.Elem() cipherText := rv.FieldByIndex(encMeta.EncField.Index).Bytes() - m, err := enc.DecryptData(cipherText, odata) + decrypted, err := enc.DecryptData(cipherText, odata) + if err != nil { + return err + } + m := make(map[string]interface{}) + err = json.Unmarshal(decrypted, &m) if err != nil { return err } @@ -205,3 +204,24 @@ func (enc *Encryptor) DecryptStructFields(v interface{}, odata string) error { } return nil } + +func MakeEncryptorEcdh(localPrivKey *ecdh.PrivateKey, remotePubKey []byte) (*Encryptor, error) { + shellPubKey, err := x509.ParsePKIXPublicKey(remotePubKey) + if err != nil { + return nil, fmt.Errorf("parse pub key: %e", err) + } + ecdhShellPubKey, err := shellPubKey.(*ecdsa.PublicKey).ECDH() + if err != nil { + return nil, fmt.Errorf("convert pub key from ecdsa to ecdh: %e", err) + } + sharedKey, err := localPrivKey.ECDH(ecdhShellPubKey) + if err != nil { + return nil, fmt.Errorf("compute shared key: %e", err) + } + encryptor, err := MakeEncryptor(sharedKey) + if err != nil { + return nil, fmt.Errorf("create encryptor: %e", err) + } + return encryptor, nil + +}