Add Unix Domain Socket Listener when Establishing Connections (#243)

This makes it possible to send wsh commands from wsh on a remote session
to wavesrv running locally. The exact behavior of running those commands
isn't implemented, but the underlying interface is added here.
This commit is contained in:
Sylvie Crowe 2024-08-17 11:21:25 -07:00 committed by GitHub
parent 6df50a5790
commit c30188552f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 164 additions and 119 deletions

View File

@ -244,7 +244,7 @@ func main() {
}
}
}()
go wshserver.RunWshRpcOverListener(unixListener)
go wshutil.RunWshRpcOverListener(unixListener)
web.RunWebServer(webListener) // blocking
runtime.KeepAlive(waveLock)
}

View File

@ -398,7 +398,7 @@ declare global {
enabled: boolean;
};
// shellexec.TermSize
// wstore.TermSize
type TermSize = {
rows: number;
cols: number;

View File

@ -6,9 +6,7 @@ package blockcontroller
import (
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"io"
@ -55,9 +53,9 @@ var globalLock = &sync.Mutex{}
var blockControllerMap = make(map[string]*BlockController)
type BlockInputUnion struct {
InputData []byte `json:"inputdata,omitempty"`
SigName string `json:"signame,omitempty"`
TermSize *shellexec.TermSize `json:"termsize,omitempty"`
InputData []byte `json:"inputdata,omitempty"`
SigName string `json:"signame,omitempty"`
TermSize *wstore.TermSize `json:"termsize,omitempty"`
}
type BlockController struct {
@ -116,7 +114,7 @@ func (bc *BlockController) getShellProc() *shellexec.ShellProc {
}
type RunShellOpts struct {
TermSize shellexec.TermSize `json:"termsize,omitempty"`
TermSize wstore.TermSize `json:"termsize,omitempty"`
}
func (bc *BlockController) UpdateControllerAndSendUpdate(updateFn func() bool) {
@ -205,18 +203,6 @@ func (bc *BlockController) resetTerminalState() {
}
}
// every byte is 4-bits of randomness
func randomHexString(numHexDigits int) (string, error) {
numBytes := (numHexDigits + 1) / 2 // Calculate the number of bytes needed
bytes := make([]byte, numBytes)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
hexStr := hex.EncodeToString(bytes)
return hexStr[:numHexDigits], nil // Return the exact number of hex digits
}
func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj.MetaMapType) error {
// create a circular blockfile for the output
ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
@ -244,35 +230,11 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
if shellProcErr != nil {
return shellProcErr
}
var remoteDomainSocketName string
remoteName := blockMeta.GetString(wstore.MetaKey_Connection, "")
isRemote := remoteName != ""
if isRemote {
randStr, err := randomHexString(16) // 64-bits of randomness
if err != nil {
return fmt.Errorf("error generating random string: %w", err)
}
remoteDomainSocketName = fmt.Sprintf("/tmp/waveterm-%s.sock", randStr)
}
var cmdStr string
cmdOpts := shellexec.CommandOptsType{
Env: make(map[string]string),
}
if !blockMeta.GetBool(wstore.MetaKey_CmdNoWsh, false) {
if isRemote {
jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId}, remoteDomainSocketName)
if err != nil {
return fmt.Errorf("error making jwt token: %w", err)
}
cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr
} else {
jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId}, wavebase.GetDomainSocketName())
if err != nil {
return fmt.Errorf("error making jwt token: %w", err)
}
cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr
}
}
if bc.ControllerType == BlockController_Shell {
cmdOpts.Interactive = true
cmdOpts.Login = true
@ -315,16 +277,29 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
if err != nil {
return err
}
client, err := remote.GetClient(credentialCtx, opts)
conn, err := remote.GetConn(credentialCtx, opts)
if err != nil {
return err
}
shellProc, err = shellexec.StartRemoteShellProc(rc.TermSize, cmdStr, cmdOpts, client)
if !blockMeta.GetBool(wstore.MetaKey_CmdNoWsh, false) {
jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId}, conn.SockName)
if err != nil {
return fmt.Errorf("error making jwt token: %w", err)
}
cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr
}
shellProc, err = shellexec.StartRemoteShellProc(rc.TermSize, cmdStr, cmdOpts, conn)
if err != nil {
return err
}
} else {
if !blockMeta.GetBool(wstore.MetaKey_CmdNoWsh, false) {
jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId}, wavebase.GetDomainSocketName())
if err != nil {
return fmt.Errorf("error making jwt token: %w", err)
}
cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr
}
shellProc, err = shellexec.StartShellProc(rc.TermSize, cmdStr, cmdOpts)
if err != nil {
return err
@ -428,11 +403,11 @@ func getBoolFromMeta(meta map[string]any, key string, def bool) bool {
return def
}
func getTermSize(bdata *wstore.Block) shellexec.TermSize {
func getTermSize(bdata *wstore.Block) wstore.TermSize {
if bdata.RuntimeOpts != nil {
return bdata.RuntimeOpts.TermSize
} else {
return shellexec.TermSize{
return wstore.TermSize{
Rows: 25,
Cols: 80,
}

View File

@ -7,6 +7,7 @@ import (
"html/template"
"io"
"log"
"net"
"os"
"path/filepath"
"regexp"
@ -16,37 +17,81 @@ import (
"github.com/wavetermdev/thenextwave/pkg/userinput"
"github.com/wavetermdev/thenextwave/pkg/util/shellutil"
"github.com/wavetermdev/thenextwave/pkg/util/utilfn"
"github.com/wavetermdev/thenextwave/pkg/wavebase"
"github.com/wavetermdev/thenextwave/pkg/wshutil"
"golang.org/x/crypto/ssh"
)
var userHostRe = regexp.MustCompile(`^([a-zA-Z0-9][a-zA-Z0-9._@\\-]*@)?([a-z0-9][a-z0-9.-]*)(?::([0-9]+))?$`)
var globalLock = &sync.Mutex{}
var clientControllerMap = make(map[SSHOpts]*ssh.Client)
var clientControllerMap = make(map[SSHOpts]*SSHConn)
func GetClient(ctx context.Context, opts *SSHOpts) (*ssh.Client, error) {
type SSHConn struct {
Lock *sync.Mutex
Opts *SSHOpts
Client *ssh.Client
SockName string
DomainSockListener net.Listener
}
func (conn *SSHConn) Close() error {
if conn.DomainSockListener != nil {
conn.DomainSockListener.Close()
}
return conn.Client.Close()
}
func (conn *SSHConn) OpenDomainSocketListener() error {
if conn.DomainSockListener != nil {
return nil
}
randStr, err := utilfn.RandomHexString(16) // 64-bits of randomness
if err != nil {
return fmt.Errorf("error generating random string: %w", err)
}
sockName := fmt.Sprintf("/tmp/waveterm-%s.sock", randStr)
log.Printf("remote domain socket %s %q\n", conn.Opts.String(), sockName)
listener, err := conn.Client.ListenUnix(sockName)
if err != nil {
return fmt.Errorf("unable to request connection domain socket: %v", err)
}
conn.SockName = sockName
conn.DomainSockListener = listener
go func() {
wshutil.RunWshRpcOverListener(listener)
}()
return nil
}
func GetConn(ctx context.Context, opts *SSHOpts) (*SSHConn, error) {
globalLock.Lock()
defer globalLock.Unlock()
// attempt to retrieve if already opened
client, ok := clientControllerMap[*opts]
conn, ok := clientControllerMap[*opts]
if ok {
return client, nil
return conn, nil
}
client, err := ConnectToClient(ctx, opts) //todo specify or remove opts
if err != nil {
return nil, err
}
conn = &SSHConn{Lock: &sync.Mutex{}, Opts: opts, Client: client}
err = conn.OpenDomainSocketListener()
if err != nil {
conn.Close()
return nil, err
}
// check that correct wsh extensions are installed
expectedVersion := fmt.Sprintf("wsh v%s", wavebase.WaveVersion)
clientVersion, err := getWshVersion(client)
if err == nil && clientVersion == expectedVersion {
// save successful connection to map
clientControllerMap[*opts] = client
return client, nil
clientControllerMap[*opts] = conn
return conn, nil
}
var queryText string
@ -92,9 +137,9 @@ func GetClient(ctx context.Context, opts *SSHOpts) (*ssh.Client, error) {
log.Printf("successful install")
// save successful connection to map
clientControllerMap[*opts] = client
clientControllerMap[*opts] = conn
return client, nil
return conn, nil
}
func DisconnectClient(opts *SSHOpts) error {

View File

@ -707,3 +707,10 @@ type SSHOpts struct {
SSHUser string `json:"sshuser"`
SSHPort int `json:"sshport,omitempty"`
}
func (opts SSHOpts) String() string {
if opts.SSHPort == 0 {
return fmt.Sprintf("%s@%s", opts.SSHUser, opts.SSHHost)
}
return fmt.Sprintf("%s@%s:%d", opts.SSHUser, opts.SSHHost, opts.SSHPort)
}

View File

@ -22,14 +22,10 @@ import (
"github.com/wavetermdev/thenextwave/pkg/remote"
"github.com/wavetermdev/thenextwave/pkg/util/shellutil"
"github.com/wavetermdev/thenextwave/pkg/wavebase"
"golang.org/x/crypto/ssh"
"github.com/wavetermdev/thenextwave/pkg/wshutil"
"github.com/wavetermdev/thenextwave/pkg/wstore"
)
type TermSize struct {
Rows int `json:"rows"`
Cols int `json:"cols"`
}
type CommandOptsType struct {
Interactive bool `json:"interactive,omitempty"`
Login bool `json:"login,omitempty"`
@ -153,7 +149,8 @@ func (pp *PipePty) WriteString(s string) (n int, err error) {
return pp.Write([]byte(s))
}
func StartRemoteShellProc(termSize TermSize, cmdStr string, cmdOpts CommandOptsType, client *ssh.Client) (*ShellProc, error) {
func StartRemoteShellProc(termSize wstore.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *remote.SSHConn) (*ShellProc, error) {
client := conn.Client
shellPath, err := remote.DetectShell(client)
if err != nil {
return nil, err
@ -239,6 +236,12 @@ func StartRemoteShellProc(termSize TermSize, cmdStr string, cmdOpts CommandOptsT
cmdCombined = fmt.Sprintf(`ZDOTDIR="%s/.waveterm/zsh-integration" %s`, homeDir, cmdCombined)
}
jwtToken, ok := cmdOpts.Env[wshutil.WaveJwtTokenVarName]
if !ok {
return nil, fmt.Errorf("no jwt token provided to connection")
}
cmdCombined = fmt.Sprintf(`%s=%s %s`, wshutil.WaveJwtTokenVarName, jwtToken, cmdCombined)
session.RequestPty("xterm-256color", termSize.Rows, termSize.Cols, nil)
sessionWrap := SessionWrap{session, cmdCombined, pipePty, pipePty}
@ -262,7 +265,7 @@ func isBashShell(shellPath string) bool {
return strings.Contains(shellBase, "bash")
}
func StartShellProc(termSize TermSize, cmdStr string, cmdOpts CommandOptsType) (*ShellProc, error) {
func StartShellProc(termSize wstore.TermSize, cmdStr string, cmdOpts CommandOptsType) (*ShellProc, error) {
shellutil.InitCustomShellStartupFiles()
var ecmd *exec.Cmd
var shellOpts []string
@ -324,7 +327,7 @@ func StartShellProc(termSize TermSize, cmdStr string, cmdOpts CommandOptsType) (
return &ShellProc{Cmd: CmdWrap{ecmd, cmdPty}, CloseOnce: &sync.Once{}, DoneCh: make(chan any)}, nil
}
func RunSimpleCmdInPty(ecmd *exec.Cmd, termSize TermSize) ([]byte, error) {
func RunSimpleCmdInPty(ecmd *exec.Cmd, termSize wstore.TermSize) ([]byte, error) {
ecmd.Env = os.Environ()
shellutil.UpdateCmdEnv(ecmd, shellutil.WaveshellLocalEnvVars(shellutil.DefaultTermType))
if termSize.Rows == 0 || termSize.Cols == 0 {

View File

@ -5,8 +5,10 @@ package utilfn
import (
"bytes"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
@ -885,3 +887,15 @@ func WriteTemplateToFile(fileName string, templateText string, vars map[string]s
template.Must(template.New("").Parse(templateText)).Execute(outBuffer, vars)
return os.WriteFile(fileName, outBuffer.Bytes(), 0644)
}
// every byte is 4-bits of randomness
func RandomHexString(numHexDigits int) (string, error) {
numBytes := (numHexDigits + 1) / 2 // Calculate the number of bytes needed
bytes := make([]byte, numBytes)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
hexStr := hex.EncodeToString(bytes)
return hexStr[:numHexDigits], nil // Return the exact number of hex digits
}

View File

@ -7,10 +7,10 @@ import (
"fmt"
"reflect"
"github.com/wavetermdev/thenextwave/pkg/shellexec"
"github.com/wavetermdev/thenextwave/pkg/tsgen/tsgenmeta"
"github.com/wavetermdev/thenextwave/pkg/util/utilfn"
"github.com/wavetermdev/thenextwave/pkg/wshutil"
"github.com/wavetermdev/thenextwave/pkg/wstore"
)
const (
@ -45,9 +45,9 @@ func (cmd *WSRpcCommand) GetWSCommand() string {
}
type SetBlockTermSizeWSCommand struct {
WSCommand string `json:"wscommand" tstype:"\"setblocktermsize\""`
BlockId string `json:"blockid"`
TermSize shellexec.TermSize `json:"termsize"`
WSCommand string `json:"wscommand" tstype:"\"setblocktermsize\""`
BlockId string `json:"blockid"`
TermSize wstore.TermSize `json:"termsize"`
}
func (cmd *SetBlockTermSizeWSCommand) GetWSCommand() string {

View File

@ -11,7 +11,6 @@ import (
"reflect"
"github.com/wavetermdev/thenextwave/pkg/ijson"
"github.com/wavetermdev/thenextwave/pkg/shellexec"
"github.com/wavetermdev/thenextwave/pkg/waveobj"
"github.com/wavetermdev/thenextwave/pkg/wstore"
)
@ -180,10 +179,10 @@ type CommandBlockRestartData struct {
}
type CommandBlockInputData struct {
BlockId string `json:"blockid" wshcontext:"BlockId"`
InputData64 string `json:"inputdata64,omitempty"`
SigName string `json:"signame,omitempty"`
TermSize *shellexec.TermSize `json:"termsize,omitempty"`
BlockId string `json:"blockid" wshcontext:"BlockId"`
InputData64 string `json:"inputdata64,omitempty"`
SigName string `json:"signame,omitempty"`
TermSize *wstore.TermSize `json:"termsize,omitempty"`
}
type CommandFileData struct {

View File

@ -4,8 +4,6 @@
package wshserver
import (
"log"
"net"
"sync"
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
@ -17,44 +15,6 @@ const (
DefaultInputChSize = 32
)
func handleDomainSocketClient(conn net.Conn) {
proxy := wshutil.MakeRpcProxy()
go func() {
writeErr := wshutil.AdaptOutputChToStream(proxy.ToRemoteCh, conn)
if writeErr != nil {
log.Printf("error writing to domain socket: %v\n", writeErr)
}
}()
go func() {
// when input is closed, close the connection
defer conn.Close()
wshutil.AdaptStreamToMsgCh(conn, proxy.FromRemoteCh)
}()
rpcCtx, err := proxy.HandleAuthentication()
if err != nil {
conn.Close()
log.Printf("error handling authentication: %v\n", err)
return
}
// now that we're authenticated, set the ctx and attach to the router
log.Printf("domain socket connection authenticated: %#v\n", rpcCtx)
proxy.SetRpcContext(rpcCtx)
wshutil.DefaultRouter.RegisterRoute("controller:"+rpcCtx.BlockId, proxy)
}
func RunWshRpcOverListener(listener net.Listener) {
defer log.Printf("domain socket listener shutting down\n")
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("error accepting connection: %v\n", err)
continue
}
log.Print("got domain socket connection\n")
go handleDomainSocketClient(conn)
}
}
var waveSrvClient_Singleton *wshutil.WshRpc
var waveSrvClient_Once = &sync.Once{}

View File

@ -302,6 +302,44 @@ func mapClaimsToRpcContext(claims jwt.MapClaims) *wshrpc.RpcContext {
return rpcCtx
}
func RunWshRpcOverListener(listener net.Listener) {
defer log.Printf("domain socket listener shutting down\n")
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("error accepting connection: %v\n", err)
continue
}
log.Print("got domain socket connection\n")
go handleDomainSocketClient(conn)
}
}
func handleDomainSocketClient(conn net.Conn) {
proxy := MakeRpcProxy()
go func() {
writeErr := AdaptOutputChToStream(proxy.ToRemoteCh, conn)
if writeErr != nil {
log.Printf("error writing to domain socket: %v\n", writeErr)
}
}()
go func() {
// when input is closed, close the connection
defer conn.Close()
AdaptStreamToMsgCh(conn, proxy.FromRemoteCh)
}()
rpcCtx, err := proxy.HandleAuthentication()
if err != nil {
conn.Close()
log.Printf("error handling authentication: %v\n", err)
return
}
// now that we're authenticated, set the ctx and attach to the router
log.Printf("domain socket connection authenticated: %#v\n", rpcCtx)
proxy.SetRpcContext(rpcCtx)
DefaultRouter.RegisterRoute("controller:"+rpcCtx.BlockId, proxy)
}
// only for use on client
func ExtractUnverifiedRpcContext(tokenStr string) (*wshrpc.RpcContext, error) {
// this happens on the client who does not have access to the secret key

View File

@ -8,7 +8,6 @@ import (
"fmt"
"reflect"
"github.com/wavetermdev/thenextwave/pkg/shellexec"
"github.com/wavetermdev/thenextwave/pkg/waveobj"
)
@ -220,8 +219,8 @@ type StickerType struct {
}
type RuntimeOpts struct {
TermSize shellexec.TermSize `json:"termsize,omitempty"`
WinSize WinSize `json:"winsize,omitempty"`
TermSize TermSize `json:"termsize,omitempty"`
WinSize WinSize `json:"winsize,omitempty"`
}
type Point struct {
@ -257,3 +256,8 @@ func AllWaveObjTypes() []reflect.Type {
reflect.TypeOf(&LayoutState{}),
}
}
type TermSize struct {
Rows int `json:"rows"`
Cols int `json:"cols"`
}