working impl of new connserver router, many bugs fixed, new functionality added to wshrouter to inject packets and help with proxy auth

This commit is contained in:
sawka 2024-10-22 00:34:39 -07:00
parent 64262cbb74
commit c874a6e302
6 changed files with 223 additions and 116 deletions

View File

@ -14,12 +14,12 @@ import (
"time"
"github.com/spf13/cobra"
"github.com/wavetermdev/waveterm/pkg/util/packetparser"
"github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshremote"
"github.com/wavetermdev/waveterm/pkg/wshutil"
"golang.org/x/crypto/ssh/terminal"
)
var serverCmd = &cobra.Command{
@ -52,7 +52,6 @@ func MakeRemoteUnixListener() (net.Listener, error) {
func handleNewListenerConn(conn net.Conn, router *wshutil.WshRouter) {
var routeIdContainer atomic.Pointer[string]
proxy := wshutil.MakeRpcProxy()
upstreamClient := router.GetUpstreamClient().(*wshutil.WshRpc)
go func() {
writeErr := wshutil.AdaptOutputChToStream(proxy.ToRemoteCh, conn)
if writeErr != nil {
@ -75,12 +74,12 @@ func handleNewListenerConn(conn net.Conn, router *wshutil.WshRouter) {
AuthToken: proxy.GetAuthToken(),
}
disposeBytes, _ := json.Marshal(disposeMsg)
upstreamClient.SendRpcMessage(disposeBytes)
router.InjectMessage(disposeBytes, *routeIdPtr)
}
}()
wshutil.AdaptStreamToMsgCh(conn, proxy.FromRemoteCh)
}()
routeId, err := proxy.HandleClientProxyAuth(upstreamClient)
routeId, err := proxy.HandleClientProxyAuth(router)
if err != nil {
log.Printf("error handling client proxy auth: %v\n", err)
conn.Close()
@ -118,59 +117,54 @@ func setupConnServerRpcClientWithRouter(router *wshutil.WshRouter) (*wshutil.Wsh
if err != nil {
return nil, fmt.Errorf("error extracting rpc context from %s: %v", wshutil.WaveJwtTokenVarName, err)
}
RpcContext = *rpcCtx
authRtn, err := router.HandleProxyAuth(jwtToken)
if err != nil {
return nil, fmt.Errorf("error handling proxy auth: %v", err)
}
inputCh := make(chan []byte, wshutil.DefaultInputChSize)
outputCh := make(chan []byte, wshutil.DefaultOutputChSize)
connServerClient := wshutil.MakeWshRpc(inputCh, outputCh, *rpcCtx, &wshremote.ServerImpl{LogWriter: os.Stdout})
upstreamClient := router.GetUpstreamClient().(*wshutil.WshRpc)
resp, err := wshclient.AuthenticateCommand(upstreamClient, jwtToken, nil)
if err != nil {
return nil, fmt.Errorf("error authenticating connserver: %v", err)
}
if resp.AuthToken == "" {
return nil, fmt.Errorf("no auth token returned from connserver")
}
log.Printf("authenticated connserver route: %s\n", resp.RouteId)
connServerClient.SetAuthToken(resp.AuthToken)
router.RegisterRoute(resp.RouteId, connServerClient, false)
connServerClient.SetAuthToken(authRtn.AuthToken)
router.RegisterRoute(authRtn.RouteId, connServerClient, false)
wshclient.RouteAnnounceCommand(connServerClient, nil)
return connServerClient, nil
}
func serverRunRouter() error {
isTerminal := terminal.IsTerminal(int(os.Stdout.Fd()))
if isTerminal {
wshutil.SetTermRawMode()
}
termClient, reader := wshutil.SetupTerminalRpcClient(&wshremote.ServerImpl{LogWriter: os.Stdout})
router := wshutil.NewWshRouter()
termProxy := wshutil.MakeRpcProxy()
rawCh := make(chan []byte, wshutil.DefaultOutputChSize)
go packetparser.Parse(os.Stdin, termProxy.FromRemoteCh, rawCh)
go func() {
// just ignore and drain the reader
var errorCode int
defer wshutil.DoShutdown("", errorCode, true)
for {
buf := make([]byte, 4096)
_, err := reader.Read(buf)
if err == io.EOF {
break
}
if err != nil {
errorCode = 1
break
}
for msg := range termProxy.ToRemoteCh {
packetparser.WritePacket(os.Stdout, msg)
}
}()
router := wshutil.NewWshRouter()
router.SetUpstreamClient(termClient)
go func() {
// just ignore and drain the rawCh (stdin)
// when stdin is closed, shutdown
defer wshutil.DoShutdown("", 0, true)
for range rawCh {
// ignore
}
}()
go func() {
for msg := range termProxy.FromRemoteCh {
// send this to the router
router.InjectMessage(msg, wshutil.UpstreamRoute)
}
}()
router.SetUpstreamClient(termProxy)
// now set up the domain socket
unixListener, err := MakeRemoteUnixListener()
if err != nil {
return fmt.Errorf("cannot create unix listener: %v", err)
}
go runListener(unixListener, router)
client, err := setupConnServerRpcClientWithRouter(router)
if err != nil {
return fmt.Errorf("error setting up connserver rpc client: %v", err)
}
go runListener(unixListener, router)
// run the sysinfo loop
wshremote.RunSysInfoLoop(client, client.GetRpcContext().Conn)
select {}

View File

@ -6,7 +6,6 @@ package wshutil
import (
"encoding/json"
"fmt"
"log"
"sync"
"github.com/google/uuid"
@ -37,6 +36,15 @@ func MakeRpcMultiProxy() *WshRpcMultiProxy {
}
}
func (p *WshRpcMultiProxy) DisposeRoutes() {
p.Lock.Lock()
defer p.Lock.Unlock()
for authToken, routeInfo := range p.RouteInfo {
DefaultRouter.UnregisterRoute(routeInfo.RouteId)
delete(p.RouteInfo, authToken)
}
}
func (p *WshRpcMultiProxy) getRouteInfo(authToken string) *multiProxyRouteInfo {
p.Lock.Lock()
defer p.Lock.Unlock()
@ -91,7 +99,6 @@ func (p *WshRpcMultiProxy) handleUnauthMessage(msgBytes []byte) {
if msg.Command == wshrpc.Command_Authenticate {
rpcContext, routeId, err := handleAuthenticationCommand(msg)
if err != nil {
log.Printf("error handling authentication command (multiproxy): %v\n", err)
p.sendResponseError(msg, err)
return
}
@ -104,6 +111,11 @@ func (p *WshRpcMultiProxy) handleUnauthMessage(msgBytes []byte) {
routeInfo.Proxy.SetRpcContext(rpcContext)
p.setRouteInfo(routeInfo.AuthToken, routeInfo)
p.sendAuthResponse(msg, routeId, routeInfo.AuthToken)
go func() {
for msgBytes := range routeInfo.Proxy.ToRemoteCh {
p.ToRemoteCh <- msgBytes
}
}()
DefaultRouter.RegisterRoute(routeId, routeInfo.Proxy, true)
return
}
@ -116,13 +128,15 @@ func (p *WshRpcMultiProxy) handleUnauthMessage(msgBytes []byte) {
p.sendResponseError(msg, fmt.Errorf("invalid auth token"))
return
}
if msg.Source != routeInfo.RouteId {
if msg.Command != "" && msg.Source != routeInfo.RouteId {
p.sendResponseError(msg, fmt.Errorf("invalid source route for auth token"))
return
}
if msg.Command == wshrpc.Command_Dispose {
DefaultRouter.UnregisterRoute(routeInfo.RouteId)
p.removeRouteInfo(msg.AuthToken)
close(routeInfo.Proxy.ToRemoteCh)
close(routeInfo.Proxy.FromRemoteCh)
return
}
routeInfo.Proxy.FromRemoteCh <- msgBytes

View File

@ -10,7 +10,6 @@ import (
"sync"
"github.com/google/uuid"
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
)
@ -113,56 +112,44 @@ func handleAuthenticationCommand(msg RpcMessage) (*wshrpc.RpcContext, string, er
}
// runs on the client (stdio client)
func (p *WshRpcProxy) HandleClientProxyAuth(upstream *WshRpc) (string, error) {
func (p *WshRpcProxy) HandleClientProxyAuth(router *WshRouter) (string, error) {
for {
msgBytes, ok := <-p.FromRemoteCh
if !ok {
return "", fmt.Errorf("remote closed, not authenticated")
}
var msg RpcMessage
err := json.Unmarshal(msgBytes, &msg)
var origMsg RpcMessage
err := json.Unmarshal(msgBytes, &origMsg)
if err != nil {
// nothing to do, can't even send a response since we don't have Source or ReqId
continue
}
if msg.Command == "" {
if origMsg.Command == "" {
// this message is not allowed (protocol error at this point), ignore
continue
}
// we only allow one command "authenticate", everything else returns an error
if msg.Command != wshrpc.Command_Authenticate {
if origMsg.Command != wshrpc.Command_Authenticate {
respErr := fmt.Errorf("connection not authenticated")
p.sendResponseError(msg, respErr)
p.sendResponseError(origMsg, respErr)
continue
}
resp, err := upstream.SendRpcRequest(msg.Command, msg.Data, nil)
authRtn, err := router.HandleProxyAuth(origMsg.Data)
if err != nil {
respErr := fmt.Errorf("error authenticating: %w", err)
p.sendResponseError(msg, respErr)
respErr := fmt.Errorf("error handling proxy auth: %w", err)
p.sendResponseError(origMsg, respErr)
return "", respErr
}
var respData wshrpc.CommandAuthenticateRtnData
err = utilfn.ReUnmarshal(&respData, resp)
if err != nil {
respErr := fmt.Errorf("error unmarshalling authenticate response: %w", err)
p.sendResponseError(msg, respErr)
return "", respErr
}
if respData.AuthToken == "" {
respErr := fmt.Errorf("no auth token in authenticate response")
p.sendResponseError(msg, respErr)
return "", respErr
}
p.SetAuthToken(respData.AuthToken)
p.SetAuthToken(authRtn.AuthToken)
announceMsg := RpcMessage{
Command: wshrpc.Command_RouteAnnounce,
Source: respData.RouteId,
AuthToken: respData.AuthToken,
Source: authRtn.RouteId,
AuthToken: authRtn.AuthToken,
}
announceBytes, _ := json.Marshal(announceMsg)
upstream.SendRpcMessage(announceBytes)
p.sendAuthenticateResponse(msg, respData.RouteId)
return respData.RouteId, nil
router.InjectMessage(announceBytes, authRtn.RouteId)
p.sendAuthenticateResponse(origMsg, authRtn.RouteId)
return authRtn.RouteId, nil
}
}

View File

@ -12,11 +12,14 @@ import (
"sync"
"time"
"github.com/google/uuid"
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
"github.com/wavetermdev/waveterm/pkg/wps"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
)
const DefaultRoute = "wavesrv"
const UpstreamRoute = "upstream"
const SysRoute = "sys" // this route doesn't exist, just a placeholder for system messages
const ElectronRoute = "electron"
@ -36,12 +39,13 @@ type msgAndRoute struct {
}
type WshRouter struct {
Lock *sync.Mutex
RouteMap map[string]AbstractRpcClient // routeid => client
UpstreamClient AbstractRpcClient // upstream client (if we are not the terminal router)
AnnouncedRoutes map[string]string // routeid => local routeid
RpcMap map[string]*routeInfo // rpcid => routeinfo
InputCh chan msgAndRoute
Lock *sync.Mutex
RouteMap map[string]AbstractRpcClient // routeid => client
UpstreamClient AbstractRpcClient // upstream client (if we are not the terminal router)
AnnouncedRoutes map[string]string // routeid => local routeid
RpcMap map[string]*routeInfo // rpcid => routeinfo
SimpleRequestMap map[string]chan *RpcMessage // simple reqid => response channel
InputCh chan msgAndRoute
}
func MakeConnectionRouteId(connId string) string {
@ -68,11 +72,12 @@ var DefaultRouter = NewWshRouter()
func NewWshRouter() *WshRouter {
rtn := &WshRouter{
Lock: &sync.Mutex{},
RouteMap: make(map[string]AbstractRpcClient),
AnnouncedRoutes: make(map[string]string),
RpcMap: make(map[string]*routeInfo),
InputCh: make(chan msgAndRoute, DefaultInputChSize),
Lock: &sync.Mutex{},
RouteMap: make(map[string]AbstractRpcClient),
AnnouncedRoutes: make(map[string]string),
RpcMap: make(map[string]*routeInfo),
SimpleRequestMap: make(map[string]chan *RpcMessage),
InputCh: make(chan msgAndRoute, DefaultInputChSize),
}
go rtn.runServer()
return rtn
@ -237,6 +242,10 @@ func (router *WshRouter) runServer() {
router.sendRoutedMessage(msgBytes, routeInfo.DestRouteId)
continue
} else if msg.ResId != "" {
ok := router.trySimpleResponse(&msg)
if ok {
continue
}
routeInfo := router.getRouteInfo(msg.ResId)
if routeInfo == nil {
// no route info, nothing to do
@ -270,9 +279,9 @@ func (router *WshRouter) WaitForRegister(ctx context.Context, routeId string) er
// this will also consume the output channel of the abstract client
func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient, shouldAnnounce bool) {
if routeId == SysRoute {
if routeId == SysRoute || routeId == UpstreamRoute {
// cannot register sys route
log.Printf("error: WshRouter cannot register sys route\n")
log.Printf("error: WshRouter cannot register %s route\n", routeId)
return
}
log.Printf("[router] registering wsh route %q\n", routeId)
@ -352,3 +361,97 @@ func (router *WshRouter) GetUpstreamClient() AbstractRpcClient {
defer router.Lock.Unlock()
return router.UpstreamClient
}
func (router *WshRouter) InjectMessage(msgBytes []byte, fromRouteId string) {
router.InputCh <- msgAndRoute{msgBytes: msgBytes, fromRouteId: fromRouteId}
}
func (router *WshRouter) registerSimpleRequest(reqId string) chan *RpcMessage {
router.Lock.Lock()
defer router.Lock.Unlock()
rtn := make(chan *RpcMessage, 1)
router.SimpleRequestMap[reqId] = rtn
return rtn
}
func (router *WshRouter) trySimpleResponse(msg *RpcMessage) bool {
router.Lock.Lock()
defer router.Lock.Unlock()
respCh := router.SimpleRequestMap[msg.ResId]
if respCh == nil {
return false
}
respCh <- msg
delete(router.SimpleRequestMap, msg.ResId)
return true
}
func (router *WshRouter) clearSimpleRequest(reqId string) {
router.Lock.Lock()
defer router.Lock.Unlock()
delete(router.SimpleRequestMap, reqId)
}
func (router *WshRouter) RunSimpleRawCommand(ctx context.Context, msg RpcMessage, fromRouteId string) (*RpcMessage, error) {
if msg.Command == "" {
return nil, errors.New("no command")
}
msgBytes, err := json.Marshal(msg)
if err != nil {
return nil, err
}
var respCh chan *RpcMessage
if msg.ReqId != "" {
respCh = router.registerSimpleRequest(msg.ReqId)
}
router.InjectMessage(msgBytes, fromRouteId)
if respCh == nil {
return nil, nil
}
select {
case <-ctx.Done():
router.clearSimpleRequest(msg.ReqId)
return nil, ctx.Err()
case resp := <-respCh:
if resp.Error != "" {
return nil, errors.New(resp.Error)
}
return resp, nil
}
}
func (router *WshRouter) HandleProxyAuth(jwtTokenAny any) (*wshrpc.CommandAuthenticateRtnData, error) {
if jwtTokenAny == nil {
return nil, errors.New("no jwt token")
}
jwtToken, ok := jwtTokenAny.(string)
if !ok {
return nil, errors.New("jwt token not a string")
}
if jwtToken == "" {
return nil, errors.New("empty jwt token")
}
msg := RpcMessage{
Command: wshrpc.Command_Authenticate,
ReqId: uuid.New().String(),
Data: jwtToken,
}
ctx, cancelFn := context.WithTimeout(context.Background(), DefaultTimeoutMs*time.Millisecond)
defer cancelFn()
resp, err := router.RunSimpleRawCommand(ctx, msg, "")
if err != nil {
return nil, err
}
if resp == nil || resp.Data == nil {
return nil, errors.New("no data in authenticate response")
}
var respData wshrpc.CommandAuthenticateRtnData
err = utilfn.ReUnmarshal(&respData, resp.Data)
if err != nil {
return nil, fmt.Errorf("error unmarshalling authenticate response: %v", err)
}
if respData.AuthToken == "" {
return nil, errors.New("no auth token in authenticate response")
}
return &respData, nil
}

View File

@ -50,6 +50,8 @@ type WshRpc struct {
ServerImpl ServerImpl
EventListener *EventListener
ResponseHandlerMap map[string]*RpcResponseHandler // reqId => handler
Debug bool
DebugName string
}
type wshRpcContextKey struct{}
@ -333,6 +335,9 @@ func (w *WshRpc) handleRequest(req *RpcMessage) {
func (w *WshRpc) runServer() {
defer close(w.OutputCh)
for msgBytes := range w.InputCh {
if w.Debug {
log.Printf("[%s] received message: %s\n", w.DebugName, string(msgBytes))
}
var msg RpcMessage
err := json.Unmarshal(msgBytes, &msg)
if err != nil {
@ -465,8 +470,9 @@ func (handler *RpcRequestHandler) SendCancel() {
}
}()
msg := &RpcMessage{
Cancel: true,
ReqId: handler.reqId,
Cancel: true,
ReqId: handler.reqId,
AuthToken: handler.w.GetAuthToken(),
}
barr, _ := json.Marshal(msg) // will never fail
handler.w.OutputCh <- barr
@ -560,6 +566,7 @@ func (handler *RpcResponseHandler) SendMessage(msg string) {
Data: wshrpc.CommandMessageData{
Message: msg,
},
AuthToken: handler.w.GetAuthToken(),
}
msgBytes, _ := json.Marshal(rpcMsg) // will never fail
handler.w.OutputCh <- msgBytes
@ -583,9 +590,10 @@ func (handler *RpcResponseHandler) SendResponse(data any, done bool) error {
defer handler.close()
}
msg := &RpcMessage{
ResId: handler.reqId,
Data: data,
Cont: !done,
ResId: handler.reqId,
Data: data,
Cont: !done,
AuthToken: handler.w.GetAuthToken(),
}
barr, err := json.Marshal(msg)
if err != nil {
@ -608,8 +616,9 @@ func (handler *RpcResponseHandler) SendResponseError(err error) {
}
defer handler.close()
msg := &RpcMessage{
ResId: handler.reqId,
Error: err.Error(),
ResId: handler.reqId,
Error: err.Error(),
AuthToken: handler.w.GetAuthToken(),
}
barr, _ := json.Marshal(msg) // will never fail
handler.w.OutputCh <- barr
@ -670,11 +679,12 @@ func (w *WshRpc) SendComplexRequest(command string, data any, opts *wshrpc.RpcOp
handler.reqId = uuid.New().String()
}
req := &RpcMessage{
Command: command,
ReqId: handler.reqId,
Data: data,
Timeout: timeoutMs,
Route: opts.Route,
Command: command,
ReqId: handler.reqId,
Data: data,
Timeout: timeoutMs,
Route: opts.Route,
AuthToken: w.GetAuthToken(),
}
barr, err := json.Marshal(req)
if err != nil {

View File

@ -4,7 +4,6 @@
package wshutil
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
@ -13,7 +12,6 @@ import (
"net"
"os"
"os/signal"
"strings"
"sync"
"sync/atomic"
"syscall"
@ -21,6 +19,7 @@ import (
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/wavetermdev/waveterm/pkg/util/packetparser"
"github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"golang.org/x/term"
@ -198,11 +197,26 @@ func SetupTerminalRpcClient(serverImpl ServerImpl) (*WshRpc, io.Reader) {
for msg := range outputCh {
barr := EncodeWaveOSCBytes(WaveOSC, msg)
os.Stdout.Write(barr)
os.Stdout.Write([]byte{'\n'})
}
}()
return rpcClient, ptyBuf
}
func SetupPacketRpcClient(input io.Reader, output io.Writer, serverImpl ServerImpl) (*WshRpc, chan []byte) {
messageCh := make(chan []byte, DefaultInputChSize)
outputCh := make(chan []byte, DefaultOutputChSize)
rawCh := make(chan []byte, DefaultOutputChSize)
rpcClient := MakeWshRpc(messageCh, outputCh, wshrpc.RpcContext{}, serverImpl)
go packetparser.Parse(input, messageCh, rawCh)
go func() {
for msg := range outputCh {
packetparser.WritePacket(output, msg)
}
}()
return rpcClient, rawCh
}
func SetupConnRpcClient(conn net.Conn, serverImpl ServerImpl) (*WshRpc, chan error, error) {
inputCh := make(chan []byte, DefaultInputChSize)
outputCh := make(chan []byte, DefaultOutputChSize)
@ -364,15 +378,16 @@ type WriteFlusher interface {
// blocking, returns if there is an error, or on EOF of input
func HandleStdIOClient(logName string, input io.Reader, output io.Writer) {
log.Printf("[%s] starting (HandleStdIOClient)\n", logName)
proxy := MakeRpcMultiProxy()
ptyBuffer := MakePtyBuffer(WaveOSCPrefix, input, proxy.FromRemoteRawCh)
rawCh := make(chan []byte, DefaultInputChSize)
go packetparser.Parse(input, proxy.FromRemoteRawCh, rawCh)
doneCh := make(chan struct{})
var doneOnce sync.Once
closeDoneCh := func() {
doneOnce.Do(func() {
close(doneCh)
})
proxy.DisposeRoutes()
}
go func() {
proxy.RunUnauthLoop()
@ -380,9 +395,7 @@ func HandleStdIOClient(logName string, input io.Reader, output io.Writer) {
go func() {
defer closeDoneCh()
for msg := range proxy.ToRemoteCh {
log.Printf("[%s] sending message: %s\n", logName, string(msg))
barr := EncodeWaveOSCBytes(WaveServerOSC, msg)
_, err := output.Write(barr)
err := packetparser.WritePacket(output, msg)
if err != nil {
log.Printf("[%s] error writing to output: %v\n", logName, err)
break
@ -391,22 +404,8 @@ func HandleStdIOClient(logName string, input io.Reader, output io.Writer) {
}()
go func() {
defer closeDoneCh()
br := bufio.NewReader(ptyBuffer.InputReader)
for {
line, err := br.ReadString('\n')
if line != "" {
if !strings.HasSuffix(line, "\n") {
line += "\n"
}
log.Printf("[%s] %s", logName, line)
}
if err == io.EOF {
break
}
if err != nil {
log.Printf("[%s] error reading from pty buffer: %v\n", logName, err)
break
}
for msg := range rawCh {
log.Printf("[%s:stdout] %s", logName, msg)
}
}()
<-doneCh