set up remote connserver (#245)

This commit is contained in:
Mike Sawka 2024-08-18 21:26:44 -07:00 committed by GitHub
parent c30188552f
commit 85874f92ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 357 additions and 212 deletions

View File

@ -7,15 +7,15 @@ import (
"os"
"github.com/spf13/cobra"
"github.com/wavetermdev/thenextwave/pkg/wshrpc/wshclient"
"github.com/wavetermdev/thenextwave/pkg/wshrpc/wshremote"
)
var serverCmd = &cobra.Command{
Use: "server",
Use: "connserver",
Short: "remote server to power wave blocks",
Args: cobra.NoArgs,
Run: serverRun,
PreRunE: preRunSetupRpcClient,
}
func init() {
@ -23,10 +23,8 @@ func init() {
}
func serverRun(cmd *cobra.Command, args []string) {
WriteStdout("running wsh server\n")
WriteStdout("running wsh connserver\n")
RpcClient.SetServerImpl(&wshremote.ServerImpl{LogWriter: os.Stdout})
err := wshclient.TestCommand(RpcClient, "hello", nil)
WriteStdout("got test rtn: %v\n", err)
select {} // run forever
}

View File

@ -13,6 +13,7 @@ var deleteBlockCmd = &cobra.Command{
Short: "delete a block",
Args: cobra.ExactArgs(1),
Run: deleteBlockRun,
PreRunE: preRunSetupRpcClient,
}
func init() {

View File

@ -16,6 +16,7 @@ var getMetaCmd = &cobra.Command{
Short: "get metadata for an entity",
Args: cobra.RangeArgs(1, 2),
Run: getMetaRun,
PreRunE: preRunSetupRpcClient,
}
func init() {

View File

@ -18,6 +18,7 @@ var htmlCmd = &cobra.Command{
Use: "html",
Short: "Launch a demo html-mode terminal",
Run: htmlRun,
PreRunE: preRunSetupRpcClient,
}
func htmlRun(cmd *cobra.Command, args []string) {

View File

@ -16,6 +16,7 @@ var readFileCmd = &cobra.Command{
Short: "read a blockfile",
Args: cobra.ExactArgs(2),
Run: runReadFile,
PreRunE: preRunSetupRpcClient,
}
func init() {

View File

@ -26,6 +26,7 @@ var (
Use: "wsh",
Short: "CLI tool to control Wave Terminal",
Long: `wsh is a small utility that lets you do cool things with Wave Terminal, right from the command line`,
SilenceUsage: true,
}
)
@ -60,9 +61,17 @@ func WriteStdout(fmtStr string, args ...interface{}) {
fmt.Print(output)
}
func preRunSetupRpcClient(cmd *cobra.Command, args []string) error {
err := setupRpcClient(nil)
if err != nil {
return err
}
return nil
}
// returns the wrapped stdin and a new rpc client (that wraps the stdin input and stdout output)
func setupRpcClient(serverImpl wshutil.ServerImpl) error {
jwtToken := os.Getenv("WAVETERM_JWT")
jwtToken := os.Getenv(wshutil.WaveJwtTokenVarName)
if jwtToken == "" {
wshutil.SetTermRawModeAndInstallShutdownHandlers(true)
UsingTermWshMode = true
@ -71,7 +80,7 @@ func setupRpcClient(serverImpl wshutil.ServerImpl) error {
}
sockName, err := wshutil.ExtractUnverifiedSocketName(jwtToken)
if err != nil {
return fmt.Errorf("error extracting socket name from WAVETERM_JWT: %v", err)
return fmt.Errorf("error extracting socket name from %s: %v", wshutil.WaveJwtTokenVarName, err)
}
RpcClient, err = wshutil.SetupDomainSocketRpcClient(sockName, serverImpl)
if err != nil {
@ -158,13 +167,7 @@ func Execute() {
wshutil.DoShutdown("", 0, false)
}
}()
err := setupRpcClient(nil)
if err != nil {
log.Printf("[error] %v\n", err)
wshutil.DoShutdown("", 1, true)
return
}
err = rootCmd.Execute()
err := rootCmd.Execute()
if err != nil {
log.Printf("[error] %v\n", err)
wshutil.DoShutdown("", 1, true)

View File

@ -18,6 +18,7 @@ var setMetaCmd = &cobra.Command{
Short: "set metadata for an entity",
Args: cobra.MinimumNArgs(2),
Run: setMetaRun,
PreRunE: preRunSetupRpcClient,
}
func init() {

View File

@ -19,6 +19,7 @@ var termCmd = &cobra.Command{
Short: "open a terminal in directory",
Args: cobra.RangeArgs(0, 1),
Run: termRun,
PreRunE: preRunSetupRpcClient,
}
func init() {

View File

@ -21,6 +21,7 @@ var viewCmd = &cobra.Command{
Short: "preview a file or directory",
Args: cobra.ExactArgs(1),
Run: viewRun,
PreRunE: preRunSetupRpcClient,
}
func init() {

View File

@ -18,6 +18,7 @@ import (
"github.com/wavetermdev/thenextwave/pkg/eventbus"
"github.com/wavetermdev/thenextwave/pkg/filestore"
"github.com/wavetermdev/thenextwave/pkg/remote"
"github.com/wavetermdev/thenextwave/pkg/remote/conncontroller"
"github.com/wavetermdev/thenextwave/pkg/shellexec"
"github.com/wavetermdev/thenextwave/pkg/wavebase"
"github.com/wavetermdev/thenextwave/pkg/waveobj"
@ -277,7 +278,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
if err != nil {
return err
}
conn, err := remote.GetConn(credentialCtx, opts)
conn, err := conncontroller.GetConn(credentialCtx, opts)
if err != nil {
return err
}
@ -288,7 +289,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
}
cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr
}
shellProc, err = shellexec.StartRemoteShellProc(rc.TermSize, cmdStr, cmdOpts, conn)
shellProc, err = shellexec.StartRemoteShellProc(rc.TermSize, cmdStr, cmdOpts, conn.Client)
if err != nil {
return err
}
@ -317,7 +318,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
// we don't need to authenticate this wshProxy since it is coming direct
wshProxy := wshutil.MakeRpcProxy()
wshProxy.SetRpcContext(&wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId})
wshutil.DefaultRouter.RegisterRoute("controller:"+bc.BlockId, wshProxy)
wshutil.DefaultRouter.RegisterRoute(wshutil.MakeControllerRouteId(bc.BlockId), wshProxy)
ptyBuffer := wshutil.MakePtyBuffer(wshutil.WaveOSCPrefix, bc.ShellProc.Cmd, wshProxy.FromRemoteCh)
go func() {
// handles regular output from the pty (goes to the blockfile and xterm)
@ -376,7 +377,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
go func() {
// wait for the shell to finish
defer func() {
wshutil.DefaultRouter.UnregisterRoute("controller:" + bc.BlockId)
wshutil.DefaultRouter.UnregisterRoute(wshutil.MakeControllerRouteId(bc.BlockId))
bc.UpdateControllerAndSendUpdate(func() bool {
bc.ShellProcStatus = Status_Done
return true

View File

@ -0,0 +1,224 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package conncontroller
import (
"context"
"fmt"
"io"
"log"
"net"
"strings"
"sync"
"github.com/wavetermdev/thenextwave/pkg/remote"
"github.com/wavetermdev/thenextwave/pkg/userinput"
"github.com/wavetermdev/thenextwave/pkg/util/shellutil"
"github.com/wavetermdev/thenextwave/pkg/util/utilfn"
"github.com/wavetermdev/thenextwave/pkg/wavebase"
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
"github.com/wavetermdev/thenextwave/pkg/wshutil"
"golang.org/x/crypto/ssh"
)
var globalLock = &sync.Mutex{}
var clientControllerMap = make(map[remote.SSHOpts]*SSHConn)
type SSHConn struct {
Lock *sync.Mutex
Opts *remote.SSHOpts
Client *ssh.Client
SockName string
DomainSockListener net.Listener
ConnController *ssh.Session
}
func (conn *SSHConn) Close() error {
if conn.DomainSockListener != nil {
conn.DomainSockListener.Close()
conn.DomainSockListener = nil
}
if conn.ConnController != nil {
conn.ConnController.Close()
conn.ConnController = nil
}
err := conn.Client.Close()
conn.Client = nil
return err
}
func (conn *SSHConn) OpenDomainSocketListener() error {
if conn.DomainSockListener != nil {
return nil
}
randStr, err := utilfn.RandomHexString(16) // 64-bits of randomness
if err != nil {
return fmt.Errorf("error generating random string: %w", err)
}
sockName := fmt.Sprintf("/tmp/waveterm-%s.sock", randStr)
log.Printf("remote domain socket %s %q\n", conn.Opts.String(), sockName)
listener, err := conn.Client.ListenUnix(sockName)
if err != nil {
return fmt.Errorf("unable to request connection domain socket: %v", err)
}
conn.SockName = sockName
conn.DomainSockListener = listener
go func() {
defer func() {
conn.Lock.Lock()
defer conn.Lock.Unlock()
conn.DomainSockListener = nil
}()
wshutil.RunWshRpcOverListener(listener)
}()
return nil
}
func (conn *SSHConn) StartConnServer() error {
conn.Lock.Lock()
defer conn.Lock.Unlock()
if conn.ConnController != nil {
return nil
}
wshPath := remote.GetWshPath(conn.Client)
rpcCtx := wshrpc.RpcContext{
Conn: conn.Opts.String(),
}
jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, conn.SockName)
if err != nil {
return fmt.Errorf("unable to create jwt token for conn controller: %w", err)
}
sshSession, err := conn.Client.NewSession()
if err != nil {
return fmt.Errorf("unable to create ssh session for conn controller: %w", err)
}
pipeRead, pipeWrite := io.Pipe()
sshSession.Stdout = pipeWrite
sshSession.Stderr = pipeWrite
conn.ConnController = sshSession
cmdStr := fmt.Sprintf("%s=\"%s\" %s connserver", wshutil.WaveJwtTokenVarName, jwtToken, wshPath)
log.Printf("starting conn controller: %s\n", cmdStr)
err = sshSession.Start(cmdStr)
if err != nil {
return fmt.Errorf("unable to start conn controller: %w", err)
}
// service the I/O
go func() {
// wait for termination, clear the controller
waitErr := sshSession.Wait()
log.Printf("conn controller (%q) terminated: %v", conn.Opts.String(), waitErr)
conn.Lock.Lock()
defer conn.Lock.Unlock()
conn.ConnController = nil
}()
go func() {
readErr := wshutil.StreamToLines(pipeRead, func(line []byte) {
lineStr := string(line)
if !strings.HasSuffix(lineStr, "\n") {
lineStr += "\n"
}
log.Printf("[conncontroller:%s:output] %s", conn.Opts.String(), lineStr)
})
if readErr != nil && readErr != io.EOF {
log.Printf("[conncontroller:%s] error reading output: %v\n", conn.Opts.String(), readErr)
}
}()
return nil
}
func (conn *SSHConn) checkAndInstallWsh(ctx context.Context) error {
client := conn.Client
// check that correct wsh extensions are installed
expectedVersion := fmt.Sprintf("wsh v%s", wavebase.WaveVersion)
clientVersion, err := remote.GetWshVersion(client)
if err == nil && clientVersion == expectedVersion {
return nil
}
var queryText string
var title string
if err != nil {
queryText = "Waveterm requires `wsh` shell extensions installed on your client to ensure a seamless experience. Would you like to install them?"
title = "Install Wsh Shell Extensions"
} else {
queryText = fmt.Sprintf("Waveterm requires `wsh` shell extensions installed on your client to be updated from %s to %s. Would you like to update?", clientVersion, expectedVersion)
title = "Update Wsh Shell Extensions"
}
request := &userinput.UserInputRequest{
ResponseType: "confirm",
QueryText: queryText,
Title: title,
CheckBoxMsg: "Don't show me this again",
}
response, err := userinput.GetUserInput(ctx, request)
if err != nil || !response.Confirm {
return err
}
log.Printf("attempting to install wsh to `%s@%s`", client.User(), client.RemoteAddr().String())
clientOs, err := remote.GetClientOs(client)
if err != nil {
return err
}
clientArch, err := remote.GetClientArch(client)
if err != nil {
return err
}
// attempt to install extension
wshLocalPath := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch)
err = remote.CpHostToRemote(client, wshLocalPath, "~/.waveterm/bin/wsh")
if err != nil {
return err
}
log.Printf("successfully installed wsh on %s\n", conn.Opts.String())
return nil
}
func GetConn(ctx context.Context, opts *remote.SSHOpts) (*SSHConn, error) {
globalLock.Lock()
defer globalLock.Unlock()
// attempt to retrieve if already opened
conn, ok := clientControllerMap[*opts]
if ok {
return conn, nil
}
client, err := remote.ConnectToClient(ctx, opts) //todo specify or remove opts
if err != nil {
return nil, err
}
conn = &SSHConn{Lock: &sync.Mutex{}, Opts: opts, Client: client}
err = conn.OpenDomainSocketListener()
if err != nil {
conn.Close()
return nil, err
}
installErr := conn.checkAndInstallWsh(ctx)
if installErr != nil {
conn.Close()
return nil, fmt.Errorf("conncontroller %s wsh install error: %v", conn.Opts.String(), installErr)
}
csErr := conn.StartConnServer()
if csErr != nil {
conn.Close()
return nil, fmt.Errorf("conncontroller %s start wsh connserver error: %v", conn.Opts.String(), csErr)
}
// save successful connection to map
clientControllerMap[*opts] = conn
return conn, nil
}
func DisconnectClient(opts *remote.SSHOpts) error {
globalLock.Lock()
defer globalLock.Unlock()
client, ok := clientControllerMap[*opts]
if ok {
return client.Close()
}
return fmt.Errorf("client %v not found", opts)
}

View File

@ -2,156 +2,20 @@ package remote
import (
"bytes"
"context"
"fmt"
"html/template"
"io"
"log"
"net"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"sync"
"github.com/wavetermdev/thenextwave/pkg/userinput"
"github.com/wavetermdev/thenextwave/pkg/util/shellutil"
"github.com/wavetermdev/thenextwave/pkg/util/utilfn"
"github.com/wavetermdev/thenextwave/pkg/wavebase"
"github.com/wavetermdev/thenextwave/pkg/wshutil"
"golang.org/x/crypto/ssh"
)
var userHostRe = regexp.MustCompile(`^([a-zA-Z0-9][a-zA-Z0-9._@\\-]*@)?([a-z0-9][a-z0-9.-]*)(?::([0-9]+))?$`)
var globalLock = &sync.Mutex{}
var clientControllerMap = make(map[SSHOpts]*SSHConn)
type SSHConn struct {
Lock *sync.Mutex
Opts *SSHOpts
Client *ssh.Client
SockName string
DomainSockListener net.Listener
}
func (conn *SSHConn) Close() error {
if conn.DomainSockListener != nil {
conn.DomainSockListener.Close()
}
return conn.Client.Close()
}
func (conn *SSHConn) OpenDomainSocketListener() error {
if conn.DomainSockListener != nil {
return nil
}
randStr, err := utilfn.RandomHexString(16) // 64-bits of randomness
if err != nil {
return fmt.Errorf("error generating random string: %w", err)
}
sockName := fmt.Sprintf("/tmp/waveterm-%s.sock", randStr)
log.Printf("remote domain socket %s %q\n", conn.Opts.String(), sockName)
listener, err := conn.Client.ListenUnix(sockName)
if err != nil {
return fmt.Errorf("unable to request connection domain socket: %v", err)
}
conn.SockName = sockName
conn.DomainSockListener = listener
go func() {
wshutil.RunWshRpcOverListener(listener)
}()
return nil
}
func GetConn(ctx context.Context, opts *SSHOpts) (*SSHConn, error) {
globalLock.Lock()
defer globalLock.Unlock()
// attempt to retrieve if already opened
conn, ok := clientControllerMap[*opts]
if ok {
return conn, nil
}
client, err := ConnectToClient(ctx, opts) //todo specify or remove opts
if err != nil {
return nil, err
}
conn = &SSHConn{Lock: &sync.Mutex{}, Opts: opts, Client: client}
err = conn.OpenDomainSocketListener()
if err != nil {
conn.Close()
return nil, err
}
// check that correct wsh extensions are installed
expectedVersion := fmt.Sprintf("wsh v%s", wavebase.WaveVersion)
clientVersion, err := getWshVersion(client)
if err == nil && clientVersion == expectedVersion {
// save successful connection to map
clientControllerMap[*opts] = conn
return conn, nil
}
var queryText string
var title string
if err != nil {
queryText = "Waveterm requires `wsh` shell extensions installed on your client to ensure a seamless experience. Would you like to install them?"
title = "Install Wsh Shell Extensions"
} else {
queryText = fmt.Sprintf("Waveterm requires `wsh` shell extensions installed on your client to be updated from %s to %s. Would you like to update?", clientVersion, expectedVersion)
title = "Update Wsh Shell Extensions"
}
request := &userinput.UserInputRequest{
ResponseType: "confirm",
QueryText: queryText,
Title: title,
CheckBoxMsg: "Don't show me this again",
}
response, err := userinput.GetUserInput(ctx, request)
if err != nil || !response.Confirm {
return nil, err
}
log.Printf("attempting to install wsh to `%s@%s`", client.User(), client.RemoteAddr().String())
clientOs, err := getClientOs(client)
if err != nil {
return nil, err
}
clientArch, err := getClientArch(client)
if err != nil {
return nil, err
}
// attempt to install extension
wshLocalPath := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch)
err = cpHostToRemote(client, wshLocalPath, "~/.waveterm/bin/wsh")
if err != nil {
return nil, err
}
log.Printf("successful install")
// save successful connection to map
clientControllerMap[*opts] = conn
return conn, nil
}
func DisconnectClient(opts *SSHOpts) error {
globalLock.Lock()
defer globalLock.Unlock()
client, ok := clientControllerMap[*opts]
if ok {
return client.Close()
}
return fmt.Errorf("client %v not found", opts)
}
func ParseOpts(input string) (*SSHOpts, error) {
m := userHostRe.FindStringSubmatch(input)
@ -173,7 +37,7 @@ func ParseOpts(input string) (*SSHOpts, error) {
}
func DetectShell(client *ssh.Client) (string, error) {
wshPath := getWshPath(client)
wshPath := GetWshPath(client)
session, err := client.NewSession()
if err != nil {
@ -191,8 +55,8 @@ func DetectShell(client *ssh.Client) (string, error) {
return fmt.Sprintf(`"%s"`, strings.TrimSpace(string(out))), nil
}
func getWshVersion(client *ssh.Client) (string, error) {
wshPath := getWshPath(client)
func GetWshVersion(client *ssh.Client) (string, error) {
wshPath := GetWshPath(client)
session, err := client.NewSession()
if err != nil {
@ -207,7 +71,7 @@ func getWshVersion(client *ssh.Client) (string, error) {
return strings.TrimSpace(string(out)), nil
}
func getWshPath(client *ssh.Client) string {
func GetWshPath(client *ssh.Client) string {
defaultPath := filepath.Join("~", ".waveterm", "bin", "wsh")
session, err := client.NewSession()
@ -267,7 +131,7 @@ func hasBashInstalled(client *ssh.Client) (bool, error) {
return false, nil
}
func getClientOs(client *ssh.Client) (string, error) {
func GetClientOs(client *ssh.Client) (string, error) {
session, err := client.NewSession()
if err != nil {
return "", err
@ -306,7 +170,7 @@ func getClientOs(client *ssh.Client) (string, error) {
return "", fmt.Errorf("unable to determine os: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr)
}
func getClientArch(client *ssh.Client) (string, error) {
func GetClientArch(client *ssh.Client) (string, error) {
session, err := client.NewSession()
if err != nil {
return "", err
@ -360,7 +224,7 @@ mv {{.tempPath}} {{.installPath}}; \
chmod a+x {{.installPath}}; \
`
func cpHostToRemote(client *ssh.Client, sourcePath string, destPath string) error {
func CpHostToRemote(client *ssh.Client, sourcePath string, destPath string) error {
// warning: does not work on windows remote yet
bashInstalled, err := hasBashInstalled(client)
if err != nil {
@ -414,7 +278,7 @@ func cpHostToRemote(client *ssh.Client, sourcePath string, destPath string) erro
}
func InstallClientRcFiles(client *ssh.Client) error {
path := getWshPath(client)
path := GetWshPath(client)
session, err := client.NewSession()
if err != nil {

View File

@ -24,6 +24,7 @@ import (
"github.com/wavetermdev/thenextwave/pkg/wavebase"
"github.com/wavetermdev/thenextwave/pkg/wshutil"
"github.com/wavetermdev/thenextwave/pkg/wstore"
"golang.org/x/crypto/ssh"
)
type CommandOptsType struct {
@ -149,8 +150,7 @@ func (pp *PipePty) WriteString(s string) (n int, err error) {
return pp.Write([]byte(s))
}
func StartRemoteShellProc(termSize wstore.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *remote.SSHConn) (*ShellProc, error) {
client := conn.Client
func StartRemoteShellProc(termSize wstore.TermSize, cmdStr string, cmdOpts CommandOptsType, client *ssh.Client) (*ShellProc, error) {
shellPath, err := remote.DetectShell(client)
if err != nil {
return nil, err
@ -195,6 +195,7 @@ func StartRemoteShellProc(termSize wstore.TermSize, cmdStr string, cmdOpts Comma
}
shellOpts = append(shellOpts, "-c", cmdStr)
cmdCombined = fmt.Sprintf("%s %s", shellPath, strings.Join(shellOpts, " "))
log.Printf("combined command is: %s", cmdCombined)
}
session, err := client.NewSession()

View File

@ -19,6 +19,7 @@ import (
"github.com/wavetermdev/thenextwave/pkg/util/utilfn"
"github.com/wavetermdev/thenextwave/pkg/wavebase"
"github.com/wavetermdev/thenextwave/pkg/wstore"
)
const DefaultTermType = "xterm-256color"
@ -125,6 +126,10 @@ func internalMacUserShell() string {
return m[1]
}
func DefaultTermSize() wstore.TermSize {
return wstore.TermSize{Rows: DefaultTermRows, Cols: DefaultTermCols}
}
func WaveshellLocalEnvVars(termType string) map[string]string {
rtn := make(map[string]string)
if termType != "" {

View File

@ -257,10 +257,9 @@ func HandleWsInternal(w http.ResponseWriter, r *http.Request) error {
defer eventbus.UnregisterWSChannel(wsConnId)
// we create a wshproxy to handle rpc messages to/from the window
wproxy := wshutil.MakeRpcProxy()
rpcRouteId := "window:" + windowId
wshutil.DefaultRouter.RegisterRoute(rpcRouteId, wproxy)
wshutil.DefaultRouter.RegisterRoute(wshutil.MakeWindowRouteId(windowId), wproxy)
defer func() {
wshutil.DefaultRouter.UnregisterRoute(rpcRouteId)
wshutil.DefaultRouter.UnregisterRoute(wshutil.MakeWindowRouteId(windowId))
close(wproxy.ToRemoteCh)
}()
// WshServerFactoryFn(rpcInputCh, rpcOutputCh, wshrpc.RpcContext{})

View File

@ -4,6 +4,8 @@
package wshclient
import (
"errors"
"github.com/wavetermdev/thenextwave/pkg/util/utilfn"
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
"github.com/wavetermdev/thenextwave/pkg/wshutil"
@ -14,6 +16,9 @@ func sendRpcRequestCallHelper[T any](w *wshutil.WshRpc, command string, data int
opts = &wshrpc.RpcOpts{}
}
var respData T
if w == nil {
return respData, errors.New("nil wshrpc passed to wshclient")
}
if opts.NoResponse {
err := w.SendCommand(command, data, opts)
if err != nil {
@ -32,17 +37,26 @@ func sendRpcRequestCallHelper[T any](w *wshutil.WshRpc, command string, data int
return respData, nil
}
func rtnErr[T any](ch chan wshrpc.RespOrErrorUnion[T], err error) {
go func() {
ch <- wshrpc.RespOrErrorUnion[T]{Error: err}
close(ch)
}()
}
func sendRpcRequestResponseStreamHelper[T any](w *wshutil.WshRpc, command string, data interface{}, opts *wshrpc.RpcOpts) chan wshrpc.RespOrErrorUnion[T] {
if opts == nil {
opts = &wshrpc.RpcOpts{}
}
respChan := make(chan wshrpc.RespOrErrorUnion[T])
if w == nil {
rtnErr(respChan, errors.New("nil wshrpc passed to wshclient"))
return respChan
}
reqHandler, err := w.SendComplexRequest(command, data, opts)
if err != nil {
go func() {
respChan <- wshrpc.RespOrErrorUnion[T]{Error: err}
close(respChan)
}()
rtnErr(respChan, err)
return respChan
} else {
go func() {
defer close(respChan)

View File

@ -107,6 +107,7 @@ type RpcOpts struct {
type RpcContext struct {
BlockId string `json:"blockid,omitempty"`
TabId string `json:"tabid,omitempty"`
Conn string `json:"conn,omitempty"`
}
func HackRpcContextIntoData(dataPtr any, rpcContext RpcContext) {

View File

@ -6,6 +6,7 @@ package wshutil
import (
"encoding/json"
"fmt"
"log"
"sync"
"github.com/google/uuid"
@ -81,12 +82,14 @@ func handleAuthenticationCommand(msg RpcMessage) (*wshrpc.RpcContext, error) {
if newCtx == nil {
return nil, fmt.Errorf("no context found in jwt token")
}
if newCtx.BlockId == "" {
return nil, fmt.Errorf("no blockId found in jwt token")
if newCtx.BlockId == "" && newCtx.Conn == "" {
return nil, fmt.Errorf("no blockid or conn found in jwt token")
}
if newCtx.BlockId != "" {
if _, err := uuid.Parse(newCtx.BlockId); err != nil {
return nil, fmt.Errorf("invalid blockId in jwt token")
}
}
return newCtx, nil
}
@ -114,6 +117,7 @@ func (p *WshRpcProxy) HandleAuthentication() (*wshrpc.RpcContext, error) {
}
newCtx, err := handleAuthenticationCommand(msg)
if err != nil {
log.Printf("error handling authentication: %v\n", err)
p.sendResponseError(msg, err)
continue
}

View File

@ -42,6 +42,14 @@ func MakeConnectionRouteId(connId string) string {
return "conn:" + connId
}
func MakeControllerRouteId(blockId string) string {
return "controller:" + blockId
}
func MakeWindowRouteId(windowId string) string {
return "window:" + windowId
}
var DefaultRouter = NewWshRouter()
func NewWshRouter() *WshRouter {
@ -230,6 +238,7 @@ func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient) {
log.Printf("error: WshRouter cannot register sys route\n")
return
}
log.Printf("registering wsh route %q\n", routeId)
router.Lock.Lock()
defer router.Lock.Unlock()
router.RouteMap[routeId] = rpc

View File

@ -43,7 +43,7 @@ func streamToLines_processBuf(lineBuf *lineBuf, readBuf []byte, lineFn func([]by
}
}
func streamToLines(input io.Reader, lineFn func([]byte)) error {
func StreamToLines(input io.Reader, lineFn func([]byte)) error {
var lineBuf lineBuf
readBuf := make([]byte, 16*1024)
for {
@ -56,7 +56,7 @@ func streamToLines(input io.Reader, lineFn func([]byte)) error {
}
func AdaptStreamToMsgCh(input io.Reader, output chan []byte) error {
return streamToLines(input, func(line []byte) {
return StreamToLines(input, func(line []byte) {
output <- line
})
}

View File

@ -248,6 +248,9 @@ func MakeClientJWTToken(rpcCtx wshrpc.RpcContext, sockName string) (string, erro
if rpcCtx.TabId != "" {
claims["tabid"] = rpcCtx.TabId
}
if rpcCtx.Conn != "" {
claims["conn"] = rpcCtx.Conn
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := token.SignedString([]byte(wavebase.JwtSecret))
if err != nil {
@ -299,6 +302,11 @@ func mapClaimsToRpcContext(claims jwt.MapClaims) *wshrpc.RpcContext {
rpcCtx.TabId = tabId
}
}
if claims["conn"] != nil {
if conn, ok := claims["conn"].(string); ok {
rpcCtx.Conn = conn
}
}
return rpcCtx
}
@ -306,9 +314,12 @@ func RunWshRpcOverListener(listener net.Listener) {
defer log.Printf("domain socket listener shutting down\n")
for {
conn, err := listener.Accept()
if err == io.EOF {
break
}
if err != nil {
log.Printf("error accepting connection: %v\n", err)
continue
break
}
log.Print("got domain socket connection\n")
go handleDomainSocketClient(conn)
@ -337,7 +348,11 @@ func handleDomainSocketClient(conn net.Conn) {
// now that we're authenticated, set the ctx and attach to the router
log.Printf("domain socket connection authenticated: %#v\n", rpcCtx)
proxy.SetRpcContext(rpcCtx)
DefaultRouter.RegisterRoute("controller:"+rpcCtx.BlockId, proxy)
if rpcCtx.BlockId != "" {
DefaultRouter.RegisterRoute(MakeControllerRouteId(rpcCtx.BlockId), proxy)
} else if rpcCtx.Conn != "" {
DefaultRouter.RegisterRoute(MakeConnectionRouteId(rpcCtx.Conn), proxy)
}
}
// only for use on client