mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-17 20:51:55 +01:00
wsh working over domain socket (and assorted bug fixes) (#217)
This commit is contained in:
parent
5e1da4805f
commit
5165d099c2
@ -231,7 +231,7 @@ func main() {
|
||||
}
|
||||
}
|
||||
}()
|
||||
go web.RunWebServer(unixListener)
|
||||
go wshserver.RunWshRpcOverListener(unixListener)
|
||||
web.RunWebServer(webListener) // blocking
|
||||
runtime.KeepAlive(waveLock)
|
||||
}
|
||||
|
@ -4,11 +4,8 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshutil"
|
||||
)
|
||||
|
||||
var deleteBlockCmd = &cobra.Command{
|
||||
@ -25,22 +22,21 @@ func init() {
|
||||
func deleteBlockRun(cmd *cobra.Command, args []string) {
|
||||
oref := args[0]
|
||||
if oref == "" {
|
||||
fmt.Println("oref is required")
|
||||
WriteStderr("[error] oref is required\n")
|
||||
return
|
||||
}
|
||||
err := validateEasyORef(oref)
|
||||
if err != nil {
|
||||
fmt.Printf("%v\n", err)
|
||||
WriteStderr("[error]%v\n", err)
|
||||
return
|
||||
}
|
||||
wshutil.SetTermRawModeAndInstallShutdownHandlers(true)
|
||||
fullORef, err := resolveSimpleId(oref)
|
||||
if err != nil {
|
||||
fmt.Printf("error resolving oref: %v\r\n", err)
|
||||
WriteStderr("[error] resolving oref: %v\n", err)
|
||||
return
|
||||
}
|
||||
if fullORef.OType != "block" {
|
||||
fmt.Printf("oref is not a block\r\n")
|
||||
WriteStderr("[error] oref is not a block\n")
|
||||
return
|
||||
}
|
||||
deleteBlockData := &wshrpc.CommandDeleteBlockData{
|
||||
@ -48,8 +44,8 @@ func deleteBlockRun(cmd *cobra.Command, args []string) {
|
||||
}
|
||||
_, err = RpcClient.SendRpcRequest(wshrpc.Command_DeleteBlock, deleteBlockData, 2000)
|
||||
if err != nil {
|
||||
fmt.Printf("error deleting block: %v\r\n", err)
|
||||
WriteStderr("[error] deleting block: %v\n", err)
|
||||
return
|
||||
}
|
||||
fmt.Print("block deleted\r\n")
|
||||
WriteStdout("block deleted\n")
|
||||
}
|
||||
|
@ -5,14 +5,10 @@ package cmd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshrpc/wshclient"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshutil"
|
||||
)
|
||||
|
||||
var getMetaCmd = &cobra.Command{
|
||||
@ -29,24 +25,22 @@ func init() {
|
||||
func getMetaRun(cmd *cobra.Command, args []string) {
|
||||
oref := args[0]
|
||||
if oref == "" {
|
||||
fmt.Println("oref is required")
|
||||
WriteStderr("[error] oref is required")
|
||||
return
|
||||
}
|
||||
err := validateEasyORef(oref)
|
||||
if err != nil {
|
||||
fmt.Printf("%v\n", err)
|
||||
WriteStderr("[error] %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
wshutil.SetTermRawModeAndInstallShutdownHandlers(true)
|
||||
fullORef, err := resolveSimpleId(oref)
|
||||
if err != nil {
|
||||
fmt.Printf("error resolving oref: %v\r\n", err)
|
||||
WriteStderr("[error] resolving oref: %v\n", err)
|
||||
return
|
||||
}
|
||||
resp, err := wshclient.GetMetaCommand(RpcClient, wshrpc.CommandGetMetaData{ORef: *fullORef}, &wshrpc.WshRpcCommandOpts{Timeout: 2000})
|
||||
if err != nil {
|
||||
log.Printf("error getting metadata: %v\r\n", err)
|
||||
WriteStderr("[error] getting metadata: %v\n", err)
|
||||
return
|
||||
}
|
||||
if len(args) > 1 {
|
||||
@ -56,21 +50,18 @@ func getMetaRun(cmd *cobra.Command, args []string) {
|
||||
}
|
||||
outBArr, err := json.MarshalIndent(val, "", " ")
|
||||
if err != nil {
|
||||
log.Printf("error formatting metadata: %v\r\n", err)
|
||||
}
|
||||
outStr := string(outBArr)
|
||||
outStr = strings.ReplaceAll(outStr, "\n", "\r\n")
|
||||
fmt.Print(outStr)
|
||||
fmt.Print("\r\n")
|
||||
} else {
|
||||
outBArr, err := json.MarshalIndent(resp, "", " ")
|
||||
if err != nil {
|
||||
log.Printf("error formatting metadata: %v\r\n", err)
|
||||
WriteStderr("[error] formatting metadata: %v\n", err)
|
||||
return
|
||||
}
|
||||
outStr := string(outBArr)
|
||||
outStr = strings.ReplaceAll(outStr, "\n", "\r\n")
|
||||
fmt.Print(outStr)
|
||||
fmt.Print("\r\n")
|
||||
WriteStdout(outStr + "\n")
|
||||
} else {
|
||||
outBArr, err := json.MarshalIndent(resp, "", " ")
|
||||
if err != nil {
|
||||
WriteStderr("[error] formatting metadata: %v\n", err)
|
||||
return
|
||||
}
|
||||
outStr := string(outBArr)
|
||||
WriteStdout(outStr + "\n")
|
||||
}
|
||||
}
|
||||
|
@ -5,13 +5,10 @@ package cmd
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshrpc/wshclient"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshutil"
|
||||
)
|
||||
|
||||
var readFileCmd = &cobra.Command{
|
||||
@ -28,30 +25,28 @@ func init() {
|
||||
func runReadFile(cmd *cobra.Command, args []string) {
|
||||
oref := args[0]
|
||||
if oref == "" {
|
||||
fmt.Fprintf(os.Stderr, "oref is required\r\n")
|
||||
WriteStderr("[error] oref is required\n")
|
||||
return
|
||||
}
|
||||
err := validateEasyORef(oref)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||
WriteStderr("[error] %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
wshutil.SetTermRawModeAndInstallShutdownHandlers(true)
|
||||
fullORef, err := resolveSimpleId(oref)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error resolving oref: %v\r\n", err)
|
||||
WriteStderr("error resolving oref: %v\n", err)
|
||||
return
|
||||
}
|
||||
resp64, err := wshclient.FileReadCommand(RpcClient, wshrpc.CommandFileData{ZoneId: fullORef.OID, FileName: args[1]}, &wshrpc.WshRpcCommandOpts{Timeout: 5000})
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error reading file: %v\r\n", err)
|
||||
WriteStderr("[error] reading file: %v\n", err)
|
||||
return
|
||||
}
|
||||
resp, err := base64.StdEncoding.DecodeString(resp64)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error decoding file: %v\r\n", err)
|
||||
WriteStderr("[error] decoding file: %v\n", err)
|
||||
return
|
||||
}
|
||||
fmt.Print(string(resp))
|
||||
WriteStdout(string(resp))
|
||||
}
|
||||
|
@ -29,8 +29,9 @@ var (
|
||||
)
|
||||
|
||||
var usingHtmlMode bool
|
||||
var WrappedStdin io.Reader
|
||||
var WrappedStdin io.Reader = os.Stdin
|
||||
var RpcClient *wshutil.WshRpc
|
||||
var UsingTermWshMode bool
|
||||
|
||||
func extraShutdownFn() {
|
||||
if usingHtmlMode {
|
||||
@ -42,15 +43,46 @@ func extraShutdownFn() {
|
||||
}
|
||||
}
|
||||
|
||||
func WriteStderr(fmtStr string, args ...interface{}) {
|
||||
output := fmt.Sprintf(fmtStr, args...)
|
||||
if UsingTermWshMode {
|
||||
output = strings.ReplaceAll(output, "\n", "\r\n")
|
||||
}
|
||||
fmt.Fprint(os.Stderr, output)
|
||||
}
|
||||
|
||||
func WriteStdout(fmtStr string, args ...interface{}) {
|
||||
output := fmt.Sprintf(fmtStr, args...)
|
||||
if UsingTermWshMode {
|
||||
output = strings.ReplaceAll(output, "\n", "\r\n")
|
||||
}
|
||||
fmt.Print(output)
|
||||
}
|
||||
|
||||
// returns the wrapped stdin and a new rpc client (that wraps the stdin input and stdout output)
|
||||
func setupRpcClient(handlerFn wshutil.CommandHandlerFnType) {
|
||||
log.Printf("setup rpc client\r\n")
|
||||
func setupRpcClient(handlerFn wshutil.CommandHandlerFnType) error {
|
||||
jwtToken := os.Getenv("WAVETERM_JWT")
|
||||
if jwtToken == "" {
|
||||
wshutil.SetTermRawModeAndInstallShutdownHandlers(true)
|
||||
UsingTermWshMode = true
|
||||
RpcClient, WrappedStdin = wshutil.SetupTerminalRpcClient(handlerFn)
|
||||
return nil
|
||||
}
|
||||
sockName, err := wshutil.ExtractUnverifiedSocketName(jwtToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error extracting socket name from WAVETERM_JWT: %v", err)
|
||||
}
|
||||
RpcClient, err = wshutil.SetupDomainSocketRpcClient(sockName, handlerFn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error setting up domain socket rpc client: %v", err)
|
||||
}
|
||||
wshclient.AuthenticateCommand(RpcClient, jwtToken, &wshrpc.WshRpcCommandOpts{NoResponse: true})
|
||||
// note we don't modify WrappedStdin here (just use os.Stdin)
|
||||
return nil
|
||||
}
|
||||
|
||||
func setTermHtmlMode() {
|
||||
wshutil.SetExtraShutdownFunc(extraShutdownFn)
|
||||
wshutil.SetTermRawModeAndInstallShutdownHandlers(true)
|
||||
cmd := &wshrpc.CommandSetMetaData{
|
||||
Meta: map[string]any{"term:mode": "html"},
|
||||
}
|
||||
@ -109,8 +141,18 @@ func resolveSimpleId(id string) (*waveobj.ORef, error) {
|
||||
}
|
||||
|
||||
// Execute executes the root command.
|
||||
func Execute() error {
|
||||
func Execute() {
|
||||
defer wshutil.DoShutdown("", 0, false)
|
||||
setupRpcClient(nil)
|
||||
return rootCmd.Execute()
|
||||
err := setupRpcClient(nil)
|
||||
if err != nil {
|
||||
log.Printf("[error] %v\n", err)
|
||||
wshutil.DoShutdown("", 1, true)
|
||||
return
|
||||
}
|
||||
err = rootCmd.Execute()
|
||||
if err != nil {
|
||||
log.Printf("[error] %v\n", err)
|
||||
wshutil.DoShutdown("", 1, true)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -11,7 +11,6 @@ import (
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshutil"
|
||||
)
|
||||
|
||||
var setMetaCmd = &cobra.Command{
|
||||
@ -62,23 +61,22 @@ func setMetaRun(cmd *cobra.Command, args []string) {
|
||||
oref := args[0]
|
||||
metaSetsStrs := args[1:]
|
||||
if oref == "" {
|
||||
fmt.Println("oref is required")
|
||||
WriteStderr("[error] oref is required\n")
|
||||
return
|
||||
}
|
||||
err := validateEasyORef(oref)
|
||||
if err != nil {
|
||||
fmt.Printf("%v\n", err)
|
||||
WriteStderr("[error] %v\n", err)
|
||||
return
|
||||
}
|
||||
meta, err := parseMetaSets(metaSetsStrs)
|
||||
if err != nil {
|
||||
fmt.Printf("%v\n", err)
|
||||
WriteStderr("[error] %v\n", err)
|
||||
return
|
||||
}
|
||||
wshutil.SetTermRawModeAndInstallShutdownHandlers(true)
|
||||
fullORef, err := resolveSimpleId(oref)
|
||||
if err != nil {
|
||||
fmt.Printf("error resolving oref: %v\n", err)
|
||||
WriteStderr("[error] resolving oref: %v\n", err)
|
||||
return
|
||||
}
|
||||
setMetaWshCmd := &wshrpc.CommandSetMetaData{
|
||||
@ -87,8 +85,8 @@ func setMetaRun(cmd *cobra.Command, args []string) {
|
||||
}
|
||||
_, err = RpcClient.SendRpcRequest(wshrpc.Command_SetMeta, setMetaWshCmd, 2000)
|
||||
if err != nil {
|
||||
fmt.Printf("error setting metadata: %v\n", err)
|
||||
WriteStderr("[error] setting metadata: %v\n", err)
|
||||
return
|
||||
}
|
||||
fmt.Print("metadata set\r\n")
|
||||
WriteStdout("metadata set\n")
|
||||
}
|
||||
|
@ -4,8 +4,6 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@ -17,6 +15,6 @@ var versionCmd = &cobra.Command{
|
||||
Use: "version",
|
||||
Short: "Print the version number of wsh",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
fmt.Printf("wsh v0.1.0\n")
|
||||
WriteStdout("wsh v0.1.0\n")
|
||||
},
|
||||
}
|
||||
|
@ -5,13 +5,11 @@ package cmd
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshutil"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wstore"
|
||||
)
|
||||
|
||||
@ -33,18 +31,18 @@ func viewRun(cmd *cobra.Command, args []string) {
|
||||
fileArg := args[0]
|
||||
absFile, err := filepath.Abs(fileArg)
|
||||
if err != nil {
|
||||
log.Printf("error getting absolute path: %v\n", err)
|
||||
WriteStderr("[error] getting absolute path: %v\n", err)
|
||||
return
|
||||
}
|
||||
_, err = os.Stat(absFile)
|
||||
if err == fs.ErrNotExist {
|
||||
log.Printf("file does not exist: %q\n", absFile)
|
||||
WriteStderr("[error] file does not exist: %q\n", absFile)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("error getting file info: %v\n", err)
|
||||
WriteStderr("[error] getting file info: %v\n", err)
|
||||
return
|
||||
}
|
||||
wshutil.SetTermRawModeAndInstallShutdownHandlers(true)
|
||||
viewWshCmd := &wshrpc.CommandCreateBlockData{
|
||||
BlockDef: &wstore.BlockDef{
|
||||
Meta: map[string]interface{}{
|
||||
@ -55,7 +53,7 @@ func viewRun(cmd *cobra.Command, args []string) {
|
||||
}
|
||||
_, err = RpcClient.SendRpcRequest(wshrpc.Command_CreateBlock, viewWshCmd, 2000)
|
||||
if err != nil {
|
||||
log.Printf("error running view command: %v\r\n", err)
|
||||
WriteStderr("[error] running view command: %v\r\n", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -9,5 +9,4 @@ import (
|
||||
|
||||
func main() {
|
||||
cmd.Execute()
|
||||
|
||||
}
|
||||
|
@ -267,13 +267,13 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
|
||||
if err != nil {
|
||||
return fmt.Errorf("error making jwt token: %w", err)
|
||||
}
|
||||
cmdOpts.Env["WAVETERM_JWT"] = jwtStr
|
||||
cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr
|
||||
} else {
|
||||
jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId}, wavebase.GetDomainSocketName())
|
||||
if err != nil {
|
||||
return fmt.Errorf("error making jwt token: %w", err)
|
||||
}
|
||||
cmdOpts.Env["WAVETERM_JWT"] = jwtStr
|
||||
cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr
|
||||
}
|
||||
}
|
||||
if bc.ControllerType == BlockController_Shell {
|
||||
|
@ -94,6 +94,9 @@ type RpcContext struct {
|
||||
|
||||
func HackRpcContextIntoData(dataPtr any, rpcContext RpcContext) {
|
||||
dataVal := reflect.ValueOf(dataPtr).Elem()
|
||||
if dataVal.Kind() != reflect.Struct {
|
||||
return
|
||||
}
|
||||
dataType := dataVal.Type()
|
||||
for i := 0; i < dataVal.NumField(); i++ {
|
||||
field := dataVal.Field(i)
|
||||
|
@ -8,11 +8,9 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"reflect"
|
||||
|
||||
"github.com/wavetermdev/thenextwave/pkg/util/utilfn"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wavebase"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshutil"
|
||||
)
|
||||
@ -159,48 +157,18 @@ func mainWshServerHandler(handler *wshutil.RpcResponseHandler) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func MakeUnixListener(sockName string) (net.Listener, error) {
|
||||
os.Remove(sockName) // ignore error
|
||||
rtn, err := net.Listen("unix", sockName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating listener at %v: %v", sockName, err)
|
||||
}
|
||||
os.Chmod(sockName, 0700)
|
||||
log.Printf("Server listening on %s\n", sockName)
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func runWshRpcWithStream(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
inputCh := make(chan []byte, DefaultInputChSize)
|
||||
outputCh := make(chan []byte, DefaultOutputChSize)
|
||||
go wshutil.AdaptMsgChToStream(outputCh, conn)
|
||||
go wshutil.AdaptStreamToMsgCh(conn, inputCh)
|
||||
wshutil.MakeWshRpc(inputCh, outputCh, wshrpc.RpcContext{}, mainWshServerHandler)
|
||||
}
|
||||
|
||||
func RunWshRpcOverListener(listener net.Listener) {
|
||||
go func() {
|
||||
defer log.Printf("domain socket listener shutting down\n")
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.Printf("error accepting connection: %v\n", err)
|
||||
continue
|
||||
}
|
||||
go runWshRpcWithStream(conn)
|
||||
log.Print("got domain socket connection\n")
|
||||
// TODO deal with closing connection
|
||||
go wshutil.SetupConnRpcClient(conn, mainWshServerHandler)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func RunDomainSocketWshServer() error {
|
||||
sockName := wavebase.GetDomainSocketName()
|
||||
listener, err := MakeUnixListener(sockName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error starging unix listener for wsh-server: %w", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
RunWshRpcOverListener(listener)
|
||||
return nil
|
||||
}
|
||||
|
||||
func MakeWshServer(inputCh chan []byte, outputCh chan []byte, initialCtx wshrpc.RpcContext) {
|
||||
|
@ -43,28 +43,32 @@ func streamToLines_processBuf(lineBuf *lineBuf, readBuf []byte, lineFn func([]by
|
||||
}
|
||||
}
|
||||
|
||||
func streamToLines(input io.Reader, lineFn func([]byte)) {
|
||||
func streamToLines(input io.Reader, lineFn func([]byte)) error {
|
||||
var lineBuf lineBuf
|
||||
readBuf := make([]byte, 16*1024)
|
||||
for {
|
||||
n, err := input.Read(readBuf)
|
||||
streamToLines_processBuf(&lineBuf, readBuf[:n], lineFn)
|
||||
if err != nil {
|
||||
break
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func AdaptStreamToMsgCh(input io.Reader, output chan []byte) {
|
||||
streamToLines(input, func(line []byte) {
|
||||
func AdaptStreamToMsgCh(input io.Reader, output chan []byte) error {
|
||||
return streamToLines(input, func(line []byte) {
|
||||
output <- line
|
||||
})
|
||||
}
|
||||
|
||||
func AdaptMsgChToStream(outputCh chan []byte, output io.Writer) error {
|
||||
func AdaptOutputChToStream(outputCh chan []byte, output io.Writer) error {
|
||||
for msg := range outputCh {
|
||||
if _, err := output.Write(msg); err != nil {
|
||||
return fmt.Errorf("error writing to output: %w", err)
|
||||
return fmt.Errorf("error writing to output (AdaptOutputChToStream): %w", err)
|
||||
}
|
||||
// write trailing newline
|
||||
if _, err := output.Write([]byte{'\n'}); err != nil {
|
||||
return fmt.Errorf("error writing trailing newline to output (AdaptOutputChToStream): %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
@ -35,6 +36,11 @@ const BEL = 0x07
|
||||
const ST = 0x9c
|
||||
const ESC = 0x1b
|
||||
|
||||
const DefaultOutputChSize = 32
|
||||
const DefaultInputChSize = 32
|
||||
|
||||
const WaveJwtTokenVarName = "WAVETERM_JWT"
|
||||
|
||||
// OSC escape types
|
||||
// OSC 23198 ; (JSON | base64-JSON) ST
|
||||
// JSON = must escape all ASCII control characters ([\x00-\x1F\x7F])
|
||||
@ -181,8 +187,8 @@ func RestoreTermState() {
|
||||
|
||||
// returns (wshRpc, wrappedStdin)
|
||||
func SetupTerminalRpcClient(handlerFn func(*RpcResponseHandler) bool) (*WshRpc, io.Reader) {
|
||||
messageCh := make(chan []byte, 32)
|
||||
outputCh := make(chan []byte, 32)
|
||||
messageCh := make(chan []byte, DefaultInputChSize)
|
||||
outputCh := make(chan []byte, DefaultOutputChSize)
|
||||
ptyBuf := MakePtyBuffer(WaveServerOSCPrefix, os.Stdin, messageCh)
|
||||
rpcClient := MakeWshRpc(messageCh, outputCh, wshrpc.RpcContext{}, handlerFn)
|
||||
go func() {
|
||||
@ -194,6 +200,42 @@ func SetupTerminalRpcClient(handlerFn func(*RpcResponseHandler) bool) (*WshRpc,
|
||||
return rpcClient, ptyBuf
|
||||
}
|
||||
|
||||
func SetupConnRpcClient(conn net.Conn, handlerFn func(*RpcResponseHandler) bool) (*WshRpc, chan error, error) {
|
||||
inputCh := make(chan []byte, DefaultInputChSize)
|
||||
outputCh := make(chan []byte, DefaultOutputChSize)
|
||||
writeErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
writeErr := AdaptOutputChToStream(outputCh, conn)
|
||||
if writeErr != nil {
|
||||
writeErrCh <- writeErr
|
||||
close(writeErrCh)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
// when input is closed, close the connection
|
||||
defer conn.Close()
|
||||
AdaptStreamToMsgCh(conn, inputCh)
|
||||
}()
|
||||
rtn := MakeWshRpc(inputCh, outputCh, wshrpc.RpcContext{}, handlerFn)
|
||||
return rtn, writeErrCh, nil
|
||||
}
|
||||
|
||||
func SetupDomainSocketRpcClient(sockName string, handlerFn func(*RpcResponseHandler) bool) (*WshRpc, error) {
|
||||
conn, err := net.Dial("unix", sockName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to Unix domain socket: %w", err)
|
||||
}
|
||||
rtn, errCh, err := SetupConnRpcClient(conn, handlerFn)
|
||||
go func() {
|
||||
defer conn.Close()
|
||||
err := <-errCh
|
||||
if err != nil && err != io.EOF {
|
||||
log.Printf("error in domain socket connection: %v\n", err)
|
||||
}
|
||||
}()
|
||||
return rtn, err
|
||||
}
|
||||
|
||||
func MakeClientJWTToken(rpcCtx wshrpc.RpcContext, sockName string) (string, error) {
|
||||
claims := jwt.MapClaims{}
|
||||
claims["iat"] = time.Now().Unix()
|
||||
@ -246,9 +288,21 @@ func ValidateAndExtractRpcContextFromToken(tokenStr string) (*wshrpc.RpcContext,
|
||||
return nil, fmt.Errorf("iss claim is missing or invalid")
|
||||
}
|
||||
rpcCtx := &wshrpc.RpcContext{}
|
||||
rpcCtx.BlockId = claims["blockid"].(string)
|
||||
rpcCtx.TabId = claims["tabid"].(string)
|
||||
rpcCtx.WindowId = claims["windowid"].(string)
|
||||
if claims["blockid"] != nil {
|
||||
if blockId, ok := claims["blockid"].(string); ok {
|
||||
rpcCtx.BlockId = blockId
|
||||
}
|
||||
}
|
||||
if claims["tabid"] != nil {
|
||||
if tabId, ok := claims["tabid"].(string); ok {
|
||||
rpcCtx.TabId = tabId
|
||||
}
|
||||
}
|
||||
if claims["windowid"] != nil {
|
||||
if windowId, ok := claims["windowid"].(string); ok {
|
||||
rpcCtx.WindowId = windowId
|
||||
}
|
||||
}
|
||||
return rpcCtx, nil
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user