wsh working over domain socket (and assorted bug fixes) (#217)

This commit is contained in:
Mike Sawka 2024-08-09 17:46:52 -07:00 committed by GitHub
parent 5e1da4805f
commit 5165d099c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 172 additions and 126 deletions

View File

@ -231,7 +231,7 @@ func main() {
}
}
}()
go web.RunWebServer(unixListener)
go wshserver.RunWshRpcOverListener(unixListener)
web.RunWebServer(webListener) // blocking
runtime.KeepAlive(waveLock)
}

View File

@ -4,11 +4,8 @@
package cmd
import (
"fmt"
"github.com/spf13/cobra"
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
"github.com/wavetermdev/thenextwave/pkg/wshutil"
)
var deleteBlockCmd = &cobra.Command{
@ -25,22 +22,21 @@ func init() {
func deleteBlockRun(cmd *cobra.Command, args []string) {
oref := args[0]
if oref == "" {
fmt.Println("oref is required")
WriteStderr("[error] oref is required\n")
return
}
err := validateEasyORef(oref)
if err != nil {
fmt.Printf("%v\n", err)
WriteStderr("[error]%v\n", err)
return
}
wshutil.SetTermRawModeAndInstallShutdownHandlers(true)
fullORef, err := resolveSimpleId(oref)
if err != nil {
fmt.Printf("error resolving oref: %v\r\n", err)
WriteStderr("[error] resolving oref: %v\n", err)
return
}
if fullORef.OType != "block" {
fmt.Printf("oref is not a block\r\n")
WriteStderr("[error] oref is not a block\n")
return
}
deleteBlockData := &wshrpc.CommandDeleteBlockData{
@ -48,8 +44,8 @@ func deleteBlockRun(cmd *cobra.Command, args []string) {
}
_, err = RpcClient.SendRpcRequest(wshrpc.Command_DeleteBlock, deleteBlockData, 2000)
if err != nil {
fmt.Printf("error deleting block: %v\r\n", err)
WriteStderr("[error] deleting block: %v\n", err)
return
}
fmt.Print("block deleted\r\n")
WriteStdout("block deleted\n")
}

View File

@ -5,14 +5,10 @@ package cmd
import (
"encoding/json"
"fmt"
"log"
"strings"
"github.com/spf13/cobra"
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
"github.com/wavetermdev/thenextwave/pkg/wshrpc/wshclient"
"github.com/wavetermdev/thenextwave/pkg/wshutil"
)
var getMetaCmd = &cobra.Command{
@ -29,24 +25,22 @@ func init() {
func getMetaRun(cmd *cobra.Command, args []string) {
oref := args[0]
if oref == "" {
fmt.Println("oref is required")
WriteStderr("[error] oref is required")
return
}
err := validateEasyORef(oref)
if err != nil {
fmt.Printf("%v\n", err)
WriteStderr("[error] %v\n", err)
return
}
wshutil.SetTermRawModeAndInstallShutdownHandlers(true)
fullORef, err := resolveSimpleId(oref)
if err != nil {
fmt.Printf("error resolving oref: %v\r\n", err)
WriteStderr("[error] resolving oref: %v\n", err)
return
}
resp, err := wshclient.GetMetaCommand(RpcClient, wshrpc.CommandGetMetaData{ORef: *fullORef}, &wshrpc.WshRpcCommandOpts{Timeout: 2000})
if err != nil {
log.Printf("error getting metadata: %v\r\n", err)
WriteStderr("[error] getting metadata: %v\n", err)
return
}
if len(args) > 1 {
@ -56,21 +50,18 @@ func getMetaRun(cmd *cobra.Command, args []string) {
}
outBArr, err := json.MarshalIndent(val, "", " ")
if err != nil {
log.Printf("error formatting metadata: %v\r\n", err)
}
outStr := string(outBArr)
outStr = strings.ReplaceAll(outStr, "\n", "\r\n")
fmt.Print(outStr)
fmt.Print("\r\n")
} else {
outBArr, err := json.MarshalIndent(resp, "", " ")
if err != nil {
log.Printf("error formatting metadata: %v\r\n", err)
WriteStderr("[error] formatting metadata: %v\n", err)
return
}
outStr := string(outBArr)
outStr = strings.ReplaceAll(outStr, "\n", "\r\n")
fmt.Print(outStr)
fmt.Print("\r\n")
WriteStdout(outStr + "\n")
} else {
outBArr, err := json.MarshalIndent(resp, "", " ")
if err != nil {
WriteStderr("[error] formatting metadata: %v\n", err)
return
}
outStr := string(outBArr)
WriteStdout(outStr + "\n")
}
}

View File

@ -5,13 +5,10 @@ package cmd
import (
"encoding/base64"
"fmt"
"os"
"github.com/spf13/cobra"
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
"github.com/wavetermdev/thenextwave/pkg/wshrpc/wshclient"
"github.com/wavetermdev/thenextwave/pkg/wshutil"
)
var readFileCmd = &cobra.Command{
@ -28,30 +25,28 @@ func init() {
func runReadFile(cmd *cobra.Command, args []string) {
oref := args[0]
if oref == "" {
fmt.Fprintf(os.Stderr, "oref is required\r\n")
WriteStderr("[error] oref is required\n")
return
}
err := validateEasyORef(oref)
if err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
WriteStderr("[error] %v\n", err)
return
}
wshutil.SetTermRawModeAndInstallShutdownHandlers(true)
fullORef, err := resolveSimpleId(oref)
if err != nil {
fmt.Fprintf(os.Stderr, "error resolving oref: %v\r\n", err)
WriteStderr("error resolving oref: %v\n", err)
return
}
resp64, err := wshclient.FileReadCommand(RpcClient, wshrpc.CommandFileData{ZoneId: fullORef.OID, FileName: args[1]}, &wshrpc.WshRpcCommandOpts{Timeout: 5000})
if err != nil {
fmt.Fprintf(os.Stderr, "error reading file: %v\r\n", err)
WriteStderr("[error] reading file: %v\n", err)
return
}
resp, err := base64.StdEncoding.DecodeString(resp64)
if err != nil {
fmt.Fprintf(os.Stderr, "error decoding file: %v\r\n", err)
WriteStderr("[error] decoding file: %v\n", err)
return
}
fmt.Print(string(resp))
WriteStdout(string(resp))
}

View File

@ -29,8 +29,9 @@ var (
)
var usingHtmlMode bool
var WrappedStdin io.Reader
var WrappedStdin io.Reader = os.Stdin
var RpcClient *wshutil.WshRpc
var UsingTermWshMode bool
func extraShutdownFn() {
if usingHtmlMode {
@ -42,15 +43,46 @@ func extraShutdownFn() {
}
}
func WriteStderr(fmtStr string, args ...interface{}) {
output := fmt.Sprintf(fmtStr, args...)
if UsingTermWshMode {
output = strings.ReplaceAll(output, "\n", "\r\n")
}
fmt.Fprint(os.Stderr, output)
}
func WriteStdout(fmtStr string, args ...interface{}) {
output := fmt.Sprintf(fmtStr, args...)
if UsingTermWshMode {
output = strings.ReplaceAll(output, "\n", "\r\n")
}
fmt.Print(output)
}
// returns the wrapped stdin and a new rpc client (that wraps the stdin input and stdout output)
func setupRpcClient(handlerFn wshutil.CommandHandlerFnType) {
log.Printf("setup rpc client\r\n")
RpcClient, WrappedStdin = wshutil.SetupTerminalRpcClient(handlerFn)
func setupRpcClient(handlerFn wshutil.CommandHandlerFnType) error {
jwtToken := os.Getenv("WAVETERM_JWT")
if jwtToken == "" {
wshutil.SetTermRawModeAndInstallShutdownHandlers(true)
UsingTermWshMode = true
RpcClient, WrappedStdin = wshutil.SetupTerminalRpcClient(handlerFn)
return nil
}
sockName, err := wshutil.ExtractUnverifiedSocketName(jwtToken)
if err != nil {
return fmt.Errorf("error extracting socket name from WAVETERM_JWT: %v", err)
}
RpcClient, err = wshutil.SetupDomainSocketRpcClient(sockName, handlerFn)
if err != nil {
return fmt.Errorf("error setting up domain socket rpc client: %v", err)
}
wshclient.AuthenticateCommand(RpcClient, jwtToken, &wshrpc.WshRpcCommandOpts{NoResponse: true})
// note we don't modify WrappedStdin here (just use os.Stdin)
return nil
}
func setTermHtmlMode() {
wshutil.SetExtraShutdownFunc(extraShutdownFn)
wshutil.SetTermRawModeAndInstallShutdownHandlers(true)
cmd := &wshrpc.CommandSetMetaData{
Meta: map[string]any{"term:mode": "html"},
}
@ -109,8 +141,18 @@ func resolveSimpleId(id string) (*waveobj.ORef, error) {
}
// Execute executes the root command.
func Execute() error {
func Execute() {
defer wshutil.DoShutdown("", 0, false)
setupRpcClient(nil)
return rootCmd.Execute()
err := setupRpcClient(nil)
if err != nil {
log.Printf("[error] %v\n", err)
wshutil.DoShutdown("", 1, true)
return
}
err = rootCmd.Execute()
if err != nil {
log.Printf("[error] %v\n", err)
wshutil.DoShutdown("", 1, true)
return
}
}

View File

@ -11,7 +11,6 @@ import (
"github.com/spf13/cobra"
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
"github.com/wavetermdev/thenextwave/pkg/wshutil"
)
var setMetaCmd = &cobra.Command{
@ -62,23 +61,22 @@ func setMetaRun(cmd *cobra.Command, args []string) {
oref := args[0]
metaSetsStrs := args[1:]
if oref == "" {
fmt.Println("oref is required")
WriteStderr("[error] oref is required\n")
return
}
err := validateEasyORef(oref)
if err != nil {
fmt.Printf("%v\n", err)
WriteStderr("[error] %v\n", err)
return
}
meta, err := parseMetaSets(metaSetsStrs)
if err != nil {
fmt.Printf("%v\n", err)
WriteStderr("[error] %v\n", err)
return
}
wshutil.SetTermRawModeAndInstallShutdownHandlers(true)
fullORef, err := resolveSimpleId(oref)
if err != nil {
fmt.Printf("error resolving oref: %v\n", err)
WriteStderr("[error] resolving oref: %v\n", err)
return
}
setMetaWshCmd := &wshrpc.CommandSetMetaData{
@ -87,8 +85,8 @@ func setMetaRun(cmd *cobra.Command, args []string) {
}
_, err = RpcClient.SendRpcRequest(wshrpc.Command_SetMeta, setMetaWshCmd, 2000)
if err != nil {
fmt.Printf("error setting metadata: %v\n", err)
WriteStderr("[error] setting metadata: %v\n", err)
return
}
fmt.Print("metadata set\r\n")
WriteStdout("metadata set\n")
}

View File

@ -4,8 +4,6 @@
package cmd
import (
"fmt"
"github.com/spf13/cobra"
)
@ -17,6 +15,6 @@ var versionCmd = &cobra.Command{
Use: "version",
Short: "Print the version number of wsh",
Run: func(cmd *cobra.Command, args []string) {
fmt.Printf("wsh v0.1.0\n")
WriteStdout("wsh v0.1.0\n")
},
}

View File

@ -5,13 +5,11 @@ package cmd
import (
"io/fs"
"log"
"os"
"path/filepath"
"github.com/spf13/cobra"
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
"github.com/wavetermdev/thenextwave/pkg/wshutil"
"github.com/wavetermdev/thenextwave/pkg/wstore"
)
@ -33,18 +31,18 @@ func viewRun(cmd *cobra.Command, args []string) {
fileArg := args[0]
absFile, err := filepath.Abs(fileArg)
if err != nil {
log.Printf("error getting absolute path: %v\n", err)
WriteStderr("[error] getting absolute path: %v\n", err)
return
}
_, err = os.Stat(absFile)
if err == fs.ErrNotExist {
log.Printf("file does not exist: %q\n", absFile)
WriteStderr("[error] file does not exist: %q\n", absFile)
return
}
if err != nil {
log.Printf("error getting file info: %v\n", err)
WriteStderr("[error] getting file info: %v\n", err)
return
}
wshutil.SetTermRawModeAndInstallShutdownHandlers(true)
viewWshCmd := &wshrpc.CommandCreateBlockData{
BlockDef: &wstore.BlockDef{
Meta: map[string]interface{}{
@ -55,7 +53,7 @@ func viewRun(cmd *cobra.Command, args []string) {
}
_, err = RpcClient.SendRpcRequest(wshrpc.Command_CreateBlock, viewWshCmd, 2000)
if err != nil {
log.Printf("error running view command: %v\r\n", err)
WriteStderr("[error] running view command: %v\r\n", err)
return
}
}

View File

@ -9,5 +9,4 @@ import (
func main() {
cmd.Execute()
}

View File

@ -267,13 +267,13 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
if err != nil {
return fmt.Errorf("error making jwt token: %w", err)
}
cmdOpts.Env["WAVETERM_JWT"] = jwtStr
cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr
} else {
jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId}, wavebase.GetDomainSocketName())
if err != nil {
return fmt.Errorf("error making jwt token: %w", err)
}
cmdOpts.Env["WAVETERM_JWT"] = jwtStr
cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr
}
}
if bc.ControllerType == BlockController_Shell {

View File

@ -94,6 +94,9 @@ type RpcContext struct {
func HackRpcContextIntoData(dataPtr any, rpcContext RpcContext) {
dataVal := reflect.ValueOf(dataPtr).Elem()
if dataVal.Kind() != reflect.Struct {
return
}
dataType := dataVal.Type()
for i := 0; i < dataVal.NumField(); i++ {
field := dataVal.Field(i)

View File

@ -8,11 +8,9 @@ import (
"fmt"
"log"
"net"
"os"
"reflect"
"github.com/wavetermdev/thenextwave/pkg/util/utilfn"
"github.com/wavetermdev/thenextwave/pkg/wavebase"
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
"github.com/wavetermdev/thenextwave/pkg/wshutil"
)
@ -159,48 +157,18 @@ func mainWshServerHandler(handler *wshutil.RpcResponseHandler) bool {
}
}
func MakeUnixListener(sockName string) (net.Listener, error) {
os.Remove(sockName) // ignore error
rtn, err := net.Listen("unix", sockName)
if err != nil {
return nil, fmt.Errorf("error creating listener at %v: %v", sockName, err)
}
os.Chmod(sockName, 0700)
log.Printf("Server listening on %s\n", sockName)
return rtn, nil
}
func runWshRpcWithStream(conn net.Conn) {
defer conn.Close()
inputCh := make(chan []byte, DefaultInputChSize)
outputCh := make(chan []byte, DefaultOutputChSize)
go wshutil.AdaptMsgChToStream(outputCh, conn)
go wshutil.AdaptStreamToMsgCh(conn, inputCh)
wshutil.MakeWshRpc(inputCh, outputCh, wshrpc.RpcContext{}, mainWshServerHandler)
}
func RunWshRpcOverListener(listener net.Listener) {
go func() {
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("error accepting connection: %v\n", err)
continue
}
go runWshRpcWithStream(conn)
defer log.Printf("domain socket listener shutting down\n")
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("error accepting connection: %v\n", err)
continue
}
}()
}
func RunDomainSocketWshServer() error {
sockName := wavebase.GetDomainSocketName()
listener, err := MakeUnixListener(sockName)
if err != nil {
return fmt.Errorf("error starging unix listener for wsh-server: %w", err)
log.Print("got domain socket connection\n")
// TODO deal with closing connection
go wshutil.SetupConnRpcClient(conn, mainWshServerHandler)
}
defer listener.Close()
RunWshRpcOverListener(listener)
return nil
}
func MakeWshServer(inputCh chan []byte, outputCh chan []byte, initialCtx wshrpc.RpcContext) {

View File

@ -43,28 +43,32 @@ func streamToLines_processBuf(lineBuf *lineBuf, readBuf []byte, lineFn func([]by
}
}
func streamToLines(input io.Reader, lineFn func([]byte)) {
func streamToLines(input io.Reader, lineFn func([]byte)) error {
var lineBuf lineBuf
readBuf := make([]byte, 16*1024)
for {
n, err := input.Read(readBuf)
streamToLines_processBuf(&lineBuf, readBuf[:n], lineFn)
if err != nil {
break
return err
}
}
}
func AdaptStreamToMsgCh(input io.Reader, output chan []byte) {
streamToLines(input, func(line []byte) {
func AdaptStreamToMsgCh(input io.Reader, output chan []byte) error {
return streamToLines(input, func(line []byte) {
output <- line
})
}
func AdaptMsgChToStream(outputCh chan []byte, output io.Writer) error {
func AdaptOutputChToStream(outputCh chan []byte, output io.Writer) error {
for msg := range outputCh {
if _, err := output.Write(msg); err != nil {
return fmt.Errorf("error writing to output: %w", err)
return fmt.Errorf("error writing to output (AdaptOutputChToStream): %w", err)
}
// write trailing newline
if _, err := output.Write([]byte{'\n'}); err != nil {
return fmt.Errorf("error writing trailing newline to output (AdaptOutputChToStream): %w", err)
}
}
return nil

View File

@ -9,6 +9,7 @@ import (
"fmt"
"io"
"log"
"net"
"os"
"os/signal"
"sync"
@ -35,6 +36,11 @@ const BEL = 0x07
const ST = 0x9c
const ESC = 0x1b
const DefaultOutputChSize = 32
const DefaultInputChSize = 32
const WaveJwtTokenVarName = "WAVETERM_JWT"
// OSC escape types
// OSC 23198 ; (JSON | base64-JSON) ST
// JSON = must escape all ASCII control characters ([\x00-\x1F\x7F])
@ -181,8 +187,8 @@ func RestoreTermState() {
// returns (wshRpc, wrappedStdin)
func SetupTerminalRpcClient(handlerFn func(*RpcResponseHandler) bool) (*WshRpc, io.Reader) {
messageCh := make(chan []byte, 32)
outputCh := make(chan []byte, 32)
messageCh := make(chan []byte, DefaultInputChSize)
outputCh := make(chan []byte, DefaultOutputChSize)
ptyBuf := MakePtyBuffer(WaveServerOSCPrefix, os.Stdin, messageCh)
rpcClient := MakeWshRpc(messageCh, outputCh, wshrpc.RpcContext{}, handlerFn)
go func() {
@ -194,6 +200,42 @@ func SetupTerminalRpcClient(handlerFn func(*RpcResponseHandler) bool) (*WshRpc,
return rpcClient, ptyBuf
}
func SetupConnRpcClient(conn net.Conn, handlerFn func(*RpcResponseHandler) bool) (*WshRpc, chan error, error) {
inputCh := make(chan []byte, DefaultInputChSize)
outputCh := make(chan []byte, DefaultOutputChSize)
writeErrCh := make(chan error, 1)
go func() {
writeErr := AdaptOutputChToStream(outputCh, conn)
if writeErr != nil {
writeErrCh <- writeErr
close(writeErrCh)
}
}()
go func() {
// when input is closed, close the connection
defer conn.Close()
AdaptStreamToMsgCh(conn, inputCh)
}()
rtn := MakeWshRpc(inputCh, outputCh, wshrpc.RpcContext{}, handlerFn)
return rtn, writeErrCh, nil
}
func SetupDomainSocketRpcClient(sockName string, handlerFn func(*RpcResponseHandler) bool) (*WshRpc, error) {
conn, err := net.Dial("unix", sockName)
if err != nil {
return nil, fmt.Errorf("failed to connect to Unix domain socket: %w", err)
}
rtn, errCh, err := SetupConnRpcClient(conn, handlerFn)
go func() {
defer conn.Close()
err := <-errCh
if err != nil && err != io.EOF {
log.Printf("error in domain socket connection: %v\n", err)
}
}()
return rtn, err
}
func MakeClientJWTToken(rpcCtx wshrpc.RpcContext, sockName string) (string, error) {
claims := jwt.MapClaims{}
claims["iat"] = time.Now().Unix()
@ -246,9 +288,21 @@ func ValidateAndExtractRpcContextFromToken(tokenStr string) (*wshrpc.RpcContext,
return nil, fmt.Errorf("iss claim is missing or invalid")
}
rpcCtx := &wshrpc.RpcContext{}
rpcCtx.BlockId = claims["blockid"].(string)
rpcCtx.TabId = claims["tabid"].(string)
rpcCtx.WindowId = claims["windowid"].(string)
if claims["blockid"] != nil {
if blockId, ok := claims["blockid"].(string); ok {
rpcCtx.BlockId = blockId
}
}
if claims["tabid"] != nil {
if tabId, ok := claims["tabid"].(string); ok {
rpcCtx.TabId = tabId
}
}
if claims["windowid"] != nil {
if windowId, ok := claims["windowid"].(string); ok {
rpcCtx.WindowId = windowId
}
}
return rpcCtx, nil
}