set up remote connserver (#245)

This commit is contained in:
Mike Sawka 2024-08-18 21:26:44 -07:00 committed by GitHub
parent c30188552f
commit 85874f92ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 357 additions and 212 deletions

View File

@ -7,15 +7,15 @@ import (
"os" "os"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/wavetermdev/thenextwave/pkg/wshrpc/wshclient"
"github.com/wavetermdev/thenextwave/pkg/wshrpc/wshremote" "github.com/wavetermdev/thenextwave/pkg/wshrpc/wshremote"
) )
var serverCmd = &cobra.Command{ var serverCmd = &cobra.Command{
Use: "server", Use: "connserver",
Short: "remote server to power wave blocks", Short: "remote server to power wave blocks",
Args: cobra.NoArgs, Args: cobra.NoArgs,
Run: serverRun, Run: serverRun,
PreRunE: preRunSetupRpcClient,
} }
func init() { func init() {
@ -23,10 +23,8 @@ func init() {
} }
func serverRun(cmd *cobra.Command, args []string) { func serverRun(cmd *cobra.Command, args []string) {
WriteStdout("running wsh server\n") WriteStdout("running wsh connserver\n")
RpcClient.SetServerImpl(&wshremote.ServerImpl{LogWriter: os.Stdout}) RpcClient.SetServerImpl(&wshremote.ServerImpl{LogWriter: os.Stdout})
err := wshclient.TestCommand(RpcClient, "hello", nil)
WriteStdout("got test rtn: %v\n", err)
select {} // run forever select {} // run forever
} }

View File

@ -9,10 +9,11 @@ import (
) )
var deleteBlockCmd = &cobra.Command{ var deleteBlockCmd = &cobra.Command{
Use: "deleteblock", Use: "deleteblock",
Short: "delete a block", Short: "delete a block",
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
Run: deleteBlockRun, Run: deleteBlockRun,
PreRunE: preRunSetupRpcClient,
} }
func init() { func init() {

View File

@ -12,10 +12,11 @@ import (
) )
var getMetaCmd = &cobra.Command{ var getMetaCmd = &cobra.Command{
Use: "getmeta", Use: "getmeta",
Short: "get metadata for an entity", Short: "get metadata for an entity",
Args: cobra.RangeArgs(1, 2), Args: cobra.RangeArgs(1, 2),
Run: getMetaRun, Run: getMetaRun,
PreRunE: preRunSetupRpcClient,
} }
func init() { func init() {

View File

@ -15,9 +15,10 @@ func init() {
} }
var htmlCmd = &cobra.Command{ var htmlCmd = &cobra.Command{
Use: "html", Use: "html",
Short: "Launch a demo html-mode terminal", Short: "Launch a demo html-mode terminal",
Run: htmlRun, Run: htmlRun,
PreRunE: preRunSetupRpcClient,
} }
func htmlRun(cmd *cobra.Command, args []string) { func htmlRun(cmd *cobra.Command, args []string) {

View File

@ -12,10 +12,11 @@ import (
) )
var readFileCmd = &cobra.Command{ var readFileCmd = &cobra.Command{
Use: "readfile", Use: "readfile",
Short: "read a blockfile", Short: "read a blockfile",
Args: cobra.ExactArgs(2), Args: cobra.ExactArgs(2),
Run: runReadFile, Run: runReadFile,
PreRunE: preRunSetupRpcClient,
} }
func init() { func init() {

View File

@ -23,9 +23,10 @@ import (
var ( var (
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "wsh", Use: "wsh",
Short: "CLI tool to control Wave Terminal", Short: "CLI tool to control Wave Terminal",
Long: `wsh is a small utility that lets you do cool things with Wave Terminal, right from the command line`, Long: `wsh is a small utility that lets you do cool things with Wave Terminal, right from the command line`,
SilenceUsage: true,
} }
) )
@ -60,9 +61,17 @@ func WriteStdout(fmtStr string, args ...interface{}) {
fmt.Print(output) fmt.Print(output)
} }
func preRunSetupRpcClient(cmd *cobra.Command, args []string) error {
err := setupRpcClient(nil)
if err != nil {
return err
}
return nil
}
// returns the wrapped stdin and a new rpc client (that wraps the stdin input and stdout output) // returns the wrapped stdin and a new rpc client (that wraps the stdin input and stdout output)
func setupRpcClient(serverImpl wshutil.ServerImpl) error { func setupRpcClient(serverImpl wshutil.ServerImpl) error {
jwtToken := os.Getenv("WAVETERM_JWT") jwtToken := os.Getenv(wshutil.WaveJwtTokenVarName)
if jwtToken == "" { if jwtToken == "" {
wshutil.SetTermRawModeAndInstallShutdownHandlers(true) wshutil.SetTermRawModeAndInstallShutdownHandlers(true)
UsingTermWshMode = true UsingTermWshMode = true
@ -71,7 +80,7 @@ func setupRpcClient(serverImpl wshutil.ServerImpl) error {
} }
sockName, err := wshutil.ExtractUnverifiedSocketName(jwtToken) sockName, err := wshutil.ExtractUnverifiedSocketName(jwtToken)
if err != nil { if err != nil {
return fmt.Errorf("error extracting socket name from WAVETERM_JWT: %v", err) return fmt.Errorf("error extracting socket name from %s: %v", wshutil.WaveJwtTokenVarName, err)
} }
RpcClient, err = wshutil.SetupDomainSocketRpcClient(sockName, serverImpl) RpcClient, err = wshutil.SetupDomainSocketRpcClient(sockName, serverImpl)
if err != nil { if err != nil {
@ -158,13 +167,7 @@ func Execute() {
wshutil.DoShutdown("", 0, false) wshutil.DoShutdown("", 0, false)
} }
}() }()
err := setupRpcClient(nil) err := rootCmd.Execute()
if err != nil {
log.Printf("[error] %v\n", err)
wshutil.DoShutdown("", 1, true)
return
}
err = rootCmd.Execute()
if err != nil { if err != nil {
log.Printf("[error] %v\n", err) log.Printf("[error] %v\n", err)
wshutil.DoShutdown("", 1, true) wshutil.DoShutdown("", 1, true)

View File

@ -14,10 +14,11 @@ import (
) )
var setMetaCmd = &cobra.Command{ var setMetaCmd = &cobra.Command{
Use: "setmeta", Use: "setmeta",
Short: "set metadata for an entity", Short: "set metadata for an entity",
Args: cobra.MinimumNArgs(2), Args: cobra.MinimumNArgs(2),
Run: setMetaRun, Run: setMetaRun,
PreRunE: preRunSetupRpcClient,
} }
func init() { func init() {

View File

@ -15,10 +15,11 @@ import (
) )
var termCmd = &cobra.Command{ var termCmd = &cobra.Command{
Use: "term", Use: "term",
Short: "open a terminal in directory", Short: "open a terminal in directory",
Args: cobra.RangeArgs(0, 1), Args: cobra.RangeArgs(0, 1),
Run: termRun, Run: termRun,
PreRunE: preRunSetupRpcClient,
} }
func init() { func init() {

View File

@ -17,10 +17,11 @@ import (
var viewNewBlock bool var viewNewBlock bool
var viewCmd = &cobra.Command{ var viewCmd = &cobra.Command{
Use: "view", Use: "view",
Short: "preview a file or directory", Short: "preview a file or directory",
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
Run: viewRun, Run: viewRun,
PreRunE: preRunSetupRpcClient,
} }
func init() { func init() {

View File

@ -18,6 +18,7 @@ import (
"github.com/wavetermdev/thenextwave/pkg/eventbus" "github.com/wavetermdev/thenextwave/pkg/eventbus"
"github.com/wavetermdev/thenextwave/pkg/filestore" "github.com/wavetermdev/thenextwave/pkg/filestore"
"github.com/wavetermdev/thenextwave/pkg/remote" "github.com/wavetermdev/thenextwave/pkg/remote"
"github.com/wavetermdev/thenextwave/pkg/remote/conncontroller"
"github.com/wavetermdev/thenextwave/pkg/shellexec" "github.com/wavetermdev/thenextwave/pkg/shellexec"
"github.com/wavetermdev/thenextwave/pkg/wavebase" "github.com/wavetermdev/thenextwave/pkg/wavebase"
"github.com/wavetermdev/thenextwave/pkg/waveobj" "github.com/wavetermdev/thenextwave/pkg/waveobj"
@ -277,7 +278,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
if err != nil { if err != nil {
return err return err
} }
conn, err := remote.GetConn(credentialCtx, opts) conn, err := conncontroller.GetConn(credentialCtx, opts)
if err != nil { if err != nil {
return err return err
} }
@ -288,7 +289,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
} }
cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr
} }
shellProc, err = shellexec.StartRemoteShellProc(rc.TermSize, cmdStr, cmdOpts, conn) shellProc, err = shellexec.StartRemoteShellProc(rc.TermSize, cmdStr, cmdOpts, conn.Client)
if err != nil { if err != nil {
return err return err
} }
@ -317,7 +318,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
// we don't need to authenticate this wshProxy since it is coming direct // we don't need to authenticate this wshProxy since it is coming direct
wshProxy := wshutil.MakeRpcProxy() wshProxy := wshutil.MakeRpcProxy()
wshProxy.SetRpcContext(&wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId}) wshProxy.SetRpcContext(&wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId})
wshutil.DefaultRouter.RegisterRoute("controller:"+bc.BlockId, wshProxy) wshutil.DefaultRouter.RegisterRoute(wshutil.MakeControllerRouteId(bc.BlockId), wshProxy)
ptyBuffer := wshutil.MakePtyBuffer(wshutil.WaveOSCPrefix, bc.ShellProc.Cmd, wshProxy.FromRemoteCh) ptyBuffer := wshutil.MakePtyBuffer(wshutil.WaveOSCPrefix, bc.ShellProc.Cmd, wshProxy.FromRemoteCh)
go func() { go func() {
// handles regular output from the pty (goes to the blockfile and xterm) // handles regular output from the pty (goes to the blockfile and xterm)
@ -376,7 +377,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
go func() { go func() {
// wait for the shell to finish // wait for the shell to finish
defer func() { defer func() {
wshutil.DefaultRouter.UnregisterRoute("controller:" + bc.BlockId) wshutil.DefaultRouter.UnregisterRoute(wshutil.MakeControllerRouteId(bc.BlockId))
bc.UpdateControllerAndSendUpdate(func() bool { bc.UpdateControllerAndSendUpdate(func() bool {
bc.ShellProcStatus = Status_Done bc.ShellProcStatus = Status_Done
return true return true

View File

@ -0,0 +1,224 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package conncontroller
import (
"context"
"fmt"
"io"
"log"
"net"
"strings"
"sync"
"github.com/wavetermdev/thenextwave/pkg/remote"
"github.com/wavetermdev/thenextwave/pkg/userinput"
"github.com/wavetermdev/thenextwave/pkg/util/shellutil"
"github.com/wavetermdev/thenextwave/pkg/util/utilfn"
"github.com/wavetermdev/thenextwave/pkg/wavebase"
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
"github.com/wavetermdev/thenextwave/pkg/wshutil"
"golang.org/x/crypto/ssh"
)
var globalLock = &sync.Mutex{}
var clientControllerMap = make(map[remote.SSHOpts]*SSHConn)
type SSHConn struct {
Lock *sync.Mutex
Opts *remote.SSHOpts
Client *ssh.Client
SockName string
DomainSockListener net.Listener
ConnController *ssh.Session
}
func (conn *SSHConn) Close() error {
if conn.DomainSockListener != nil {
conn.DomainSockListener.Close()
conn.DomainSockListener = nil
}
if conn.ConnController != nil {
conn.ConnController.Close()
conn.ConnController = nil
}
err := conn.Client.Close()
conn.Client = nil
return err
}
func (conn *SSHConn) OpenDomainSocketListener() error {
if conn.DomainSockListener != nil {
return nil
}
randStr, err := utilfn.RandomHexString(16) // 64-bits of randomness
if err != nil {
return fmt.Errorf("error generating random string: %w", err)
}
sockName := fmt.Sprintf("/tmp/waveterm-%s.sock", randStr)
log.Printf("remote domain socket %s %q\n", conn.Opts.String(), sockName)
listener, err := conn.Client.ListenUnix(sockName)
if err != nil {
return fmt.Errorf("unable to request connection domain socket: %v", err)
}
conn.SockName = sockName
conn.DomainSockListener = listener
go func() {
defer func() {
conn.Lock.Lock()
defer conn.Lock.Unlock()
conn.DomainSockListener = nil
}()
wshutil.RunWshRpcOverListener(listener)
}()
return nil
}
func (conn *SSHConn) StartConnServer() error {
conn.Lock.Lock()
defer conn.Lock.Unlock()
if conn.ConnController != nil {
return nil
}
wshPath := remote.GetWshPath(conn.Client)
rpcCtx := wshrpc.RpcContext{
Conn: conn.Opts.String(),
}
jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, conn.SockName)
if err != nil {
return fmt.Errorf("unable to create jwt token for conn controller: %w", err)
}
sshSession, err := conn.Client.NewSession()
if err != nil {
return fmt.Errorf("unable to create ssh session for conn controller: %w", err)
}
pipeRead, pipeWrite := io.Pipe()
sshSession.Stdout = pipeWrite
sshSession.Stderr = pipeWrite
conn.ConnController = sshSession
cmdStr := fmt.Sprintf("%s=\"%s\" %s connserver", wshutil.WaveJwtTokenVarName, jwtToken, wshPath)
log.Printf("starting conn controller: %s\n", cmdStr)
err = sshSession.Start(cmdStr)
if err != nil {
return fmt.Errorf("unable to start conn controller: %w", err)
}
// service the I/O
go func() {
// wait for termination, clear the controller
waitErr := sshSession.Wait()
log.Printf("conn controller (%q) terminated: %v", conn.Opts.String(), waitErr)
conn.Lock.Lock()
defer conn.Lock.Unlock()
conn.ConnController = nil
}()
go func() {
readErr := wshutil.StreamToLines(pipeRead, func(line []byte) {
lineStr := string(line)
if !strings.HasSuffix(lineStr, "\n") {
lineStr += "\n"
}
log.Printf("[conncontroller:%s:output] %s", conn.Opts.String(), lineStr)
})
if readErr != nil && readErr != io.EOF {
log.Printf("[conncontroller:%s] error reading output: %v\n", conn.Opts.String(), readErr)
}
}()
return nil
}
func (conn *SSHConn) checkAndInstallWsh(ctx context.Context) error {
client := conn.Client
// check that correct wsh extensions are installed
expectedVersion := fmt.Sprintf("wsh v%s", wavebase.WaveVersion)
clientVersion, err := remote.GetWshVersion(client)
if err == nil && clientVersion == expectedVersion {
return nil
}
var queryText string
var title string
if err != nil {
queryText = "Waveterm requires `wsh` shell extensions installed on your client to ensure a seamless experience. Would you like to install them?"
title = "Install Wsh Shell Extensions"
} else {
queryText = fmt.Sprintf("Waveterm requires `wsh` shell extensions installed on your client to be updated from %s to %s. Would you like to update?", clientVersion, expectedVersion)
title = "Update Wsh Shell Extensions"
}
request := &userinput.UserInputRequest{
ResponseType: "confirm",
QueryText: queryText,
Title: title,
CheckBoxMsg: "Don't show me this again",
}
response, err := userinput.GetUserInput(ctx, request)
if err != nil || !response.Confirm {
return err
}
log.Printf("attempting to install wsh to `%s@%s`", client.User(), client.RemoteAddr().String())
clientOs, err := remote.GetClientOs(client)
if err != nil {
return err
}
clientArch, err := remote.GetClientArch(client)
if err != nil {
return err
}
// attempt to install extension
wshLocalPath := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch)
err = remote.CpHostToRemote(client, wshLocalPath, "~/.waveterm/bin/wsh")
if err != nil {
return err
}
log.Printf("successfully installed wsh on %s\n", conn.Opts.String())
return nil
}
func GetConn(ctx context.Context, opts *remote.SSHOpts) (*SSHConn, error) {
globalLock.Lock()
defer globalLock.Unlock()
// attempt to retrieve if already opened
conn, ok := clientControllerMap[*opts]
if ok {
return conn, nil
}
client, err := remote.ConnectToClient(ctx, opts) //todo specify or remove opts
if err != nil {
return nil, err
}
conn = &SSHConn{Lock: &sync.Mutex{}, Opts: opts, Client: client}
err = conn.OpenDomainSocketListener()
if err != nil {
conn.Close()
return nil, err
}
installErr := conn.checkAndInstallWsh(ctx)
if installErr != nil {
conn.Close()
return nil, fmt.Errorf("conncontroller %s wsh install error: %v", conn.Opts.String(), installErr)
}
csErr := conn.StartConnServer()
if csErr != nil {
conn.Close()
return nil, fmt.Errorf("conncontroller %s start wsh connserver error: %v", conn.Opts.String(), csErr)
}
// save successful connection to map
clientControllerMap[*opts] = conn
return conn, nil
}
func DisconnectClient(opts *remote.SSHOpts) error {
globalLock.Lock()
defer globalLock.Unlock()
client, ok := clientControllerMap[*opts]
if ok {
return client.Close()
}
return fmt.Errorf("client %v not found", opts)
}

View File

@ -2,156 +2,20 @@ package remote
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"html/template" "html/template"
"io" "io"
"log" "log"
"net"
"os" "os"
"path/filepath" "path/filepath"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"sync"
"github.com/wavetermdev/thenextwave/pkg/userinput"
"github.com/wavetermdev/thenextwave/pkg/util/shellutil"
"github.com/wavetermdev/thenextwave/pkg/util/utilfn"
"github.com/wavetermdev/thenextwave/pkg/wavebase"
"github.com/wavetermdev/thenextwave/pkg/wshutil"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
var userHostRe = regexp.MustCompile(`^([a-zA-Z0-9][a-zA-Z0-9._@\\-]*@)?([a-z0-9][a-z0-9.-]*)(?::([0-9]+))?$`) var userHostRe = regexp.MustCompile(`^([a-zA-Z0-9][a-zA-Z0-9._@\\-]*@)?([a-z0-9][a-z0-9.-]*)(?::([0-9]+))?$`)
var globalLock = &sync.Mutex{}
var clientControllerMap = make(map[SSHOpts]*SSHConn)
type SSHConn struct {
Lock *sync.Mutex
Opts *SSHOpts
Client *ssh.Client
SockName string
DomainSockListener net.Listener
}
func (conn *SSHConn) Close() error {
if conn.DomainSockListener != nil {
conn.DomainSockListener.Close()
}
return conn.Client.Close()
}
func (conn *SSHConn) OpenDomainSocketListener() error {
if conn.DomainSockListener != nil {
return nil
}
randStr, err := utilfn.RandomHexString(16) // 64-bits of randomness
if err != nil {
return fmt.Errorf("error generating random string: %w", err)
}
sockName := fmt.Sprintf("/tmp/waveterm-%s.sock", randStr)
log.Printf("remote domain socket %s %q\n", conn.Opts.String(), sockName)
listener, err := conn.Client.ListenUnix(sockName)
if err != nil {
return fmt.Errorf("unable to request connection domain socket: %v", err)
}
conn.SockName = sockName
conn.DomainSockListener = listener
go func() {
wshutil.RunWshRpcOverListener(listener)
}()
return nil
}
func GetConn(ctx context.Context, opts *SSHOpts) (*SSHConn, error) {
globalLock.Lock()
defer globalLock.Unlock()
// attempt to retrieve if already opened
conn, ok := clientControllerMap[*opts]
if ok {
return conn, nil
}
client, err := ConnectToClient(ctx, opts) //todo specify or remove opts
if err != nil {
return nil, err
}
conn = &SSHConn{Lock: &sync.Mutex{}, Opts: opts, Client: client}
err = conn.OpenDomainSocketListener()
if err != nil {
conn.Close()
return nil, err
}
// check that correct wsh extensions are installed
expectedVersion := fmt.Sprintf("wsh v%s", wavebase.WaveVersion)
clientVersion, err := getWshVersion(client)
if err == nil && clientVersion == expectedVersion {
// save successful connection to map
clientControllerMap[*opts] = conn
return conn, nil
}
var queryText string
var title string
if err != nil {
queryText = "Waveterm requires `wsh` shell extensions installed on your client to ensure a seamless experience. Would you like to install them?"
title = "Install Wsh Shell Extensions"
} else {
queryText = fmt.Sprintf("Waveterm requires `wsh` shell extensions installed on your client to be updated from %s to %s. Would you like to update?", clientVersion, expectedVersion)
title = "Update Wsh Shell Extensions"
}
request := &userinput.UserInputRequest{
ResponseType: "confirm",
QueryText: queryText,
Title: title,
CheckBoxMsg: "Don't show me this again",
}
response, err := userinput.GetUserInput(ctx, request)
if err != nil || !response.Confirm {
return nil, err
}
log.Printf("attempting to install wsh to `%s@%s`", client.User(), client.RemoteAddr().String())
clientOs, err := getClientOs(client)
if err != nil {
return nil, err
}
clientArch, err := getClientArch(client)
if err != nil {
return nil, err
}
// attempt to install extension
wshLocalPath := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch)
err = cpHostToRemote(client, wshLocalPath, "~/.waveterm/bin/wsh")
if err != nil {
return nil, err
}
log.Printf("successful install")
// save successful connection to map
clientControllerMap[*opts] = conn
return conn, nil
}
func DisconnectClient(opts *SSHOpts) error {
globalLock.Lock()
defer globalLock.Unlock()
client, ok := clientControllerMap[*opts]
if ok {
return client.Close()
}
return fmt.Errorf("client %v not found", opts)
}
func ParseOpts(input string) (*SSHOpts, error) { func ParseOpts(input string) (*SSHOpts, error) {
m := userHostRe.FindStringSubmatch(input) m := userHostRe.FindStringSubmatch(input)
@ -173,7 +37,7 @@ func ParseOpts(input string) (*SSHOpts, error) {
} }
func DetectShell(client *ssh.Client) (string, error) { func DetectShell(client *ssh.Client) (string, error) {
wshPath := getWshPath(client) wshPath := GetWshPath(client)
session, err := client.NewSession() session, err := client.NewSession()
if err != nil { if err != nil {
@ -191,8 +55,8 @@ func DetectShell(client *ssh.Client) (string, error) {
return fmt.Sprintf(`"%s"`, strings.TrimSpace(string(out))), nil return fmt.Sprintf(`"%s"`, strings.TrimSpace(string(out))), nil
} }
func getWshVersion(client *ssh.Client) (string, error) { func GetWshVersion(client *ssh.Client) (string, error) {
wshPath := getWshPath(client) wshPath := GetWshPath(client)
session, err := client.NewSession() session, err := client.NewSession()
if err != nil { if err != nil {
@ -207,7 +71,7 @@ func getWshVersion(client *ssh.Client) (string, error) {
return strings.TrimSpace(string(out)), nil return strings.TrimSpace(string(out)), nil
} }
func getWshPath(client *ssh.Client) string { func GetWshPath(client *ssh.Client) string {
defaultPath := filepath.Join("~", ".waveterm", "bin", "wsh") defaultPath := filepath.Join("~", ".waveterm", "bin", "wsh")
session, err := client.NewSession() session, err := client.NewSession()
@ -267,7 +131,7 @@ func hasBashInstalled(client *ssh.Client) (bool, error) {
return false, nil return false, nil
} }
func getClientOs(client *ssh.Client) (string, error) { func GetClientOs(client *ssh.Client) (string, error) {
session, err := client.NewSession() session, err := client.NewSession()
if err != nil { if err != nil {
return "", err return "", err
@ -306,7 +170,7 @@ func getClientOs(client *ssh.Client) (string, error) {
return "", fmt.Errorf("unable to determine os: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr) return "", fmt.Errorf("unable to determine os: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr)
} }
func getClientArch(client *ssh.Client) (string, error) { func GetClientArch(client *ssh.Client) (string, error) {
session, err := client.NewSession() session, err := client.NewSession()
if err != nil { if err != nil {
return "", err return "", err
@ -360,7 +224,7 @@ mv {{.tempPath}} {{.installPath}}; \
chmod a+x {{.installPath}}; \ chmod a+x {{.installPath}}; \
` `
func cpHostToRemote(client *ssh.Client, sourcePath string, destPath string) error { func CpHostToRemote(client *ssh.Client, sourcePath string, destPath string) error {
// warning: does not work on windows remote yet // warning: does not work on windows remote yet
bashInstalled, err := hasBashInstalled(client) bashInstalled, err := hasBashInstalled(client)
if err != nil { if err != nil {
@ -414,7 +278,7 @@ func cpHostToRemote(client *ssh.Client, sourcePath string, destPath string) erro
} }
func InstallClientRcFiles(client *ssh.Client) error { func InstallClientRcFiles(client *ssh.Client) error {
path := getWshPath(client) path := GetWshPath(client)
session, err := client.NewSession() session, err := client.NewSession()
if err != nil { if err != nil {

View File

@ -24,6 +24,7 @@ import (
"github.com/wavetermdev/thenextwave/pkg/wavebase" "github.com/wavetermdev/thenextwave/pkg/wavebase"
"github.com/wavetermdev/thenextwave/pkg/wshutil" "github.com/wavetermdev/thenextwave/pkg/wshutil"
"github.com/wavetermdev/thenextwave/pkg/wstore" "github.com/wavetermdev/thenextwave/pkg/wstore"
"golang.org/x/crypto/ssh"
) )
type CommandOptsType struct { type CommandOptsType struct {
@ -149,8 +150,7 @@ func (pp *PipePty) WriteString(s string) (n int, err error) {
return pp.Write([]byte(s)) return pp.Write([]byte(s))
} }
func StartRemoteShellProc(termSize wstore.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *remote.SSHConn) (*ShellProc, error) { func StartRemoteShellProc(termSize wstore.TermSize, cmdStr string, cmdOpts CommandOptsType, client *ssh.Client) (*ShellProc, error) {
client := conn.Client
shellPath, err := remote.DetectShell(client) shellPath, err := remote.DetectShell(client)
if err != nil { if err != nil {
return nil, err return nil, err
@ -195,6 +195,7 @@ func StartRemoteShellProc(termSize wstore.TermSize, cmdStr string, cmdOpts Comma
} }
shellOpts = append(shellOpts, "-c", cmdStr) shellOpts = append(shellOpts, "-c", cmdStr)
cmdCombined = fmt.Sprintf("%s %s", shellPath, strings.Join(shellOpts, " ")) cmdCombined = fmt.Sprintf("%s %s", shellPath, strings.Join(shellOpts, " "))
log.Printf("combined command is: %s", cmdCombined)
} }
session, err := client.NewSession() session, err := client.NewSession()

View File

@ -19,6 +19,7 @@ import (
"github.com/wavetermdev/thenextwave/pkg/util/utilfn" "github.com/wavetermdev/thenextwave/pkg/util/utilfn"
"github.com/wavetermdev/thenextwave/pkg/wavebase" "github.com/wavetermdev/thenextwave/pkg/wavebase"
"github.com/wavetermdev/thenextwave/pkg/wstore"
) )
const DefaultTermType = "xterm-256color" const DefaultTermType = "xterm-256color"
@ -125,6 +126,10 @@ func internalMacUserShell() string {
return m[1] return m[1]
} }
func DefaultTermSize() wstore.TermSize {
return wstore.TermSize{Rows: DefaultTermRows, Cols: DefaultTermCols}
}
func WaveshellLocalEnvVars(termType string) map[string]string { func WaveshellLocalEnvVars(termType string) map[string]string {
rtn := make(map[string]string) rtn := make(map[string]string)
if termType != "" { if termType != "" {

View File

@ -257,10 +257,9 @@ func HandleWsInternal(w http.ResponseWriter, r *http.Request) error {
defer eventbus.UnregisterWSChannel(wsConnId) defer eventbus.UnregisterWSChannel(wsConnId)
// we create a wshproxy to handle rpc messages to/from the window // we create a wshproxy to handle rpc messages to/from the window
wproxy := wshutil.MakeRpcProxy() wproxy := wshutil.MakeRpcProxy()
rpcRouteId := "window:" + windowId wshutil.DefaultRouter.RegisterRoute(wshutil.MakeWindowRouteId(windowId), wproxy)
wshutil.DefaultRouter.RegisterRoute(rpcRouteId, wproxy)
defer func() { defer func() {
wshutil.DefaultRouter.UnregisterRoute(rpcRouteId) wshutil.DefaultRouter.UnregisterRoute(wshutil.MakeWindowRouteId(windowId))
close(wproxy.ToRemoteCh) close(wproxy.ToRemoteCh)
}() }()
// WshServerFactoryFn(rpcInputCh, rpcOutputCh, wshrpc.RpcContext{}) // WshServerFactoryFn(rpcInputCh, rpcOutputCh, wshrpc.RpcContext{})

View File

@ -4,6 +4,8 @@
package wshclient package wshclient
import ( import (
"errors"
"github.com/wavetermdev/thenextwave/pkg/util/utilfn" "github.com/wavetermdev/thenextwave/pkg/util/utilfn"
"github.com/wavetermdev/thenextwave/pkg/wshrpc" "github.com/wavetermdev/thenextwave/pkg/wshrpc"
"github.com/wavetermdev/thenextwave/pkg/wshutil" "github.com/wavetermdev/thenextwave/pkg/wshutil"
@ -14,6 +16,9 @@ func sendRpcRequestCallHelper[T any](w *wshutil.WshRpc, command string, data int
opts = &wshrpc.RpcOpts{} opts = &wshrpc.RpcOpts{}
} }
var respData T var respData T
if w == nil {
return respData, errors.New("nil wshrpc passed to wshclient")
}
if opts.NoResponse { if opts.NoResponse {
err := w.SendCommand(command, data, opts) err := w.SendCommand(command, data, opts)
if err != nil { if err != nil {
@ -32,17 +37,26 @@ func sendRpcRequestCallHelper[T any](w *wshutil.WshRpc, command string, data int
return respData, nil return respData, nil
} }
func rtnErr[T any](ch chan wshrpc.RespOrErrorUnion[T], err error) {
go func() {
ch <- wshrpc.RespOrErrorUnion[T]{Error: err}
close(ch)
}()
}
func sendRpcRequestResponseStreamHelper[T any](w *wshutil.WshRpc, command string, data interface{}, opts *wshrpc.RpcOpts) chan wshrpc.RespOrErrorUnion[T] { func sendRpcRequestResponseStreamHelper[T any](w *wshutil.WshRpc, command string, data interface{}, opts *wshrpc.RpcOpts) chan wshrpc.RespOrErrorUnion[T] {
if opts == nil { if opts == nil {
opts = &wshrpc.RpcOpts{} opts = &wshrpc.RpcOpts{}
} }
respChan := make(chan wshrpc.RespOrErrorUnion[T]) respChan := make(chan wshrpc.RespOrErrorUnion[T])
if w == nil {
rtnErr(respChan, errors.New("nil wshrpc passed to wshclient"))
return respChan
}
reqHandler, err := w.SendComplexRequest(command, data, opts) reqHandler, err := w.SendComplexRequest(command, data, opts)
if err != nil { if err != nil {
go func() { rtnErr(respChan, err)
respChan <- wshrpc.RespOrErrorUnion[T]{Error: err} return respChan
close(respChan)
}()
} else { } else {
go func() { go func() {
defer close(respChan) defer close(respChan)

View File

@ -107,6 +107,7 @@ type RpcOpts struct {
type RpcContext struct { type RpcContext struct {
BlockId string `json:"blockid,omitempty"` BlockId string `json:"blockid,omitempty"`
TabId string `json:"tabid,omitempty"` TabId string `json:"tabid,omitempty"`
Conn string `json:"conn,omitempty"`
} }
func HackRpcContextIntoData(dataPtr any, rpcContext RpcContext) { func HackRpcContextIntoData(dataPtr any, rpcContext RpcContext) {

View File

@ -6,6 +6,7 @@ package wshutil
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"log"
"sync" "sync"
"github.com/google/uuid" "github.com/google/uuid"
@ -81,11 +82,13 @@ func handleAuthenticationCommand(msg RpcMessage) (*wshrpc.RpcContext, error) {
if newCtx == nil { if newCtx == nil {
return nil, fmt.Errorf("no context found in jwt token") return nil, fmt.Errorf("no context found in jwt token")
} }
if newCtx.BlockId == "" { if newCtx.BlockId == "" && newCtx.Conn == "" {
return nil, fmt.Errorf("no blockId found in jwt token") return nil, fmt.Errorf("no blockid or conn found in jwt token")
} }
if _, err := uuid.Parse(newCtx.BlockId); err != nil { if newCtx.BlockId != "" {
return nil, fmt.Errorf("invalid blockId in jwt token") if _, err := uuid.Parse(newCtx.BlockId); err != nil {
return nil, fmt.Errorf("invalid blockId in jwt token")
}
} }
return newCtx, nil return newCtx, nil
} }
@ -114,6 +117,7 @@ func (p *WshRpcProxy) HandleAuthentication() (*wshrpc.RpcContext, error) {
} }
newCtx, err := handleAuthenticationCommand(msg) newCtx, err := handleAuthenticationCommand(msg)
if err != nil { if err != nil {
log.Printf("error handling authentication: %v\n", err)
p.sendResponseError(msg, err) p.sendResponseError(msg, err)
continue continue
} }

View File

@ -42,6 +42,14 @@ func MakeConnectionRouteId(connId string) string {
return "conn:" + connId return "conn:" + connId
} }
func MakeControllerRouteId(blockId string) string {
return "controller:" + blockId
}
func MakeWindowRouteId(windowId string) string {
return "window:" + windowId
}
var DefaultRouter = NewWshRouter() var DefaultRouter = NewWshRouter()
func NewWshRouter() *WshRouter { func NewWshRouter() *WshRouter {
@ -230,6 +238,7 @@ func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient) {
log.Printf("error: WshRouter cannot register sys route\n") log.Printf("error: WshRouter cannot register sys route\n")
return return
} }
log.Printf("registering wsh route %q\n", routeId)
router.Lock.Lock() router.Lock.Lock()
defer router.Lock.Unlock() defer router.Lock.Unlock()
router.RouteMap[routeId] = rpc router.RouteMap[routeId] = rpc

View File

@ -43,7 +43,7 @@ func streamToLines_processBuf(lineBuf *lineBuf, readBuf []byte, lineFn func([]by
} }
} }
func streamToLines(input io.Reader, lineFn func([]byte)) error { func StreamToLines(input io.Reader, lineFn func([]byte)) error {
var lineBuf lineBuf var lineBuf lineBuf
readBuf := make([]byte, 16*1024) readBuf := make([]byte, 16*1024)
for { for {
@ -56,7 +56,7 @@ func streamToLines(input io.Reader, lineFn func([]byte)) error {
} }
func AdaptStreamToMsgCh(input io.Reader, output chan []byte) error { func AdaptStreamToMsgCh(input io.Reader, output chan []byte) error {
return streamToLines(input, func(line []byte) { return StreamToLines(input, func(line []byte) {
output <- line output <- line
}) })
} }

View File

@ -248,6 +248,9 @@ func MakeClientJWTToken(rpcCtx wshrpc.RpcContext, sockName string) (string, erro
if rpcCtx.TabId != "" { if rpcCtx.TabId != "" {
claims["tabid"] = rpcCtx.TabId claims["tabid"] = rpcCtx.TabId
} }
if rpcCtx.Conn != "" {
claims["conn"] = rpcCtx.Conn
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := token.SignedString([]byte(wavebase.JwtSecret)) tokenStr, err := token.SignedString([]byte(wavebase.JwtSecret))
if err != nil { if err != nil {
@ -299,6 +302,11 @@ func mapClaimsToRpcContext(claims jwt.MapClaims) *wshrpc.RpcContext {
rpcCtx.TabId = tabId rpcCtx.TabId = tabId
} }
} }
if claims["conn"] != nil {
if conn, ok := claims["conn"].(string); ok {
rpcCtx.Conn = conn
}
}
return rpcCtx return rpcCtx
} }
@ -306,9 +314,12 @@ func RunWshRpcOverListener(listener net.Listener) {
defer log.Printf("domain socket listener shutting down\n") defer log.Printf("domain socket listener shutting down\n")
for { for {
conn, err := listener.Accept() conn, err := listener.Accept()
if err == io.EOF {
break
}
if err != nil { if err != nil {
log.Printf("error accepting connection: %v\n", err) log.Printf("error accepting connection: %v\n", err)
continue break
} }
log.Print("got domain socket connection\n") log.Print("got domain socket connection\n")
go handleDomainSocketClient(conn) go handleDomainSocketClient(conn)
@ -337,7 +348,11 @@ func handleDomainSocketClient(conn net.Conn) {
// now that we're authenticated, set the ctx and attach to the router // now that we're authenticated, set the ctx and attach to the router
log.Printf("domain socket connection authenticated: %#v\n", rpcCtx) log.Printf("domain socket connection authenticated: %#v\n", rpcCtx)
proxy.SetRpcContext(rpcCtx) proxy.SetRpcContext(rpcCtx)
DefaultRouter.RegisterRoute("controller:"+rpcCtx.BlockId, proxy) if rpcCtx.BlockId != "" {
DefaultRouter.RegisterRoute(MakeControllerRouteId(rpcCtx.BlockId), proxy)
} else if rpcCtx.Conn != "" {
DefaultRouter.RegisterRoute(MakeConnectionRouteId(rpcCtx.Conn), proxy)
}
} }
// only for use on client // only for use on client