mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-04 18:59:08 +01:00
set up remote connserver (#245)
This commit is contained in:
parent
c30188552f
commit
85874f92ca
@ -7,15 +7,15 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"github.com/wavetermdev/thenextwave/pkg/wshrpc/wshclient"
|
|
||||||
"github.com/wavetermdev/thenextwave/pkg/wshrpc/wshremote"
|
"github.com/wavetermdev/thenextwave/pkg/wshrpc/wshremote"
|
||||||
)
|
)
|
||||||
|
|
||||||
var serverCmd = &cobra.Command{
|
var serverCmd = &cobra.Command{
|
||||||
Use: "server",
|
Use: "connserver",
|
||||||
Short: "remote server to power wave blocks",
|
Short: "remote server to power wave blocks",
|
||||||
Args: cobra.NoArgs,
|
Args: cobra.NoArgs,
|
||||||
Run: serverRun,
|
Run: serverRun,
|
||||||
|
PreRunE: preRunSetupRpcClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@ -23,10 +23,8 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func serverRun(cmd *cobra.Command, args []string) {
|
func serverRun(cmd *cobra.Command, args []string) {
|
||||||
WriteStdout("running wsh server\n")
|
WriteStdout("running wsh connserver\n")
|
||||||
RpcClient.SetServerImpl(&wshremote.ServerImpl{LogWriter: os.Stdout})
|
RpcClient.SetServerImpl(&wshremote.ServerImpl{LogWriter: os.Stdout})
|
||||||
err := wshclient.TestCommand(RpcClient, "hello", nil)
|
|
||||||
WriteStdout("got test rtn: %v\n", err)
|
|
||||||
|
|
||||||
select {} // run forever
|
select {} // run forever
|
||||||
}
|
}
|
@ -13,6 +13,7 @@ var deleteBlockCmd = &cobra.Command{
|
|||||||
Short: "delete a block",
|
Short: "delete a block",
|
||||||
Args: cobra.ExactArgs(1),
|
Args: cobra.ExactArgs(1),
|
||||||
Run: deleteBlockRun,
|
Run: deleteBlockRun,
|
||||||
|
PreRunE: preRunSetupRpcClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
@ -16,6 +16,7 @@ var getMetaCmd = &cobra.Command{
|
|||||||
Short: "get metadata for an entity",
|
Short: "get metadata for an entity",
|
||||||
Args: cobra.RangeArgs(1, 2),
|
Args: cobra.RangeArgs(1, 2),
|
||||||
Run: getMetaRun,
|
Run: getMetaRun,
|
||||||
|
PreRunE: preRunSetupRpcClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
@ -18,6 +18,7 @@ var htmlCmd = &cobra.Command{
|
|||||||
Use: "html",
|
Use: "html",
|
||||||
Short: "Launch a demo html-mode terminal",
|
Short: "Launch a demo html-mode terminal",
|
||||||
Run: htmlRun,
|
Run: htmlRun,
|
||||||
|
PreRunE: preRunSetupRpcClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
func htmlRun(cmd *cobra.Command, args []string) {
|
func htmlRun(cmd *cobra.Command, args []string) {
|
||||||
|
@ -16,6 +16,7 @@ var readFileCmd = &cobra.Command{
|
|||||||
Short: "read a blockfile",
|
Short: "read a blockfile",
|
||||||
Args: cobra.ExactArgs(2),
|
Args: cobra.ExactArgs(2),
|
||||||
Run: runReadFile,
|
Run: runReadFile,
|
||||||
|
PreRunE: preRunSetupRpcClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
@ -26,6 +26,7 @@ var (
|
|||||||
Use: "wsh",
|
Use: "wsh",
|
||||||
Short: "CLI tool to control Wave Terminal",
|
Short: "CLI tool to control Wave Terminal",
|
||||||
Long: `wsh is a small utility that lets you do cool things with Wave Terminal, right from the command line`,
|
Long: `wsh is a small utility that lets you do cool things with Wave Terminal, right from the command line`,
|
||||||
|
SilenceUsage: true,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -60,9 +61,17 @@ func WriteStdout(fmtStr string, args ...interface{}) {
|
|||||||
fmt.Print(output)
|
fmt.Print(output)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func preRunSetupRpcClient(cmd *cobra.Command, args []string) error {
|
||||||
|
err := setupRpcClient(nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// returns the wrapped stdin and a new rpc client (that wraps the stdin input and stdout output)
|
// returns the wrapped stdin and a new rpc client (that wraps the stdin input and stdout output)
|
||||||
func setupRpcClient(serverImpl wshutil.ServerImpl) error {
|
func setupRpcClient(serverImpl wshutil.ServerImpl) error {
|
||||||
jwtToken := os.Getenv("WAVETERM_JWT")
|
jwtToken := os.Getenv(wshutil.WaveJwtTokenVarName)
|
||||||
if jwtToken == "" {
|
if jwtToken == "" {
|
||||||
wshutil.SetTermRawModeAndInstallShutdownHandlers(true)
|
wshutil.SetTermRawModeAndInstallShutdownHandlers(true)
|
||||||
UsingTermWshMode = true
|
UsingTermWshMode = true
|
||||||
@ -71,7 +80,7 @@ func setupRpcClient(serverImpl wshutil.ServerImpl) error {
|
|||||||
}
|
}
|
||||||
sockName, err := wshutil.ExtractUnverifiedSocketName(jwtToken)
|
sockName, err := wshutil.ExtractUnverifiedSocketName(jwtToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error extracting socket name from WAVETERM_JWT: %v", err)
|
return fmt.Errorf("error extracting socket name from %s: %v", wshutil.WaveJwtTokenVarName, err)
|
||||||
}
|
}
|
||||||
RpcClient, err = wshutil.SetupDomainSocketRpcClient(sockName, serverImpl)
|
RpcClient, err = wshutil.SetupDomainSocketRpcClient(sockName, serverImpl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -158,13 +167,7 @@ func Execute() {
|
|||||||
wshutil.DoShutdown("", 0, false)
|
wshutil.DoShutdown("", 0, false)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
err := setupRpcClient(nil)
|
err := rootCmd.Execute()
|
||||||
if err != nil {
|
|
||||||
log.Printf("[error] %v\n", err)
|
|
||||||
wshutil.DoShutdown("", 1, true)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = rootCmd.Execute()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[error] %v\n", err)
|
log.Printf("[error] %v\n", err)
|
||||||
wshutil.DoShutdown("", 1, true)
|
wshutil.DoShutdown("", 1, true)
|
||||||
|
@ -18,6 +18,7 @@ var setMetaCmd = &cobra.Command{
|
|||||||
Short: "set metadata for an entity",
|
Short: "set metadata for an entity",
|
||||||
Args: cobra.MinimumNArgs(2),
|
Args: cobra.MinimumNArgs(2),
|
||||||
Run: setMetaRun,
|
Run: setMetaRun,
|
||||||
|
PreRunE: preRunSetupRpcClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
@ -19,6 +19,7 @@ var termCmd = &cobra.Command{
|
|||||||
Short: "open a terminal in directory",
|
Short: "open a terminal in directory",
|
||||||
Args: cobra.RangeArgs(0, 1),
|
Args: cobra.RangeArgs(0, 1),
|
||||||
Run: termRun,
|
Run: termRun,
|
||||||
|
PreRunE: preRunSetupRpcClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
@ -21,6 +21,7 @@ var viewCmd = &cobra.Command{
|
|||||||
Short: "preview a file or directory",
|
Short: "preview a file or directory",
|
||||||
Args: cobra.ExactArgs(1),
|
Args: cobra.ExactArgs(1),
|
||||||
Run: viewRun,
|
Run: viewRun,
|
||||||
|
PreRunE: preRunSetupRpcClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
@ -18,6 +18,7 @@ import (
|
|||||||
"github.com/wavetermdev/thenextwave/pkg/eventbus"
|
"github.com/wavetermdev/thenextwave/pkg/eventbus"
|
||||||
"github.com/wavetermdev/thenextwave/pkg/filestore"
|
"github.com/wavetermdev/thenextwave/pkg/filestore"
|
||||||
"github.com/wavetermdev/thenextwave/pkg/remote"
|
"github.com/wavetermdev/thenextwave/pkg/remote"
|
||||||
|
"github.com/wavetermdev/thenextwave/pkg/remote/conncontroller"
|
||||||
"github.com/wavetermdev/thenextwave/pkg/shellexec"
|
"github.com/wavetermdev/thenextwave/pkg/shellexec"
|
||||||
"github.com/wavetermdev/thenextwave/pkg/wavebase"
|
"github.com/wavetermdev/thenextwave/pkg/wavebase"
|
||||||
"github.com/wavetermdev/thenextwave/pkg/waveobj"
|
"github.com/wavetermdev/thenextwave/pkg/waveobj"
|
||||||
@ -277,7 +278,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
conn, err := remote.GetConn(credentialCtx, opts)
|
conn, err := conncontroller.GetConn(credentialCtx, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -288,7 +289,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
|
|||||||
}
|
}
|
||||||
cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr
|
cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr
|
||||||
}
|
}
|
||||||
shellProc, err = shellexec.StartRemoteShellProc(rc.TermSize, cmdStr, cmdOpts, conn)
|
shellProc, err = shellexec.StartRemoteShellProc(rc.TermSize, cmdStr, cmdOpts, conn.Client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -317,7 +318,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
|
|||||||
// we don't need to authenticate this wshProxy since it is coming direct
|
// we don't need to authenticate this wshProxy since it is coming direct
|
||||||
wshProxy := wshutil.MakeRpcProxy()
|
wshProxy := wshutil.MakeRpcProxy()
|
||||||
wshProxy.SetRpcContext(&wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId})
|
wshProxy.SetRpcContext(&wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId})
|
||||||
wshutil.DefaultRouter.RegisterRoute("controller:"+bc.BlockId, wshProxy)
|
wshutil.DefaultRouter.RegisterRoute(wshutil.MakeControllerRouteId(bc.BlockId), wshProxy)
|
||||||
ptyBuffer := wshutil.MakePtyBuffer(wshutil.WaveOSCPrefix, bc.ShellProc.Cmd, wshProxy.FromRemoteCh)
|
ptyBuffer := wshutil.MakePtyBuffer(wshutil.WaveOSCPrefix, bc.ShellProc.Cmd, wshProxy.FromRemoteCh)
|
||||||
go func() {
|
go func() {
|
||||||
// handles regular output from the pty (goes to the blockfile and xterm)
|
// handles regular output from the pty (goes to the blockfile and xterm)
|
||||||
@ -376,7 +377,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
|
|||||||
go func() {
|
go func() {
|
||||||
// wait for the shell to finish
|
// wait for the shell to finish
|
||||||
defer func() {
|
defer func() {
|
||||||
wshutil.DefaultRouter.UnregisterRoute("controller:" + bc.BlockId)
|
wshutil.DefaultRouter.UnregisterRoute(wshutil.MakeControllerRouteId(bc.BlockId))
|
||||||
bc.UpdateControllerAndSendUpdate(func() bool {
|
bc.UpdateControllerAndSendUpdate(func() bool {
|
||||||
bc.ShellProcStatus = Status_Done
|
bc.ShellProcStatus = Status_Done
|
||||||
return true
|
return true
|
||||||
|
224
pkg/remote/conncontroller/conncontroller.go
Normal file
224
pkg/remote/conncontroller/conncontroller.go
Normal file
@ -0,0 +1,224 @@
|
|||||||
|
// Copyright 2024, Command Line Inc.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package conncontroller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/wavetermdev/thenextwave/pkg/remote"
|
||||||
|
"github.com/wavetermdev/thenextwave/pkg/userinput"
|
||||||
|
"github.com/wavetermdev/thenextwave/pkg/util/shellutil"
|
||||||
|
"github.com/wavetermdev/thenextwave/pkg/util/utilfn"
|
||||||
|
"github.com/wavetermdev/thenextwave/pkg/wavebase"
|
||||||
|
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
|
||||||
|
"github.com/wavetermdev/thenextwave/pkg/wshutil"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
var globalLock = &sync.Mutex{}
|
||||||
|
var clientControllerMap = make(map[remote.SSHOpts]*SSHConn)
|
||||||
|
|
||||||
|
type SSHConn struct {
|
||||||
|
Lock *sync.Mutex
|
||||||
|
Opts *remote.SSHOpts
|
||||||
|
Client *ssh.Client
|
||||||
|
SockName string
|
||||||
|
DomainSockListener net.Listener
|
||||||
|
ConnController *ssh.Session
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *SSHConn) Close() error {
|
||||||
|
if conn.DomainSockListener != nil {
|
||||||
|
conn.DomainSockListener.Close()
|
||||||
|
conn.DomainSockListener = nil
|
||||||
|
}
|
||||||
|
if conn.ConnController != nil {
|
||||||
|
conn.ConnController.Close()
|
||||||
|
conn.ConnController = nil
|
||||||
|
}
|
||||||
|
err := conn.Client.Close()
|
||||||
|
conn.Client = nil
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *SSHConn) OpenDomainSocketListener() error {
|
||||||
|
if conn.DomainSockListener != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
randStr, err := utilfn.RandomHexString(16) // 64-bits of randomness
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error generating random string: %w", err)
|
||||||
|
}
|
||||||
|
sockName := fmt.Sprintf("/tmp/waveterm-%s.sock", randStr)
|
||||||
|
log.Printf("remote domain socket %s %q\n", conn.Opts.String(), sockName)
|
||||||
|
listener, err := conn.Client.ListenUnix(sockName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to request connection domain socket: %v", err)
|
||||||
|
}
|
||||||
|
conn.SockName = sockName
|
||||||
|
conn.DomainSockListener = listener
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
conn.Lock.Lock()
|
||||||
|
defer conn.Lock.Unlock()
|
||||||
|
conn.DomainSockListener = nil
|
||||||
|
}()
|
||||||
|
wshutil.RunWshRpcOverListener(listener)
|
||||||
|
}()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *SSHConn) StartConnServer() error {
|
||||||
|
conn.Lock.Lock()
|
||||||
|
defer conn.Lock.Unlock()
|
||||||
|
if conn.ConnController != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
wshPath := remote.GetWshPath(conn.Client)
|
||||||
|
rpcCtx := wshrpc.RpcContext{
|
||||||
|
Conn: conn.Opts.String(),
|
||||||
|
}
|
||||||
|
jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, conn.SockName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to create jwt token for conn controller: %w", err)
|
||||||
|
}
|
||||||
|
sshSession, err := conn.Client.NewSession()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to create ssh session for conn controller: %w", err)
|
||||||
|
}
|
||||||
|
pipeRead, pipeWrite := io.Pipe()
|
||||||
|
sshSession.Stdout = pipeWrite
|
||||||
|
sshSession.Stderr = pipeWrite
|
||||||
|
conn.ConnController = sshSession
|
||||||
|
cmdStr := fmt.Sprintf("%s=\"%s\" %s connserver", wshutil.WaveJwtTokenVarName, jwtToken, wshPath)
|
||||||
|
log.Printf("starting conn controller: %s\n", cmdStr)
|
||||||
|
err = sshSession.Start(cmdStr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to start conn controller: %w", err)
|
||||||
|
}
|
||||||
|
// service the I/O
|
||||||
|
go func() {
|
||||||
|
// wait for termination, clear the controller
|
||||||
|
waitErr := sshSession.Wait()
|
||||||
|
log.Printf("conn controller (%q) terminated: %v", conn.Opts.String(), waitErr)
|
||||||
|
conn.Lock.Lock()
|
||||||
|
defer conn.Lock.Unlock()
|
||||||
|
conn.ConnController = nil
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
readErr := wshutil.StreamToLines(pipeRead, func(line []byte) {
|
||||||
|
lineStr := string(line)
|
||||||
|
if !strings.HasSuffix(lineStr, "\n") {
|
||||||
|
lineStr += "\n"
|
||||||
|
}
|
||||||
|
log.Printf("[conncontroller:%s:output] %s", conn.Opts.String(), lineStr)
|
||||||
|
})
|
||||||
|
if readErr != nil && readErr != io.EOF {
|
||||||
|
log.Printf("[conncontroller:%s] error reading output: %v\n", conn.Opts.String(), readErr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *SSHConn) checkAndInstallWsh(ctx context.Context) error {
|
||||||
|
client := conn.Client
|
||||||
|
// check that correct wsh extensions are installed
|
||||||
|
expectedVersion := fmt.Sprintf("wsh v%s", wavebase.WaveVersion)
|
||||||
|
clientVersion, err := remote.GetWshVersion(client)
|
||||||
|
if err == nil && clientVersion == expectedVersion {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var queryText string
|
||||||
|
var title string
|
||||||
|
if err != nil {
|
||||||
|
queryText = "Waveterm requires `wsh` shell extensions installed on your client to ensure a seamless experience. Would you like to install them?"
|
||||||
|
title = "Install Wsh Shell Extensions"
|
||||||
|
} else {
|
||||||
|
queryText = fmt.Sprintf("Waveterm requires `wsh` shell extensions installed on your client to be updated from %s to %s. Would you like to update?", clientVersion, expectedVersion)
|
||||||
|
title = "Update Wsh Shell Extensions"
|
||||||
|
}
|
||||||
|
request := &userinput.UserInputRequest{
|
||||||
|
ResponseType: "confirm",
|
||||||
|
QueryText: queryText,
|
||||||
|
Title: title,
|
||||||
|
CheckBoxMsg: "Don't show me this again",
|
||||||
|
}
|
||||||
|
response, err := userinput.GetUserInput(ctx, request)
|
||||||
|
if err != nil || !response.Confirm {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Printf("attempting to install wsh to `%s@%s`", client.User(), client.RemoteAddr().String())
|
||||||
|
clientOs, err := remote.GetClientOs(client)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
clientArch, err := remote.GetClientArch(client)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// attempt to install extension
|
||||||
|
wshLocalPath := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch)
|
||||||
|
err = remote.CpHostToRemote(client, wshLocalPath, "~/.waveterm/bin/wsh")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Printf("successfully installed wsh on %s\n", conn.Opts.String())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetConn(ctx context.Context, opts *remote.SSHOpts) (*SSHConn, error) {
|
||||||
|
globalLock.Lock()
|
||||||
|
defer globalLock.Unlock()
|
||||||
|
|
||||||
|
// attempt to retrieve if already opened
|
||||||
|
conn, ok := clientControllerMap[*opts]
|
||||||
|
if ok {
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := remote.ConnectToClient(ctx, opts) //todo specify or remove opts
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
conn = &SSHConn{Lock: &sync.Mutex{}, Opts: opts, Client: client}
|
||||||
|
err = conn.OpenDomainSocketListener()
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
installErr := conn.checkAndInstallWsh(ctx)
|
||||||
|
if installErr != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, fmt.Errorf("conncontroller %s wsh install error: %v", conn.Opts.String(), installErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
csErr := conn.StartConnServer()
|
||||||
|
if csErr != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, fmt.Errorf("conncontroller %s start wsh connserver error: %v", conn.Opts.String(), csErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// save successful connection to map
|
||||||
|
clientControllerMap[*opts] = conn
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DisconnectClient(opts *remote.SSHOpts) error {
|
||||||
|
globalLock.Lock()
|
||||||
|
defer globalLock.Unlock()
|
||||||
|
|
||||||
|
client, ok := clientControllerMap[*opts]
|
||||||
|
if ok {
|
||||||
|
return client.Close()
|
||||||
|
}
|
||||||
|
return fmt.Errorf("client %v not found", opts)
|
||||||
|
}
|
@ -2,156 +2,20 @@ package remote
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"html/template"
|
"html/template"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/wavetermdev/thenextwave/pkg/userinput"
|
|
||||||
"github.com/wavetermdev/thenextwave/pkg/util/shellutil"
|
|
||||||
"github.com/wavetermdev/thenextwave/pkg/util/utilfn"
|
|
||||||
"github.com/wavetermdev/thenextwave/pkg/wavebase"
|
|
||||||
"github.com/wavetermdev/thenextwave/pkg/wshutil"
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
var userHostRe = regexp.MustCompile(`^([a-zA-Z0-9][a-zA-Z0-9._@\\-]*@)?([a-z0-9][a-z0-9.-]*)(?::([0-9]+))?$`)
|
var userHostRe = regexp.MustCompile(`^([a-zA-Z0-9][a-zA-Z0-9._@\\-]*@)?([a-z0-9][a-z0-9.-]*)(?::([0-9]+))?$`)
|
||||||
var globalLock = &sync.Mutex{}
|
|
||||||
var clientControllerMap = make(map[SSHOpts]*SSHConn)
|
|
||||||
|
|
||||||
type SSHConn struct {
|
|
||||||
Lock *sync.Mutex
|
|
||||||
Opts *SSHOpts
|
|
||||||
Client *ssh.Client
|
|
||||||
SockName string
|
|
||||||
DomainSockListener net.Listener
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *SSHConn) Close() error {
|
|
||||||
if conn.DomainSockListener != nil {
|
|
||||||
conn.DomainSockListener.Close()
|
|
||||||
}
|
|
||||||
return conn.Client.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *SSHConn) OpenDomainSocketListener() error {
|
|
||||||
if conn.DomainSockListener != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
randStr, err := utilfn.RandomHexString(16) // 64-bits of randomness
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error generating random string: %w", err)
|
|
||||||
}
|
|
||||||
sockName := fmt.Sprintf("/tmp/waveterm-%s.sock", randStr)
|
|
||||||
log.Printf("remote domain socket %s %q\n", conn.Opts.String(), sockName)
|
|
||||||
listener, err := conn.Client.ListenUnix(sockName)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to request connection domain socket: %v", err)
|
|
||||||
}
|
|
||||||
conn.SockName = sockName
|
|
||||||
conn.DomainSockListener = listener
|
|
||||||
go func() {
|
|
||||||
wshutil.RunWshRpcOverListener(listener)
|
|
||||||
}()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetConn(ctx context.Context, opts *SSHOpts) (*SSHConn, error) {
|
|
||||||
globalLock.Lock()
|
|
||||||
defer globalLock.Unlock()
|
|
||||||
|
|
||||||
// attempt to retrieve if already opened
|
|
||||||
conn, ok := clientControllerMap[*opts]
|
|
||||||
if ok {
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
client, err := ConnectToClient(ctx, opts) //todo specify or remove opts
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
conn = &SSHConn{Lock: &sync.Mutex{}, Opts: opts, Client: client}
|
|
||||||
err = conn.OpenDomainSocketListener()
|
|
||||||
if err != nil {
|
|
||||||
conn.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// check that correct wsh extensions are installed
|
|
||||||
expectedVersion := fmt.Sprintf("wsh v%s", wavebase.WaveVersion)
|
|
||||||
clientVersion, err := getWshVersion(client)
|
|
||||||
if err == nil && clientVersion == expectedVersion {
|
|
||||||
// save successful connection to map
|
|
||||||
clientControllerMap[*opts] = conn
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var queryText string
|
|
||||||
var title string
|
|
||||||
if err != nil {
|
|
||||||
queryText = "Waveterm requires `wsh` shell extensions installed on your client to ensure a seamless experience. Would you like to install them?"
|
|
||||||
title = "Install Wsh Shell Extensions"
|
|
||||||
} else {
|
|
||||||
queryText = fmt.Sprintf("Waveterm requires `wsh` shell extensions installed on your client to be updated from %s to %s. Would you like to update?", clientVersion, expectedVersion)
|
|
||||||
title = "Update Wsh Shell Extensions"
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
request := &userinput.UserInputRequest{
|
|
||||||
ResponseType: "confirm",
|
|
||||||
QueryText: queryText,
|
|
||||||
Title: title,
|
|
||||||
CheckBoxMsg: "Don't show me this again",
|
|
||||||
}
|
|
||||||
response, err := userinput.GetUserInput(ctx, request)
|
|
||||||
if err != nil || !response.Confirm {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("attempting to install wsh to `%s@%s`", client.User(), client.RemoteAddr().String())
|
|
||||||
|
|
||||||
clientOs, err := getClientOs(client)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
clientArch, err := getClientArch(client)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// attempt to install extension
|
|
||||||
wshLocalPath := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch)
|
|
||||||
err = cpHostToRemote(client, wshLocalPath, "~/.waveterm/bin/wsh")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
log.Printf("successful install")
|
|
||||||
|
|
||||||
// save successful connection to map
|
|
||||||
clientControllerMap[*opts] = conn
|
|
||||||
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DisconnectClient(opts *SSHOpts) error {
|
|
||||||
globalLock.Lock()
|
|
||||||
defer globalLock.Unlock()
|
|
||||||
|
|
||||||
client, ok := clientControllerMap[*opts]
|
|
||||||
if ok {
|
|
||||||
return client.Close()
|
|
||||||
}
|
|
||||||
return fmt.Errorf("client %v not found", opts)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ParseOpts(input string) (*SSHOpts, error) {
|
func ParseOpts(input string) (*SSHOpts, error) {
|
||||||
m := userHostRe.FindStringSubmatch(input)
|
m := userHostRe.FindStringSubmatch(input)
|
||||||
@ -173,7 +37,7 @@ func ParseOpts(input string) (*SSHOpts, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func DetectShell(client *ssh.Client) (string, error) {
|
func DetectShell(client *ssh.Client) (string, error) {
|
||||||
wshPath := getWshPath(client)
|
wshPath := GetWshPath(client)
|
||||||
|
|
||||||
session, err := client.NewSession()
|
session, err := client.NewSession()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -191,8 +55,8 @@ func DetectShell(client *ssh.Client) (string, error) {
|
|||||||
return fmt.Sprintf(`"%s"`, strings.TrimSpace(string(out))), nil
|
return fmt.Sprintf(`"%s"`, strings.TrimSpace(string(out))), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getWshVersion(client *ssh.Client) (string, error) {
|
func GetWshVersion(client *ssh.Client) (string, error) {
|
||||||
wshPath := getWshPath(client)
|
wshPath := GetWshPath(client)
|
||||||
|
|
||||||
session, err := client.NewSession()
|
session, err := client.NewSession()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -207,7 +71,7 @@ func getWshVersion(client *ssh.Client) (string, error) {
|
|||||||
return strings.TrimSpace(string(out)), nil
|
return strings.TrimSpace(string(out)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getWshPath(client *ssh.Client) string {
|
func GetWshPath(client *ssh.Client) string {
|
||||||
defaultPath := filepath.Join("~", ".waveterm", "bin", "wsh")
|
defaultPath := filepath.Join("~", ".waveterm", "bin", "wsh")
|
||||||
|
|
||||||
session, err := client.NewSession()
|
session, err := client.NewSession()
|
||||||
@ -267,7 +131,7 @@ func hasBashInstalled(client *ssh.Client) (bool, error) {
|
|||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getClientOs(client *ssh.Client) (string, error) {
|
func GetClientOs(client *ssh.Client) (string, error) {
|
||||||
session, err := client.NewSession()
|
session, err := client.NewSession()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@ -306,7 +170,7 @@ func getClientOs(client *ssh.Client) (string, error) {
|
|||||||
return "", fmt.Errorf("unable to determine os: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr)
|
return "", fmt.Errorf("unable to determine os: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getClientArch(client *ssh.Client) (string, error) {
|
func GetClientArch(client *ssh.Client) (string, error) {
|
||||||
session, err := client.NewSession()
|
session, err := client.NewSession()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@ -360,7 +224,7 @@ mv {{.tempPath}} {{.installPath}}; \
|
|||||||
chmod a+x {{.installPath}}; \
|
chmod a+x {{.installPath}}; \
|
||||||
`
|
`
|
||||||
|
|
||||||
func cpHostToRemote(client *ssh.Client, sourcePath string, destPath string) error {
|
func CpHostToRemote(client *ssh.Client, sourcePath string, destPath string) error {
|
||||||
// warning: does not work on windows remote yet
|
// warning: does not work on windows remote yet
|
||||||
bashInstalled, err := hasBashInstalled(client)
|
bashInstalled, err := hasBashInstalled(client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -414,7 +278,7 @@ func cpHostToRemote(client *ssh.Client, sourcePath string, destPath string) erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
func InstallClientRcFiles(client *ssh.Client) error {
|
func InstallClientRcFiles(client *ssh.Client) error {
|
||||||
path := getWshPath(client)
|
path := GetWshPath(client)
|
||||||
|
|
||||||
session, err := client.NewSession()
|
session, err := client.NewSession()
|
||||||
if err != nil {
|
if err != nil {
|
@ -24,6 +24,7 @@ import (
|
|||||||
"github.com/wavetermdev/thenextwave/pkg/wavebase"
|
"github.com/wavetermdev/thenextwave/pkg/wavebase"
|
||||||
"github.com/wavetermdev/thenextwave/pkg/wshutil"
|
"github.com/wavetermdev/thenextwave/pkg/wshutil"
|
||||||
"github.com/wavetermdev/thenextwave/pkg/wstore"
|
"github.com/wavetermdev/thenextwave/pkg/wstore"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
type CommandOptsType struct {
|
type CommandOptsType struct {
|
||||||
@ -149,8 +150,7 @@ func (pp *PipePty) WriteString(s string) (n int, err error) {
|
|||||||
return pp.Write([]byte(s))
|
return pp.Write([]byte(s))
|
||||||
}
|
}
|
||||||
|
|
||||||
func StartRemoteShellProc(termSize wstore.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *remote.SSHConn) (*ShellProc, error) {
|
func StartRemoteShellProc(termSize wstore.TermSize, cmdStr string, cmdOpts CommandOptsType, client *ssh.Client) (*ShellProc, error) {
|
||||||
client := conn.Client
|
|
||||||
shellPath, err := remote.DetectShell(client)
|
shellPath, err := remote.DetectShell(client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -195,6 +195,7 @@ func StartRemoteShellProc(termSize wstore.TermSize, cmdStr string, cmdOpts Comma
|
|||||||
}
|
}
|
||||||
shellOpts = append(shellOpts, "-c", cmdStr)
|
shellOpts = append(shellOpts, "-c", cmdStr)
|
||||||
cmdCombined = fmt.Sprintf("%s %s", shellPath, strings.Join(shellOpts, " "))
|
cmdCombined = fmt.Sprintf("%s %s", shellPath, strings.Join(shellOpts, " "))
|
||||||
|
log.Printf("combined command is: %s", cmdCombined)
|
||||||
}
|
}
|
||||||
|
|
||||||
session, err := client.NewSession()
|
session, err := client.NewSession()
|
||||||
|
@ -19,6 +19,7 @@ import (
|
|||||||
|
|
||||||
"github.com/wavetermdev/thenextwave/pkg/util/utilfn"
|
"github.com/wavetermdev/thenextwave/pkg/util/utilfn"
|
||||||
"github.com/wavetermdev/thenextwave/pkg/wavebase"
|
"github.com/wavetermdev/thenextwave/pkg/wavebase"
|
||||||
|
"github.com/wavetermdev/thenextwave/pkg/wstore"
|
||||||
)
|
)
|
||||||
|
|
||||||
const DefaultTermType = "xterm-256color"
|
const DefaultTermType = "xterm-256color"
|
||||||
@ -125,6 +126,10 @@ func internalMacUserShell() string {
|
|||||||
return m[1]
|
return m[1]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func DefaultTermSize() wstore.TermSize {
|
||||||
|
return wstore.TermSize{Rows: DefaultTermRows, Cols: DefaultTermCols}
|
||||||
|
}
|
||||||
|
|
||||||
func WaveshellLocalEnvVars(termType string) map[string]string {
|
func WaveshellLocalEnvVars(termType string) map[string]string {
|
||||||
rtn := make(map[string]string)
|
rtn := make(map[string]string)
|
||||||
if termType != "" {
|
if termType != "" {
|
||||||
|
@ -257,10 +257,9 @@ func HandleWsInternal(w http.ResponseWriter, r *http.Request) error {
|
|||||||
defer eventbus.UnregisterWSChannel(wsConnId)
|
defer eventbus.UnregisterWSChannel(wsConnId)
|
||||||
// we create a wshproxy to handle rpc messages to/from the window
|
// we create a wshproxy to handle rpc messages to/from the window
|
||||||
wproxy := wshutil.MakeRpcProxy()
|
wproxy := wshutil.MakeRpcProxy()
|
||||||
rpcRouteId := "window:" + windowId
|
wshutil.DefaultRouter.RegisterRoute(wshutil.MakeWindowRouteId(windowId), wproxy)
|
||||||
wshutil.DefaultRouter.RegisterRoute(rpcRouteId, wproxy)
|
|
||||||
defer func() {
|
defer func() {
|
||||||
wshutil.DefaultRouter.UnregisterRoute(rpcRouteId)
|
wshutil.DefaultRouter.UnregisterRoute(wshutil.MakeWindowRouteId(windowId))
|
||||||
close(wproxy.ToRemoteCh)
|
close(wproxy.ToRemoteCh)
|
||||||
}()
|
}()
|
||||||
// WshServerFactoryFn(rpcInputCh, rpcOutputCh, wshrpc.RpcContext{})
|
// WshServerFactoryFn(rpcInputCh, rpcOutputCh, wshrpc.RpcContext{})
|
||||||
|
@ -4,6 +4,8 @@
|
|||||||
package wshclient
|
package wshclient
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
"github.com/wavetermdev/thenextwave/pkg/util/utilfn"
|
"github.com/wavetermdev/thenextwave/pkg/util/utilfn"
|
||||||
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
|
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
|
||||||
"github.com/wavetermdev/thenextwave/pkg/wshutil"
|
"github.com/wavetermdev/thenextwave/pkg/wshutil"
|
||||||
@ -14,6 +16,9 @@ func sendRpcRequestCallHelper[T any](w *wshutil.WshRpc, command string, data int
|
|||||||
opts = &wshrpc.RpcOpts{}
|
opts = &wshrpc.RpcOpts{}
|
||||||
}
|
}
|
||||||
var respData T
|
var respData T
|
||||||
|
if w == nil {
|
||||||
|
return respData, errors.New("nil wshrpc passed to wshclient")
|
||||||
|
}
|
||||||
if opts.NoResponse {
|
if opts.NoResponse {
|
||||||
err := w.SendCommand(command, data, opts)
|
err := w.SendCommand(command, data, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -32,17 +37,26 @@ func sendRpcRequestCallHelper[T any](w *wshutil.WshRpc, command string, data int
|
|||||||
return respData, nil
|
return respData, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func rtnErr[T any](ch chan wshrpc.RespOrErrorUnion[T], err error) {
|
||||||
|
go func() {
|
||||||
|
ch <- wshrpc.RespOrErrorUnion[T]{Error: err}
|
||||||
|
close(ch)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
func sendRpcRequestResponseStreamHelper[T any](w *wshutil.WshRpc, command string, data interface{}, opts *wshrpc.RpcOpts) chan wshrpc.RespOrErrorUnion[T] {
|
func sendRpcRequestResponseStreamHelper[T any](w *wshutil.WshRpc, command string, data interface{}, opts *wshrpc.RpcOpts) chan wshrpc.RespOrErrorUnion[T] {
|
||||||
if opts == nil {
|
if opts == nil {
|
||||||
opts = &wshrpc.RpcOpts{}
|
opts = &wshrpc.RpcOpts{}
|
||||||
}
|
}
|
||||||
respChan := make(chan wshrpc.RespOrErrorUnion[T])
|
respChan := make(chan wshrpc.RespOrErrorUnion[T])
|
||||||
|
if w == nil {
|
||||||
|
rtnErr(respChan, errors.New("nil wshrpc passed to wshclient"))
|
||||||
|
return respChan
|
||||||
|
}
|
||||||
reqHandler, err := w.SendComplexRequest(command, data, opts)
|
reqHandler, err := w.SendComplexRequest(command, data, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
go func() {
|
rtnErr(respChan, err)
|
||||||
respChan <- wshrpc.RespOrErrorUnion[T]{Error: err}
|
return respChan
|
||||||
close(respChan)
|
|
||||||
}()
|
|
||||||
} else {
|
} else {
|
||||||
go func() {
|
go func() {
|
||||||
defer close(respChan)
|
defer close(respChan)
|
||||||
|
@ -107,6 +107,7 @@ type RpcOpts struct {
|
|||||||
type RpcContext struct {
|
type RpcContext struct {
|
||||||
BlockId string `json:"blockid,omitempty"`
|
BlockId string `json:"blockid,omitempty"`
|
||||||
TabId string `json:"tabid,omitempty"`
|
TabId string `json:"tabid,omitempty"`
|
||||||
|
Conn string `json:"conn,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func HackRpcContextIntoData(dataPtr any, rpcContext RpcContext) {
|
func HackRpcContextIntoData(dataPtr any, rpcContext RpcContext) {
|
||||||
|
@ -6,6 +6,7 @@ package wshutil
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@ -81,12 +82,14 @@ func handleAuthenticationCommand(msg RpcMessage) (*wshrpc.RpcContext, error) {
|
|||||||
if newCtx == nil {
|
if newCtx == nil {
|
||||||
return nil, fmt.Errorf("no context found in jwt token")
|
return nil, fmt.Errorf("no context found in jwt token")
|
||||||
}
|
}
|
||||||
if newCtx.BlockId == "" {
|
if newCtx.BlockId == "" && newCtx.Conn == "" {
|
||||||
return nil, fmt.Errorf("no blockId found in jwt token")
|
return nil, fmt.Errorf("no blockid or conn found in jwt token")
|
||||||
}
|
}
|
||||||
|
if newCtx.BlockId != "" {
|
||||||
if _, err := uuid.Parse(newCtx.BlockId); err != nil {
|
if _, err := uuid.Parse(newCtx.BlockId); err != nil {
|
||||||
return nil, fmt.Errorf("invalid blockId in jwt token")
|
return nil, fmt.Errorf("invalid blockId in jwt token")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return newCtx, nil
|
return newCtx, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -114,6 +117,7 @@ func (p *WshRpcProxy) HandleAuthentication() (*wshrpc.RpcContext, error) {
|
|||||||
}
|
}
|
||||||
newCtx, err := handleAuthenticationCommand(msg)
|
newCtx, err := handleAuthenticationCommand(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Printf("error handling authentication: %v\n", err)
|
||||||
p.sendResponseError(msg, err)
|
p.sendResponseError(msg, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -42,6 +42,14 @@ func MakeConnectionRouteId(connId string) string {
|
|||||||
return "conn:" + connId
|
return "conn:" + connId
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func MakeControllerRouteId(blockId string) string {
|
||||||
|
return "controller:" + blockId
|
||||||
|
}
|
||||||
|
|
||||||
|
func MakeWindowRouteId(windowId string) string {
|
||||||
|
return "window:" + windowId
|
||||||
|
}
|
||||||
|
|
||||||
var DefaultRouter = NewWshRouter()
|
var DefaultRouter = NewWshRouter()
|
||||||
|
|
||||||
func NewWshRouter() *WshRouter {
|
func NewWshRouter() *WshRouter {
|
||||||
@ -230,6 +238,7 @@ func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient) {
|
|||||||
log.Printf("error: WshRouter cannot register sys route\n")
|
log.Printf("error: WshRouter cannot register sys route\n")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
log.Printf("registering wsh route %q\n", routeId)
|
||||||
router.Lock.Lock()
|
router.Lock.Lock()
|
||||||
defer router.Lock.Unlock()
|
defer router.Lock.Unlock()
|
||||||
router.RouteMap[routeId] = rpc
|
router.RouteMap[routeId] = rpc
|
||||||
|
@ -43,7 +43,7 @@ func streamToLines_processBuf(lineBuf *lineBuf, readBuf []byte, lineFn func([]by
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func streamToLines(input io.Reader, lineFn func([]byte)) error {
|
func StreamToLines(input io.Reader, lineFn func([]byte)) error {
|
||||||
var lineBuf lineBuf
|
var lineBuf lineBuf
|
||||||
readBuf := make([]byte, 16*1024)
|
readBuf := make([]byte, 16*1024)
|
||||||
for {
|
for {
|
||||||
@ -56,7 +56,7 @@ func streamToLines(input io.Reader, lineFn func([]byte)) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func AdaptStreamToMsgCh(input io.Reader, output chan []byte) error {
|
func AdaptStreamToMsgCh(input io.Reader, output chan []byte) error {
|
||||||
return streamToLines(input, func(line []byte) {
|
return StreamToLines(input, func(line []byte) {
|
||||||
output <- line
|
output <- line
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -248,6 +248,9 @@ func MakeClientJWTToken(rpcCtx wshrpc.RpcContext, sockName string) (string, erro
|
|||||||
if rpcCtx.TabId != "" {
|
if rpcCtx.TabId != "" {
|
||||||
claims["tabid"] = rpcCtx.TabId
|
claims["tabid"] = rpcCtx.TabId
|
||||||
}
|
}
|
||||||
|
if rpcCtx.Conn != "" {
|
||||||
|
claims["conn"] = rpcCtx.Conn
|
||||||
|
}
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
tokenStr, err := token.SignedString([]byte(wavebase.JwtSecret))
|
tokenStr, err := token.SignedString([]byte(wavebase.JwtSecret))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -299,6 +302,11 @@ func mapClaimsToRpcContext(claims jwt.MapClaims) *wshrpc.RpcContext {
|
|||||||
rpcCtx.TabId = tabId
|
rpcCtx.TabId = tabId
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if claims["conn"] != nil {
|
||||||
|
if conn, ok := claims["conn"].(string); ok {
|
||||||
|
rpcCtx.Conn = conn
|
||||||
|
}
|
||||||
|
}
|
||||||
return rpcCtx
|
return rpcCtx
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -306,9 +314,12 @@ func RunWshRpcOverListener(listener net.Listener) {
|
|||||||
defer log.Printf("domain socket listener shutting down\n")
|
defer log.Printf("domain socket listener shutting down\n")
|
||||||
for {
|
for {
|
||||||
conn, err := listener.Accept()
|
conn, err := listener.Accept()
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("error accepting connection: %v\n", err)
|
log.Printf("error accepting connection: %v\n", err)
|
||||||
continue
|
break
|
||||||
}
|
}
|
||||||
log.Print("got domain socket connection\n")
|
log.Print("got domain socket connection\n")
|
||||||
go handleDomainSocketClient(conn)
|
go handleDomainSocketClient(conn)
|
||||||
@ -337,7 +348,11 @@ func handleDomainSocketClient(conn net.Conn) {
|
|||||||
// now that we're authenticated, set the ctx and attach to the router
|
// now that we're authenticated, set the ctx and attach to the router
|
||||||
log.Printf("domain socket connection authenticated: %#v\n", rpcCtx)
|
log.Printf("domain socket connection authenticated: %#v\n", rpcCtx)
|
||||||
proxy.SetRpcContext(rpcCtx)
|
proxy.SetRpcContext(rpcCtx)
|
||||||
DefaultRouter.RegisterRoute("controller:"+rpcCtx.BlockId, proxy)
|
if rpcCtx.BlockId != "" {
|
||||||
|
DefaultRouter.RegisterRoute(MakeControllerRouteId(rpcCtx.BlockId), proxy)
|
||||||
|
} else if rpcCtx.Conn != "" {
|
||||||
|
DefaultRouter.RegisterRoute(MakeConnectionRouteId(rpcCtx.Conn), proxy)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// only for use on client
|
// only for use on client
|
||||||
|
Loading…
Reference in New Issue
Block a user