mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-17 20:51:55 +01:00
set up remote connserver (#245)
This commit is contained in:
parent
c30188552f
commit
85874f92ca
@ -7,15 +7,15 @@ import (
|
||||
"os"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshrpc/wshclient"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshrpc/wshremote"
|
||||
)
|
||||
|
||||
var serverCmd = &cobra.Command{
|
||||
Use: "server",
|
||||
Use: "connserver",
|
||||
Short: "remote server to power wave blocks",
|
||||
Args: cobra.NoArgs,
|
||||
Run: serverRun,
|
||||
PreRunE: preRunSetupRpcClient,
|
||||
}
|
||||
|
||||
func init() {
|
||||
@ -23,10 +23,8 @@ func init() {
|
||||
}
|
||||
|
||||
func serverRun(cmd *cobra.Command, args []string) {
|
||||
WriteStdout("running wsh server\n")
|
||||
WriteStdout("running wsh connserver\n")
|
||||
RpcClient.SetServerImpl(&wshremote.ServerImpl{LogWriter: os.Stdout})
|
||||
err := wshclient.TestCommand(RpcClient, "hello", nil)
|
||||
WriteStdout("got test rtn: %v\n", err)
|
||||
|
||||
select {} // run forever
|
||||
}
|
@ -13,6 +13,7 @@ var deleteBlockCmd = &cobra.Command{
|
||||
Short: "delete a block",
|
||||
Args: cobra.ExactArgs(1),
|
||||
Run: deleteBlockRun,
|
||||
PreRunE: preRunSetupRpcClient,
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
@ -16,6 +16,7 @@ var getMetaCmd = &cobra.Command{
|
||||
Short: "get metadata for an entity",
|
||||
Args: cobra.RangeArgs(1, 2),
|
||||
Run: getMetaRun,
|
||||
PreRunE: preRunSetupRpcClient,
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
@ -18,6 +18,7 @@ var htmlCmd = &cobra.Command{
|
||||
Use: "html",
|
||||
Short: "Launch a demo html-mode terminal",
|
||||
Run: htmlRun,
|
||||
PreRunE: preRunSetupRpcClient,
|
||||
}
|
||||
|
||||
func htmlRun(cmd *cobra.Command, args []string) {
|
||||
|
@ -16,6 +16,7 @@ var readFileCmd = &cobra.Command{
|
||||
Short: "read a blockfile",
|
||||
Args: cobra.ExactArgs(2),
|
||||
Run: runReadFile,
|
||||
PreRunE: preRunSetupRpcClient,
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
@ -26,6 +26,7 @@ var (
|
||||
Use: "wsh",
|
||||
Short: "CLI tool to control Wave Terminal",
|
||||
Long: `wsh is a small utility that lets you do cool things with Wave Terminal, right from the command line`,
|
||||
SilenceUsage: true,
|
||||
}
|
||||
)
|
||||
|
||||
@ -60,9 +61,17 @@ func WriteStdout(fmtStr string, args ...interface{}) {
|
||||
fmt.Print(output)
|
||||
}
|
||||
|
||||
func preRunSetupRpcClient(cmd *cobra.Command, args []string) error {
|
||||
err := setupRpcClient(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// returns the wrapped stdin and a new rpc client (that wraps the stdin input and stdout output)
|
||||
func setupRpcClient(serverImpl wshutil.ServerImpl) error {
|
||||
jwtToken := os.Getenv("WAVETERM_JWT")
|
||||
jwtToken := os.Getenv(wshutil.WaveJwtTokenVarName)
|
||||
if jwtToken == "" {
|
||||
wshutil.SetTermRawModeAndInstallShutdownHandlers(true)
|
||||
UsingTermWshMode = true
|
||||
@ -71,7 +80,7 @@ func setupRpcClient(serverImpl wshutil.ServerImpl) error {
|
||||
}
|
||||
sockName, err := wshutil.ExtractUnverifiedSocketName(jwtToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error extracting socket name from WAVETERM_JWT: %v", err)
|
||||
return fmt.Errorf("error extracting socket name from %s: %v", wshutil.WaveJwtTokenVarName, err)
|
||||
}
|
||||
RpcClient, err = wshutil.SetupDomainSocketRpcClient(sockName, serverImpl)
|
||||
if err != nil {
|
||||
@ -158,13 +167,7 @@ func Execute() {
|
||||
wshutil.DoShutdown("", 0, false)
|
||||
}
|
||||
}()
|
||||
err := setupRpcClient(nil)
|
||||
if err != nil {
|
||||
log.Printf("[error] %v\n", err)
|
||||
wshutil.DoShutdown("", 1, true)
|
||||
return
|
||||
}
|
||||
err = rootCmd.Execute()
|
||||
err := rootCmd.Execute()
|
||||
if err != nil {
|
||||
log.Printf("[error] %v\n", err)
|
||||
wshutil.DoShutdown("", 1, true)
|
||||
|
@ -18,6 +18,7 @@ var setMetaCmd = &cobra.Command{
|
||||
Short: "set metadata for an entity",
|
||||
Args: cobra.MinimumNArgs(2),
|
||||
Run: setMetaRun,
|
||||
PreRunE: preRunSetupRpcClient,
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
@ -19,6 +19,7 @@ var termCmd = &cobra.Command{
|
||||
Short: "open a terminal in directory",
|
||||
Args: cobra.RangeArgs(0, 1),
|
||||
Run: termRun,
|
||||
PreRunE: preRunSetupRpcClient,
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
@ -21,6 +21,7 @@ var viewCmd = &cobra.Command{
|
||||
Short: "preview a file or directory",
|
||||
Args: cobra.ExactArgs(1),
|
||||
Run: viewRun,
|
||||
PreRunE: preRunSetupRpcClient,
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
@ -18,6 +18,7 @@ import (
|
||||
"github.com/wavetermdev/thenextwave/pkg/eventbus"
|
||||
"github.com/wavetermdev/thenextwave/pkg/filestore"
|
||||
"github.com/wavetermdev/thenextwave/pkg/remote"
|
||||
"github.com/wavetermdev/thenextwave/pkg/remote/conncontroller"
|
||||
"github.com/wavetermdev/thenextwave/pkg/shellexec"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wavebase"
|
||||
"github.com/wavetermdev/thenextwave/pkg/waveobj"
|
||||
@ -277,7 +278,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn, err := remote.GetConn(credentialCtx, opts)
|
||||
conn, err := conncontroller.GetConn(credentialCtx, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -288,7 +289,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
|
||||
}
|
||||
cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr
|
||||
}
|
||||
shellProc, err = shellexec.StartRemoteShellProc(rc.TermSize, cmdStr, cmdOpts, conn)
|
||||
shellProc, err = shellexec.StartRemoteShellProc(rc.TermSize, cmdStr, cmdOpts, conn.Client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -317,7 +318,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
|
||||
// we don't need to authenticate this wshProxy since it is coming direct
|
||||
wshProxy := wshutil.MakeRpcProxy()
|
||||
wshProxy.SetRpcContext(&wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId})
|
||||
wshutil.DefaultRouter.RegisterRoute("controller:"+bc.BlockId, wshProxy)
|
||||
wshutil.DefaultRouter.RegisterRoute(wshutil.MakeControllerRouteId(bc.BlockId), wshProxy)
|
||||
ptyBuffer := wshutil.MakePtyBuffer(wshutil.WaveOSCPrefix, bc.ShellProc.Cmd, wshProxy.FromRemoteCh)
|
||||
go func() {
|
||||
// handles regular output from the pty (goes to the blockfile and xterm)
|
||||
@ -376,7 +377,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
|
||||
go func() {
|
||||
// wait for the shell to finish
|
||||
defer func() {
|
||||
wshutil.DefaultRouter.UnregisterRoute("controller:" + bc.BlockId)
|
||||
wshutil.DefaultRouter.UnregisterRoute(wshutil.MakeControllerRouteId(bc.BlockId))
|
||||
bc.UpdateControllerAndSendUpdate(func() bool {
|
||||
bc.ShellProcStatus = Status_Done
|
||||
return true
|
||||
|
224
pkg/remote/conncontroller/conncontroller.go
Normal file
224
pkg/remote/conncontroller/conncontroller.go
Normal file
@ -0,0 +1,224 @@
|
||||
// Copyright 2024, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package conncontroller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/wavetermdev/thenextwave/pkg/remote"
|
||||
"github.com/wavetermdev/thenextwave/pkg/userinput"
|
||||
"github.com/wavetermdev/thenextwave/pkg/util/shellutil"
|
||||
"github.com/wavetermdev/thenextwave/pkg/util/utilfn"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wavebase"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshutil"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var globalLock = &sync.Mutex{}
|
||||
var clientControllerMap = make(map[remote.SSHOpts]*SSHConn)
|
||||
|
||||
type SSHConn struct {
|
||||
Lock *sync.Mutex
|
||||
Opts *remote.SSHOpts
|
||||
Client *ssh.Client
|
||||
SockName string
|
||||
DomainSockListener net.Listener
|
||||
ConnController *ssh.Session
|
||||
}
|
||||
|
||||
func (conn *SSHConn) Close() error {
|
||||
if conn.DomainSockListener != nil {
|
||||
conn.DomainSockListener.Close()
|
||||
conn.DomainSockListener = nil
|
||||
}
|
||||
if conn.ConnController != nil {
|
||||
conn.ConnController.Close()
|
||||
conn.ConnController = nil
|
||||
}
|
||||
err := conn.Client.Close()
|
||||
conn.Client = nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (conn *SSHConn) OpenDomainSocketListener() error {
|
||||
if conn.DomainSockListener != nil {
|
||||
return nil
|
||||
}
|
||||
randStr, err := utilfn.RandomHexString(16) // 64-bits of randomness
|
||||
if err != nil {
|
||||
return fmt.Errorf("error generating random string: %w", err)
|
||||
}
|
||||
sockName := fmt.Sprintf("/tmp/waveterm-%s.sock", randStr)
|
||||
log.Printf("remote domain socket %s %q\n", conn.Opts.String(), sockName)
|
||||
listener, err := conn.Client.ListenUnix(sockName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to request connection domain socket: %v", err)
|
||||
}
|
||||
conn.SockName = sockName
|
||||
conn.DomainSockListener = listener
|
||||
go func() {
|
||||
defer func() {
|
||||
conn.Lock.Lock()
|
||||
defer conn.Lock.Unlock()
|
||||
conn.DomainSockListener = nil
|
||||
}()
|
||||
wshutil.RunWshRpcOverListener(listener)
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *SSHConn) StartConnServer() error {
|
||||
conn.Lock.Lock()
|
||||
defer conn.Lock.Unlock()
|
||||
if conn.ConnController != nil {
|
||||
return nil
|
||||
}
|
||||
wshPath := remote.GetWshPath(conn.Client)
|
||||
rpcCtx := wshrpc.RpcContext{
|
||||
Conn: conn.Opts.String(),
|
||||
}
|
||||
jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, conn.SockName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create jwt token for conn controller: %w", err)
|
||||
}
|
||||
sshSession, err := conn.Client.NewSession()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create ssh session for conn controller: %w", err)
|
||||
}
|
||||
pipeRead, pipeWrite := io.Pipe()
|
||||
sshSession.Stdout = pipeWrite
|
||||
sshSession.Stderr = pipeWrite
|
||||
conn.ConnController = sshSession
|
||||
cmdStr := fmt.Sprintf("%s=\"%s\" %s connserver", wshutil.WaveJwtTokenVarName, jwtToken, wshPath)
|
||||
log.Printf("starting conn controller: %s\n", cmdStr)
|
||||
err = sshSession.Start(cmdStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to start conn controller: %w", err)
|
||||
}
|
||||
// service the I/O
|
||||
go func() {
|
||||
// wait for termination, clear the controller
|
||||
waitErr := sshSession.Wait()
|
||||
log.Printf("conn controller (%q) terminated: %v", conn.Opts.String(), waitErr)
|
||||
conn.Lock.Lock()
|
||||
defer conn.Lock.Unlock()
|
||||
conn.ConnController = nil
|
||||
}()
|
||||
go func() {
|
||||
readErr := wshutil.StreamToLines(pipeRead, func(line []byte) {
|
||||
lineStr := string(line)
|
||||
if !strings.HasSuffix(lineStr, "\n") {
|
||||
lineStr += "\n"
|
||||
}
|
||||
log.Printf("[conncontroller:%s:output] %s", conn.Opts.String(), lineStr)
|
||||
})
|
||||
if readErr != nil && readErr != io.EOF {
|
||||
log.Printf("[conncontroller:%s] error reading output: %v\n", conn.Opts.String(), readErr)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *SSHConn) checkAndInstallWsh(ctx context.Context) error {
|
||||
client := conn.Client
|
||||
// check that correct wsh extensions are installed
|
||||
expectedVersion := fmt.Sprintf("wsh v%s", wavebase.WaveVersion)
|
||||
clientVersion, err := remote.GetWshVersion(client)
|
||||
if err == nil && clientVersion == expectedVersion {
|
||||
return nil
|
||||
}
|
||||
var queryText string
|
||||
var title string
|
||||
if err != nil {
|
||||
queryText = "Waveterm requires `wsh` shell extensions installed on your client to ensure a seamless experience. Would you like to install them?"
|
||||
title = "Install Wsh Shell Extensions"
|
||||
} else {
|
||||
queryText = fmt.Sprintf("Waveterm requires `wsh` shell extensions installed on your client to be updated from %s to %s. Would you like to update?", clientVersion, expectedVersion)
|
||||
title = "Update Wsh Shell Extensions"
|
||||
}
|
||||
request := &userinput.UserInputRequest{
|
||||
ResponseType: "confirm",
|
||||
QueryText: queryText,
|
||||
Title: title,
|
||||
CheckBoxMsg: "Don't show me this again",
|
||||
}
|
||||
response, err := userinput.GetUserInput(ctx, request)
|
||||
if err != nil || !response.Confirm {
|
||||
return err
|
||||
}
|
||||
log.Printf("attempting to install wsh to `%s@%s`", client.User(), client.RemoteAddr().String())
|
||||
clientOs, err := remote.GetClientOs(client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
clientArch, err := remote.GetClientArch(client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// attempt to install extension
|
||||
wshLocalPath := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch)
|
||||
err = remote.CpHostToRemote(client, wshLocalPath, "~/.waveterm/bin/wsh")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("successfully installed wsh on %s\n", conn.Opts.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetConn(ctx context.Context, opts *remote.SSHOpts) (*SSHConn, error) {
|
||||
globalLock.Lock()
|
||||
defer globalLock.Unlock()
|
||||
|
||||
// attempt to retrieve if already opened
|
||||
conn, ok := clientControllerMap[*opts]
|
||||
if ok {
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
client, err := remote.ConnectToClient(ctx, opts) //todo specify or remove opts
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn = &SSHConn{Lock: &sync.Mutex{}, Opts: opts, Client: client}
|
||||
err = conn.OpenDomainSocketListener()
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
installErr := conn.checkAndInstallWsh(ctx)
|
||||
if installErr != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("conncontroller %s wsh install error: %v", conn.Opts.String(), installErr)
|
||||
}
|
||||
|
||||
csErr := conn.StartConnServer()
|
||||
if csErr != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("conncontroller %s start wsh connserver error: %v", conn.Opts.String(), csErr)
|
||||
}
|
||||
|
||||
// save successful connection to map
|
||||
clientControllerMap[*opts] = conn
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func DisconnectClient(opts *remote.SSHOpts) error {
|
||||
globalLock.Lock()
|
||||
defer globalLock.Unlock()
|
||||
|
||||
client, ok := clientControllerMap[*opts]
|
||||
if ok {
|
||||
return client.Close()
|
||||
}
|
||||
return fmt.Errorf("client %v not found", opts)
|
||||
}
|
@ -2,156 +2,20 @@ package remote
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/wavetermdev/thenextwave/pkg/userinput"
|
||||
"github.com/wavetermdev/thenextwave/pkg/util/shellutil"
|
||||
"github.com/wavetermdev/thenextwave/pkg/util/utilfn"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wavebase"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshutil"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var userHostRe = regexp.MustCompile(`^([a-zA-Z0-9][a-zA-Z0-9._@\\-]*@)?([a-z0-9][a-z0-9.-]*)(?::([0-9]+))?$`)
|
||||
var globalLock = &sync.Mutex{}
|
||||
var clientControllerMap = make(map[SSHOpts]*SSHConn)
|
||||
|
||||
type SSHConn struct {
|
||||
Lock *sync.Mutex
|
||||
Opts *SSHOpts
|
||||
Client *ssh.Client
|
||||
SockName string
|
||||
DomainSockListener net.Listener
|
||||
}
|
||||
|
||||
func (conn *SSHConn) Close() error {
|
||||
if conn.DomainSockListener != nil {
|
||||
conn.DomainSockListener.Close()
|
||||
}
|
||||
return conn.Client.Close()
|
||||
}
|
||||
|
||||
func (conn *SSHConn) OpenDomainSocketListener() error {
|
||||
if conn.DomainSockListener != nil {
|
||||
return nil
|
||||
}
|
||||
randStr, err := utilfn.RandomHexString(16) // 64-bits of randomness
|
||||
if err != nil {
|
||||
return fmt.Errorf("error generating random string: %w", err)
|
||||
}
|
||||
sockName := fmt.Sprintf("/tmp/waveterm-%s.sock", randStr)
|
||||
log.Printf("remote domain socket %s %q\n", conn.Opts.String(), sockName)
|
||||
listener, err := conn.Client.ListenUnix(sockName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to request connection domain socket: %v", err)
|
||||
}
|
||||
conn.SockName = sockName
|
||||
conn.DomainSockListener = listener
|
||||
go func() {
|
||||
wshutil.RunWshRpcOverListener(listener)
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetConn(ctx context.Context, opts *SSHOpts) (*SSHConn, error) {
|
||||
globalLock.Lock()
|
||||
defer globalLock.Unlock()
|
||||
|
||||
// attempt to retrieve if already opened
|
||||
conn, ok := clientControllerMap[*opts]
|
||||
if ok {
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
client, err := ConnectToClient(ctx, opts) //todo specify or remove opts
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn = &SSHConn{Lock: &sync.Mutex{}, Opts: opts, Client: client}
|
||||
err = conn.OpenDomainSocketListener()
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// check that correct wsh extensions are installed
|
||||
expectedVersion := fmt.Sprintf("wsh v%s", wavebase.WaveVersion)
|
||||
clientVersion, err := getWshVersion(client)
|
||||
if err == nil && clientVersion == expectedVersion {
|
||||
// save successful connection to map
|
||||
clientControllerMap[*opts] = conn
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
var queryText string
|
||||
var title string
|
||||
if err != nil {
|
||||
queryText = "Waveterm requires `wsh` shell extensions installed on your client to ensure a seamless experience. Would you like to install them?"
|
||||
title = "Install Wsh Shell Extensions"
|
||||
} else {
|
||||
queryText = fmt.Sprintf("Waveterm requires `wsh` shell extensions installed on your client to be updated from %s to %s. Would you like to update?", clientVersion, expectedVersion)
|
||||
title = "Update Wsh Shell Extensions"
|
||||
|
||||
}
|
||||
|
||||
request := &userinput.UserInputRequest{
|
||||
ResponseType: "confirm",
|
||||
QueryText: queryText,
|
||||
Title: title,
|
||||
CheckBoxMsg: "Don't show me this again",
|
||||
}
|
||||
response, err := userinput.GetUserInput(ctx, request)
|
||||
if err != nil || !response.Confirm {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Printf("attempting to install wsh to `%s@%s`", client.User(), client.RemoteAddr().String())
|
||||
|
||||
clientOs, err := getClientOs(client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clientArch, err := getClientArch(client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// attempt to install extension
|
||||
wshLocalPath := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch)
|
||||
err = cpHostToRemote(client, wshLocalPath, "~/.waveterm/bin/wsh")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Printf("successful install")
|
||||
|
||||
// save successful connection to map
|
||||
clientControllerMap[*opts] = conn
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func DisconnectClient(opts *SSHOpts) error {
|
||||
globalLock.Lock()
|
||||
defer globalLock.Unlock()
|
||||
|
||||
client, ok := clientControllerMap[*opts]
|
||||
if ok {
|
||||
return client.Close()
|
||||
}
|
||||
return fmt.Errorf("client %v not found", opts)
|
||||
}
|
||||
|
||||
func ParseOpts(input string) (*SSHOpts, error) {
|
||||
m := userHostRe.FindStringSubmatch(input)
|
||||
@ -173,7 +37,7 @@ func ParseOpts(input string) (*SSHOpts, error) {
|
||||
}
|
||||
|
||||
func DetectShell(client *ssh.Client) (string, error) {
|
||||
wshPath := getWshPath(client)
|
||||
wshPath := GetWshPath(client)
|
||||
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
@ -191,8 +55,8 @@ func DetectShell(client *ssh.Client) (string, error) {
|
||||
return fmt.Sprintf(`"%s"`, strings.TrimSpace(string(out))), nil
|
||||
}
|
||||
|
||||
func getWshVersion(client *ssh.Client) (string, error) {
|
||||
wshPath := getWshPath(client)
|
||||
func GetWshVersion(client *ssh.Client) (string, error) {
|
||||
wshPath := GetWshPath(client)
|
||||
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
@ -207,7 +71,7 @@ func getWshVersion(client *ssh.Client) (string, error) {
|
||||
return strings.TrimSpace(string(out)), nil
|
||||
}
|
||||
|
||||
func getWshPath(client *ssh.Client) string {
|
||||
func GetWshPath(client *ssh.Client) string {
|
||||
defaultPath := filepath.Join("~", ".waveterm", "bin", "wsh")
|
||||
|
||||
session, err := client.NewSession()
|
||||
@ -267,7 +131,7 @@ func hasBashInstalled(client *ssh.Client) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func getClientOs(client *ssh.Client) (string, error) {
|
||||
func GetClientOs(client *ssh.Client) (string, error) {
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
return "", err
|
||||
@ -306,7 +170,7 @@ func getClientOs(client *ssh.Client) (string, error) {
|
||||
return "", fmt.Errorf("unable to determine os: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr)
|
||||
}
|
||||
|
||||
func getClientArch(client *ssh.Client) (string, error) {
|
||||
func GetClientArch(client *ssh.Client) (string, error) {
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
return "", err
|
||||
@ -360,7 +224,7 @@ mv {{.tempPath}} {{.installPath}}; \
|
||||
chmod a+x {{.installPath}}; \
|
||||
`
|
||||
|
||||
func cpHostToRemote(client *ssh.Client, sourcePath string, destPath string) error {
|
||||
func CpHostToRemote(client *ssh.Client, sourcePath string, destPath string) error {
|
||||
// warning: does not work on windows remote yet
|
||||
bashInstalled, err := hasBashInstalled(client)
|
||||
if err != nil {
|
||||
@ -414,7 +278,7 @@ func cpHostToRemote(client *ssh.Client, sourcePath string, destPath string) erro
|
||||
}
|
||||
|
||||
func InstallClientRcFiles(client *ssh.Client) error {
|
||||
path := getWshPath(client)
|
||||
path := GetWshPath(client)
|
||||
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
@ -24,6 +24,7 @@ import (
|
||||
"github.com/wavetermdev/thenextwave/pkg/wavebase"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshutil"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wstore"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type CommandOptsType struct {
|
||||
@ -149,8 +150,7 @@ func (pp *PipePty) WriteString(s string) (n int, err error) {
|
||||
return pp.Write([]byte(s))
|
||||
}
|
||||
|
||||
func StartRemoteShellProc(termSize wstore.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *remote.SSHConn) (*ShellProc, error) {
|
||||
client := conn.Client
|
||||
func StartRemoteShellProc(termSize wstore.TermSize, cmdStr string, cmdOpts CommandOptsType, client *ssh.Client) (*ShellProc, error) {
|
||||
shellPath, err := remote.DetectShell(client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -195,6 +195,7 @@ func StartRemoteShellProc(termSize wstore.TermSize, cmdStr string, cmdOpts Comma
|
||||
}
|
||||
shellOpts = append(shellOpts, "-c", cmdStr)
|
||||
cmdCombined = fmt.Sprintf("%s %s", shellPath, strings.Join(shellOpts, " "))
|
||||
log.Printf("combined command is: %s", cmdCombined)
|
||||
}
|
||||
|
||||
session, err := client.NewSession()
|
||||
|
@ -19,6 +19,7 @@ import (
|
||||
|
||||
"github.com/wavetermdev/thenextwave/pkg/util/utilfn"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wavebase"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wstore"
|
||||
)
|
||||
|
||||
const DefaultTermType = "xterm-256color"
|
||||
@ -125,6 +126,10 @@ func internalMacUserShell() string {
|
||||
return m[1]
|
||||
}
|
||||
|
||||
func DefaultTermSize() wstore.TermSize {
|
||||
return wstore.TermSize{Rows: DefaultTermRows, Cols: DefaultTermCols}
|
||||
}
|
||||
|
||||
func WaveshellLocalEnvVars(termType string) map[string]string {
|
||||
rtn := make(map[string]string)
|
||||
if termType != "" {
|
||||
|
@ -257,10 +257,9 @@ func HandleWsInternal(w http.ResponseWriter, r *http.Request) error {
|
||||
defer eventbus.UnregisterWSChannel(wsConnId)
|
||||
// we create a wshproxy to handle rpc messages to/from the window
|
||||
wproxy := wshutil.MakeRpcProxy()
|
||||
rpcRouteId := "window:" + windowId
|
||||
wshutil.DefaultRouter.RegisterRoute(rpcRouteId, wproxy)
|
||||
wshutil.DefaultRouter.RegisterRoute(wshutil.MakeWindowRouteId(windowId), wproxy)
|
||||
defer func() {
|
||||
wshutil.DefaultRouter.UnregisterRoute(rpcRouteId)
|
||||
wshutil.DefaultRouter.UnregisterRoute(wshutil.MakeWindowRouteId(windowId))
|
||||
close(wproxy.ToRemoteCh)
|
||||
}()
|
||||
// WshServerFactoryFn(rpcInputCh, rpcOutputCh, wshrpc.RpcContext{})
|
||||
|
@ -4,6 +4,8 @@
|
||||
package wshclient
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/wavetermdev/thenextwave/pkg/util/utilfn"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshutil"
|
||||
@ -14,6 +16,9 @@ func sendRpcRequestCallHelper[T any](w *wshutil.WshRpc, command string, data int
|
||||
opts = &wshrpc.RpcOpts{}
|
||||
}
|
||||
var respData T
|
||||
if w == nil {
|
||||
return respData, errors.New("nil wshrpc passed to wshclient")
|
||||
}
|
||||
if opts.NoResponse {
|
||||
err := w.SendCommand(command, data, opts)
|
||||
if err != nil {
|
||||
@ -32,17 +37,26 @@ func sendRpcRequestCallHelper[T any](w *wshutil.WshRpc, command string, data int
|
||||
return respData, nil
|
||||
}
|
||||
|
||||
func rtnErr[T any](ch chan wshrpc.RespOrErrorUnion[T], err error) {
|
||||
go func() {
|
||||
ch <- wshrpc.RespOrErrorUnion[T]{Error: err}
|
||||
close(ch)
|
||||
}()
|
||||
}
|
||||
|
||||
func sendRpcRequestResponseStreamHelper[T any](w *wshutil.WshRpc, command string, data interface{}, opts *wshrpc.RpcOpts) chan wshrpc.RespOrErrorUnion[T] {
|
||||
if opts == nil {
|
||||
opts = &wshrpc.RpcOpts{}
|
||||
}
|
||||
respChan := make(chan wshrpc.RespOrErrorUnion[T])
|
||||
if w == nil {
|
||||
rtnErr(respChan, errors.New("nil wshrpc passed to wshclient"))
|
||||
return respChan
|
||||
}
|
||||
reqHandler, err := w.SendComplexRequest(command, data, opts)
|
||||
if err != nil {
|
||||
go func() {
|
||||
respChan <- wshrpc.RespOrErrorUnion[T]{Error: err}
|
||||
close(respChan)
|
||||
}()
|
||||
rtnErr(respChan, err)
|
||||
return respChan
|
||||
} else {
|
||||
go func() {
|
||||
defer close(respChan)
|
||||
|
@ -107,6 +107,7 @@ type RpcOpts struct {
|
||||
type RpcContext struct {
|
||||
BlockId string `json:"blockid,omitempty"`
|
||||
TabId string `json:"tabid,omitempty"`
|
||||
Conn string `json:"conn,omitempty"`
|
||||
}
|
||||
|
||||
func HackRpcContextIntoData(dataPtr any, rpcContext RpcContext) {
|
||||
|
@ -6,6 +6,7 @@ package wshutil
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@ -81,12 +82,14 @@ func handleAuthenticationCommand(msg RpcMessage) (*wshrpc.RpcContext, error) {
|
||||
if newCtx == nil {
|
||||
return nil, fmt.Errorf("no context found in jwt token")
|
||||
}
|
||||
if newCtx.BlockId == "" {
|
||||
return nil, fmt.Errorf("no blockId found in jwt token")
|
||||
if newCtx.BlockId == "" && newCtx.Conn == "" {
|
||||
return nil, fmt.Errorf("no blockid or conn found in jwt token")
|
||||
}
|
||||
if newCtx.BlockId != "" {
|
||||
if _, err := uuid.Parse(newCtx.BlockId); err != nil {
|
||||
return nil, fmt.Errorf("invalid blockId in jwt token")
|
||||
}
|
||||
}
|
||||
return newCtx, nil
|
||||
}
|
||||
|
||||
@ -114,6 +117,7 @@ func (p *WshRpcProxy) HandleAuthentication() (*wshrpc.RpcContext, error) {
|
||||
}
|
||||
newCtx, err := handleAuthenticationCommand(msg)
|
||||
if err != nil {
|
||||
log.Printf("error handling authentication: %v\n", err)
|
||||
p.sendResponseError(msg, err)
|
||||
continue
|
||||
}
|
||||
|
@ -42,6 +42,14 @@ func MakeConnectionRouteId(connId string) string {
|
||||
return "conn:" + connId
|
||||
}
|
||||
|
||||
func MakeControllerRouteId(blockId string) string {
|
||||
return "controller:" + blockId
|
||||
}
|
||||
|
||||
func MakeWindowRouteId(windowId string) string {
|
||||
return "window:" + windowId
|
||||
}
|
||||
|
||||
var DefaultRouter = NewWshRouter()
|
||||
|
||||
func NewWshRouter() *WshRouter {
|
||||
@ -230,6 +238,7 @@ func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient) {
|
||||
log.Printf("error: WshRouter cannot register sys route\n")
|
||||
return
|
||||
}
|
||||
log.Printf("registering wsh route %q\n", routeId)
|
||||
router.Lock.Lock()
|
||||
defer router.Lock.Unlock()
|
||||
router.RouteMap[routeId] = rpc
|
||||
|
@ -43,7 +43,7 @@ func streamToLines_processBuf(lineBuf *lineBuf, readBuf []byte, lineFn func([]by
|
||||
}
|
||||
}
|
||||
|
||||
func streamToLines(input io.Reader, lineFn func([]byte)) error {
|
||||
func StreamToLines(input io.Reader, lineFn func([]byte)) error {
|
||||
var lineBuf lineBuf
|
||||
readBuf := make([]byte, 16*1024)
|
||||
for {
|
||||
@ -56,7 +56,7 @@ func streamToLines(input io.Reader, lineFn func([]byte)) error {
|
||||
}
|
||||
|
||||
func AdaptStreamToMsgCh(input io.Reader, output chan []byte) error {
|
||||
return streamToLines(input, func(line []byte) {
|
||||
return StreamToLines(input, func(line []byte) {
|
||||
output <- line
|
||||
})
|
||||
}
|
||||
|
@ -248,6 +248,9 @@ func MakeClientJWTToken(rpcCtx wshrpc.RpcContext, sockName string) (string, erro
|
||||
if rpcCtx.TabId != "" {
|
||||
claims["tabid"] = rpcCtx.TabId
|
||||
}
|
||||
if rpcCtx.Conn != "" {
|
||||
claims["conn"] = rpcCtx.Conn
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenStr, err := token.SignedString([]byte(wavebase.JwtSecret))
|
||||
if err != nil {
|
||||
@ -299,6 +302,11 @@ func mapClaimsToRpcContext(claims jwt.MapClaims) *wshrpc.RpcContext {
|
||||
rpcCtx.TabId = tabId
|
||||
}
|
||||
}
|
||||
if claims["conn"] != nil {
|
||||
if conn, ok := claims["conn"].(string); ok {
|
||||
rpcCtx.Conn = conn
|
||||
}
|
||||
}
|
||||
return rpcCtx
|
||||
}
|
||||
|
||||
@ -306,9 +314,12 @@ func RunWshRpcOverListener(listener net.Listener) {
|
||||
defer log.Printf("domain socket listener shutting down\n")
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("error accepting connection: %v\n", err)
|
||||
continue
|
||||
break
|
||||
}
|
||||
log.Print("got domain socket connection\n")
|
||||
go handleDomainSocketClient(conn)
|
||||
@ -337,7 +348,11 @@ func handleDomainSocketClient(conn net.Conn) {
|
||||
// now that we're authenticated, set the ctx and attach to the router
|
||||
log.Printf("domain socket connection authenticated: %#v\n", rpcCtx)
|
||||
proxy.SetRpcContext(rpcCtx)
|
||||
DefaultRouter.RegisterRoute("controller:"+rpcCtx.BlockId, proxy)
|
||||
if rpcCtx.BlockId != "" {
|
||||
DefaultRouter.RegisterRoute(MakeControllerRouteId(rpcCtx.BlockId), proxy)
|
||||
} else if rpcCtx.Conn != "" {
|
||||
DefaultRouter.RegisterRoute(MakeConnectionRouteId(rpcCtx.Conn), proxy)
|
||||
}
|
||||
}
|
||||
|
||||
// only for use on client
|
||||
|
Loading…
Reference in New Issue
Block a user