mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-21 21:32:13 +01:00
working impl of new connserver router, many bugs fixed, new functionality added to wshrouter to inject packets and help with proxy auth
This commit is contained in:
parent
64262cbb74
commit
c874a6e302
@ -14,12 +14,12 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/util/packetparser"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
|
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshremote"
|
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshremote"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||||
"golang.org/x/crypto/ssh/terminal"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var serverCmd = &cobra.Command{
|
var serverCmd = &cobra.Command{
|
||||||
@ -52,7 +52,6 @@ func MakeRemoteUnixListener() (net.Listener, error) {
|
|||||||
func handleNewListenerConn(conn net.Conn, router *wshutil.WshRouter) {
|
func handleNewListenerConn(conn net.Conn, router *wshutil.WshRouter) {
|
||||||
var routeIdContainer atomic.Pointer[string]
|
var routeIdContainer atomic.Pointer[string]
|
||||||
proxy := wshutil.MakeRpcProxy()
|
proxy := wshutil.MakeRpcProxy()
|
||||||
upstreamClient := router.GetUpstreamClient().(*wshutil.WshRpc)
|
|
||||||
go func() {
|
go func() {
|
||||||
writeErr := wshutil.AdaptOutputChToStream(proxy.ToRemoteCh, conn)
|
writeErr := wshutil.AdaptOutputChToStream(proxy.ToRemoteCh, conn)
|
||||||
if writeErr != nil {
|
if writeErr != nil {
|
||||||
@ -75,12 +74,12 @@ func handleNewListenerConn(conn net.Conn, router *wshutil.WshRouter) {
|
|||||||
AuthToken: proxy.GetAuthToken(),
|
AuthToken: proxy.GetAuthToken(),
|
||||||
}
|
}
|
||||||
disposeBytes, _ := json.Marshal(disposeMsg)
|
disposeBytes, _ := json.Marshal(disposeMsg)
|
||||||
upstreamClient.SendRpcMessage(disposeBytes)
|
router.InjectMessage(disposeBytes, *routeIdPtr)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
wshutil.AdaptStreamToMsgCh(conn, proxy.FromRemoteCh)
|
wshutil.AdaptStreamToMsgCh(conn, proxy.FromRemoteCh)
|
||||||
}()
|
}()
|
||||||
routeId, err := proxy.HandleClientProxyAuth(upstreamClient)
|
routeId, err := proxy.HandleClientProxyAuth(router)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("error handling client proxy auth: %v\n", err)
|
log.Printf("error handling client proxy auth: %v\n", err)
|
||||||
conn.Close()
|
conn.Close()
|
||||||
@ -118,59 +117,54 @@ func setupConnServerRpcClientWithRouter(router *wshutil.WshRouter) (*wshutil.Wsh
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error extracting rpc context from %s: %v", wshutil.WaveJwtTokenVarName, err)
|
return nil, fmt.Errorf("error extracting rpc context from %s: %v", wshutil.WaveJwtTokenVarName, err)
|
||||||
}
|
}
|
||||||
RpcContext = *rpcCtx
|
authRtn, err := router.HandleProxyAuth(jwtToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error handling proxy auth: %v", err)
|
||||||
|
}
|
||||||
inputCh := make(chan []byte, wshutil.DefaultInputChSize)
|
inputCh := make(chan []byte, wshutil.DefaultInputChSize)
|
||||||
outputCh := make(chan []byte, wshutil.DefaultOutputChSize)
|
outputCh := make(chan []byte, wshutil.DefaultOutputChSize)
|
||||||
connServerClient := wshutil.MakeWshRpc(inputCh, outputCh, *rpcCtx, &wshremote.ServerImpl{LogWriter: os.Stdout})
|
connServerClient := wshutil.MakeWshRpc(inputCh, outputCh, *rpcCtx, &wshremote.ServerImpl{LogWriter: os.Stdout})
|
||||||
upstreamClient := router.GetUpstreamClient().(*wshutil.WshRpc)
|
connServerClient.SetAuthToken(authRtn.AuthToken)
|
||||||
resp, err := wshclient.AuthenticateCommand(upstreamClient, jwtToken, nil)
|
router.RegisterRoute(authRtn.RouteId, connServerClient, false)
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error authenticating connserver: %v", err)
|
|
||||||
}
|
|
||||||
if resp.AuthToken == "" {
|
|
||||||
return nil, fmt.Errorf("no auth token returned from connserver")
|
|
||||||
}
|
|
||||||
log.Printf("authenticated connserver route: %s\n", resp.RouteId)
|
|
||||||
connServerClient.SetAuthToken(resp.AuthToken)
|
|
||||||
router.RegisterRoute(resp.RouteId, connServerClient, false)
|
|
||||||
wshclient.RouteAnnounceCommand(connServerClient, nil)
|
wshclient.RouteAnnounceCommand(connServerClient, nil)
|
||||||
return connServerClient, nil
|
return connServerClient, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func serverRunRouter() error {
|
func serverRunRouter() error {
|
||||||
isTerminal := terminal.IsTerminal(int(os.Stdout.Fd()))
|
router := wshutil.NewWshRouter()
|
||||||
if isTerminal {
|
termProxy := wshutil.MakeRpcProxy()
|
||||||
wshutil.SetTermRawMode()
|
rawCh := make(chan []byte, wshutil.DefaultOutputChSize)
|
||||||
}
|
go packetparser.Parse(os.Stdin, termProxy.FromRemoteCh, rawCh)
|
||||||
termClient, reader := wshutil.SetupTerminalRpcClient(&wshremote.ServerImpl{LogWriter: os.Stdout})
|
|
||||||
go func() {
|
go func() {
|
||||||
// just ignore and drain the reader
|
for msg := range termProxy.ToRemoteCh {
|
||||||
var errorCode int
|
packetparser.WritePacket(os.Stdout, msg)
|
||||||
defer wshutil.DoShutdown("", errorCode, true)
|
|
||||||
for {
|
|
||||||
buf := make([]byte, 4096)
|
|
||||||
_, err := reader.Read(buf)
|
|
||||||
if err == io.EOF {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
errorCode = 1
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
router := wshutil.NewWshRouter()
|
go func() {
|
||||||
router.SetUpstreamClient(termClient)
|
// just ignore and drain the rawCh (stdin)
|
||||||
|
// when stdin is closed, shutdown
|
||||||
|
defer wshutil.DoShutdown("", 0, true)
|
||||||
|
for range rawCh {
|
||||||
|
// ignore
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
for msg := range termProxy.FromRemoteCh {
|
||||||
|
// send this to the router
|
||||||
|
router.InjectMessage(msg, wshutil.UpstreamRoute)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
router.SetUpstreamClient(termProxy)
|
||||||
// now set up the domain socket
|
// now set up the domain socket
|
||||||
unixListener, err := MakeRemoteUnixListener()
|
unixListener, err := MakeRemoteUnixListener()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot create unix listener: %v", err)
|
return fmt.Errorf("cannot create unix listener: %v", err)
|
||||||
}
|
}
|
||||||
go runListener(unixListener, router)
|
|
||||||
client, err := setupConnServerRpcClientWithRouter(router)
|
client, err := setupConnServerRpcClientWithRouter(router)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error setting up connserver rpc client: %v", err)
|
return fmt.Errorf("error setting up connserver rpc client: %v", err)
|
||||||
}
|
}
|
||||||
|
go runListener(unixListener, router)
|
||||||
// run the sysinfo loop
|
// run the sysinfo loop
|
||||||
wshremote.RunSysInfoLoop(client, client.GetRpcContext().Conn)
|
wshremote.RunSysInfoLoop(client, client.GetRpcContext().Conn)
|
||||||
select {}
|
select {}
|
||||||
|
@ -6,7 +6,6 @@ package wshutil
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@ -37,6 +36,15 @@ func MakeRpcMultiProxy() *WshRpcMultiProxy {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *WshRpcMultiProxy) DisposeRoutes() {
|
||||||
|
p.Lock.Lock()
|
||||||
|
defer p.Lock.Unlock()
|
||||||
|
for authToken, routeInfo := range p.RouteInfo {
|
||||||
|
DefaultRouter.UnregisterRoute(routeInfo.RouteId)
|
||||||
|
delete(p.RouteInfo, authToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (p *WshRpcMultiProxy) getRouteInfo(authToken string) *multiProxyRouteInfo {
|
func (p *WshRpcMultiProxy) getRouteInfo(authToken string) *multiProxyRouteInfo {
|
||||||
p.Lock.Lock()
|
p.Lock.Lock()
|
||||||
defer p.Lock.Unlock()
|
defer p.Lock.Unlock()
|
||||||
@ -91,7 +99,6 @@ func (p *WshRpcMultiProxy) handleUnauthMessage(msgBytes []byte) {
|
|||||||
if msg.Command == wshrpc.Command_Authenticate {
|
if msg.Command == wshrpc.Command_Authenticate {
|
||||||
rpcContext, routeId, err := handleAuthenticationCommand(msg)
|
rpcContext, routeId, err := handleAuthenticationCommand(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("error handling authentication command (multiproxy): %v\n", err)
|
|
||||||
p.sendResponseError(msg, err)
|
p.sendResponseError(msg, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -104,6 +111,11 @@ func (p *WshRpcMultiProxy) handleUnauthMessage(msgBytes []byte) {
|
|||||||
routeInfo.Proxy.SetRpcContext(rpcContext)
|
routeInfo.Proxy.SetRpcContext(rpcContext)
|
||||||
p.setRouteInfo(routeInfo.AuthToken, routeInfo)
|
p.setRouteInfo(routeInfo.AuthToken, routeInfo)
|
||||||
p.sendAuthResponse(msg, routeId, routeInfo.AuthToken)
|
p.sendAuthResponse(msg, routeId, routeInfo.AuthToken)
|
||||||
|
go func() {
|
||||||
|
for msgBytes := range routeInfo.Proxy.ToRemoteCh {
|
||||||
|
p.ToRemoteCh <- msgBytes
|
||||||
|
}
|
||||||
|
}()
|
||||||
DefaultRouter.RegisterRoute(routeId, routeInfo.Proxy, true)
|
DefaultRouter.RegisterRoute(routeId, routeInfo.Proxy, true)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -116,13 +128,15 @@ func (p *WshRpcMultiProxy) handleUnauthMessage(msgBytes []byte) {
|
|||||||
p.sendResponseError(msg, fmt.Errorf("invalid auth token"))
|
p.sendResponseError(msg, fmt.Errorf("invalid auth token"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if msg.Source != routeInfo.RouteId {
|
if msg.Command != "" && msg.Source != routeInfo.RouteId {
|
||||||
p.sendResponseError(msg, fmt.Errorf("invalid source route for auth token"))
|
p.sendResponseError(msg, fmt.Errorf("invalid source route for auth token"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if msg.Command == wshrpc.Command_Dispose {
|
if msg.Command == wshrpc.Command_Dispose {
|
||||||
DefaultRouter.UnregisterRoute(routeInfo.RouteId)
|
DefaultRouter.UnregisterRoute(routeInfo.RouteId)
|
||||||
p.removeRouteInfo(msg.AuthToken)
|
p.removeRouteInfo(msg.AuthToken)
|
||||||
|
close(routeInfo.Proxy.ToRemoteCh)
|
||||||
|
close(routeInfo.Proxy.FromRemoteCh)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
routeInfo.Proxy.FromRemoteCh <- msgBytes
|
routeInfo.Proxy.FromRemoteCh <- msgBytes
|
||||||
|
@ -10,7 +10,6 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
|
|
||||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -113,56 +112,44 @@ func handleAuthenticationCommand(msg RpcMessage) (*wshrpc.RpcContext, string, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
// runs on the client (stdio client)
|
// runs on the client (stdio client)
|
||||||
func (p *WshRpcProxy) HandleClientProxyAuth(upstream *WshRpc) (string, error) {
|
func (p *WshRpcProxy) HandleClientProxyAuth(router *WshRouter) (string, error) {
|
||||||
for {
|
for {
|
||||||
msgBytes, ok := <-p.FromRemoteCh
|
msgBytes, ok := <-p.FromRemoteCh
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", fmt.Errorf("remote closed, not authenticated")
|
return "", fmt.Errorf("remote closed, not authenticated")
|
||||||
}
|
}
|
||||||
var msg RpcMessage
|
var origMsg RpcMessage
|
||||||
err := json.Unmarshal(msgBytes, &msg)
|
err := json.Unmarshal(msgBytes, &origMsg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// nothing to do, can't even send a response since we don't have Source or ReqId
|
// nothing to do, can't even send a response since we don't have Source or ReqId
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if msg.Command == "" {
|
if origMsg.Command == "" {
|
||||||
// this message is not allowed (protocol error at this point), ignore
|
// this message is not allowed (protocol error at this point), ignore
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// we only allow one command "authenticate", everything else returns an error
|
// we only allow one command "authenticate", everything else returns an error
|
||||||
if msg.Command != wshrpc.Command_Authenticate {
|
if origMsg.Command != wshrpc.Command_Authenticate {
|
||||||
respErr := fmt.Errorf("connection not authenticated")
|
respErr := fmt.Errorf("connection not authenticated")
|
||||||
p.sendResponseError(msg, respErr)
|
p.sendResponseError(origMsg, respErr)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
resp, err := upstream.SendRpcRequest(msg.Command, msg.Data, nil)
|
authRtn, err := router.HandleProxyAuth(origMsg.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
respErr := fmt.Errorf("error authenticating: %w", err)
|
respErr := fmt.Errorf("error handling proxy auth: %w", err)
|
||||||
p.sendResponseError(msg, respErr)
|
p.sendResponseError(origMsg, respErr)
|
||||||
return "", respErr
|
return "", respErr
|
||||||
}
|
}
|
||||||
var respData wshrpc.CommandAuthenticateRtnData
|
p.SetAuthToken(authRtn.AuthToken)
|
||||||
err = utilfn.ReUnmarshal(&respData, resp)
|
|
||||||
if err != nil {
|
|
||||||
respErr := fmt.Errorf("error unmarshalling authenticate response: %w", err)
|
|
||||||
p.sendResponseError(msg, respErr)
|
|
||||||
return "", respErr
|
|
||||||
}
|
|
||||||
if respData.AuthToken == "" {
|
|
||||||
respErr := fmt.Errorf("no auth token in authenticate response")
|
|
||||||
p.sendResponseError(msg, respErr)
|
|
||||||
return "", respErr
|
|
||||||
}
|
|
||||||
p.SetAuthToken(respData.AuthToken)
|
|
||||||
announceMsg := RpcMessage{
|
announceMsg := RpcMessage{
|
||||||
Command: wshrpc.Command_RouteAnnounce,
|
Command: wshrpc.Command_RouteAnnounce,
|
||||||
Source: respData.RouteId,
|
Source: authRtn.RouteId,
|
||||||
AuthToken: respData.AuthToken,
|
AuthToken: authRtn.AuthToken,
|
||||||
}
|
}
|
||||||
announceBytes, _ := json.Marshal(announceMsg)
|
announceBytes, _ := json.Marshal(announceMsg)
|
||||||
upstream.SendRpcMessage(announceBytes)
|
router.InjectMessage(announceBytes, authRtn.RouteId)
|
||||||
p.sendAuthenticateResponse(msg, respData.RouteId)
|
p.sendAuthenticateResponse(origMsg, authRtn.RouteId)
|
||||||
return respData.RouteId, nil
|
return authRtn.RouteId, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -12,11 +12,14 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wps"
|
"github.com/wavetermdev/waveterm/pkg/wps"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||||
)
|
)
|
||||||
|
|
||||||
const DefaultRoute = "wavesrv"
|
const DefaultRoute = "wavesrv"
|
||||||
|
const UpstreamRoute = "upstream"
|
||||||
const SysRoute = "sys" // this route doesn't exist, just a placeholder for system messages
|
const SysRoute = "sys" // this route doesn't exist, just a placeholder for system messages
|
||||||
const ElectronRoute = "electron"
|
const ElectronRoute = "electron"
|
||||||
|
|
||||||
@ -36,12 +39,13 @@ type msgAndRoute struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type WshRouter struct {
|
type WshRouter struct {
|
||||||
Lock *sync.Mutex
|
Lock *sync.Mutex
|
||||||
RouteMap map[string]AbstractRpcClient // routeid => client
|
RouteMap map[string]AbstractRpcClient // routeid => client
|
||||||
UpstreamClient AbstractRpcClient // upstream client (if we are not the terminal router)
|
UpstreamClient AbstractRpcClient // upstream client (if we are not the terminal router)
|
||||||
AnnouncedRoutes map[string]string // routeid => local routeid
|
AnnouncedRoutes map[string]string // routeid => local routeid
|
||||||
RpcMap map[string]*routeInfo // rpcid => routeinfo
|
RpcMap map[string]*routeInfo // rpcid => routeinfo
|
||||||
InputCh chan msgAndRoute
|
SimpleRequestMap map[string]chan *RpcMessage // simple reqid => response channel
|
||||||
|
InputCh chan msgAndRoute
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeConnectionRouteId(connId string) string {
|
func MakeConnectionRouteId(connId string) string {
|
||||||
@ -68,11 +72,12 @@ var DefaultRouter = NewWshRouter()
|
|||||||
|
|
||||||
func NewWshRouter() *WshRouter {
|
func NewWshRouter() *WshRouter {
|
||||||
rtn := &WshRouter{
|
rtn := &WshRouter{
|
||||||
Lock: &sync.Mutex{},
|
Lock: &sync.Mutex{},
|
||||||
RouteMap: make(map[string]AbstractRpcClient),
|
RouteMap: make(map[string]AbstractRpcClient),
|
||||||
AnnouncedRoutes: make(map[string]string),
|
AnnouncedRoutes: make(map[string]string),
|
||||||
RpcMap: make(map[string]*routeInfo),
|
RpcMap: make(map[string]*routeInfo),
|
||||||
InputCh: make(chan msgAndRoute, DefaultInputChSize),
|
SimpleRequestMap: make(map[string]chan *RpcMessage),
|
||||||
|
InputCh: make(chan msgAndRoute, DefaultInputChSize),
|
||||||
}
|
}
|
||||||
go rtn.runServer()
|
go rtn.runServer()
|
||||||
return rtn
|
return rtn
|
||||||
@ -237,6 +242,10 @@ func (router *WshRouter) runServer() {
|
|||||||
router.sendRoutedMessage(msgBytes, routeInfo.DestRouteId)
|
router.sendRoutedMessage(msgBytes, routeInfo.DestRouteId)
|
||||||
continue
|
continue
|
||||||
} else if msg.ResId != "" {
|
} else if msg.ResId != "" {
|
||||||
|
ok := router.trySimpleResponse(&msg)
|
||||||
|
if ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
routeInfo := router.getRouteInfo(msg.ResId)
|
routeInfo := router.getRouteInfo(msg.ResId)
|
||||||
if routeInfo == nil {
|
if routeInfo == nil {
|
||||||
// no route info, nothing to do
|
// no route info, nothing to do
|
||||||
@ -270,9 +279,9 @@ func (router *WshRouter) WaitForRegister(ctx context.Context, routeId string) er
|
|||||||
|
|
||||||
// this will also consume the output channel of the abstract client
|
// this will also consume the output channel of the abstract client
|
||||||
func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient, shouldAnnounce bool) {
|
func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient, shouldAnnounce bool) {
|
||||||
if routeId == SysRoute {
|
if routeId == SysRoute || routeId == UpstreamRoute {
|
||||||
// cannot register sys route
|
// cannot register sys route
|
||||||
log.Printf("error: WshRouter cannot register sys route\n")
|
log.Printf("error: WshRouter cannot register %s route\n", routeId)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Printf("[router] registering wsh route %q\n", routeId)
|
log.Printf("[router] registering wsh route %q\n", routeId)
|
||||||
@ -352,3 +361,97 @@ func (router *WshRouter) GetUpstreamClient() AbstractRpcClient {
|
|||||||
defer router.Lock.Unlock()
|
defer router.Lock.Unlock()
|
||||||
return router.UpstreamClient
|
return router.UpstreamClient
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (router *WshRouter) InjectMessage(msgBytes []byte, fromRouteId string) {
|
||||||
|
router.InputCh <- msgAndRoute{msgBytes: msgBytes, fromRouteId: fromRouteId}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (router *WshRouter) registerSimpleRequest(reqId string) chan *RpcMessage {
|
||||||
|
router.Lock.Lock()
|
||||||
|
defer router.Lock.Unlock()
|
||||||
|
rtn := make(chan *RpcMessage, 1)
|
||||||
|
router.SimpleRequestMap[reqId] = rtn
|
||||||
|
return rtn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (router *WshRouter) trySimpleResponse(msg *RpcMessage) bool {
|
||||||
|
router.Lock.Lock()
|
||||||
|
defer router.Lock.Unlock()
|
||||||
|
respCh := router.SimpleRequestMap[msg.ResId]
|
||||||
|
if respCh == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
respCh <- msg
|
||||||
|
delete(router.SimpleRequestMap, msg.ResId)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (router *WshRouter) clearSimpleRequest(reqId string) {
|
||||||
|
router.Lock.Lock()
|
||||||
|
defer router.Lock.Unlock()
|
||||||
|
delete(router.SimpleRequestMap, reqId)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (router *WshRouter) RunSimpleRawCommand(ctx context.Context, msg RpcMessage, fromRouteId string) (*RpcMessage, error) {
|
||||||
|
if msg.Command == "" {
|
||||||
|
return nil, errors.New("no command")
|
||||||
|
}
|
||||||
|
msgBytes, err := json.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var respCh chan *RpcMessage
|
||||||
|
if msg.ReqId != "" {
|
||||||
|
respCh = router.registerSimpleRequest(msg.ReqId)
|
||||||
|
}
|
||||||
|
router.InjectMessage(msgBytes, fromRouteId)
|
||||||
|
if respCh == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
router.clearSimpleRequest(msg.ReqId)
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case resp := <-respCh:
|
||||||
|
if resp.Error != "" {
|
||||||
|
return nil, errors.New(resp.Error)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (router *WshRouter) HandleProxyAuth(jwtTokenAny any) (*wshrpc.CommandAuthenticateRtnData, error) {
|
||||||
|
if jwtTokenAny == nil {
|
||||||
|
return nil, errors.New("no jwt token")
|
||||||
|
}
|
||||||
|
jwtToken, ok := jwtTokenAny.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("jwt token not a string")
|
||||||
|
}
|
||||||
|
if jwtToken == "" {
|
||||||
|
return nil, errors.New("empty jwt token")
|
||||||
|
}
|
||||||
|
msg := RpcMessage{
|
||||||
|
Command: wshrpc.Command_Authenticate,
|
||||||
|
ReqId: uuid.New().String(),
|
||||||
|
Data: jwtToken,
|
||||||
|
}
|
||||||
|
ctx, cancelFn := context.WithTimeout(context.Background(), DefaultTimeoutMs*time.Millisecond)
|
||||||
|
defer cancelFn()
|
||||||
|
resp, err := router.RunSimpleRawCommand(ctx, msg, "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if resp == nil || resp.Data == nil {
|
||||||
|
return nil, errors.New("no data in authenticate response")
|
||||||
|
}
|
||||||
|
var respData wshrpc.CommandAuthenticateRtnData
|
||||||
|
err = utilfn.ReUnmarshal(&respData, resp.Data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error unmarshalling authenticate response: %v", err)
|
||||||
|
}
|
||||||
|
if respData.AuthToken == "" {
|
||||||
|
return nil, errors.New("no auth token in authenticate response")
|
||||||
|
}
|
||||||
|
return &respData, nil
|
||||||
|
}
|
||||||
|
@ -50,6 +50,8 @@ type WshRpc struct {
|
|||||||
ServerImpl ServerImpl
|
ServerImpl ServerImpl
|
||||||
EventListener *EventListener
|
EventListener *EventListener
|
||||||
ResponseHandlerMap map[string]*RpcResponseHandler // reqId => handler
|
ResponseHandlerMap map[string]*RpcResponseHandler // reqId => handler
|
||||||
|
Debug bool
|
||||||
|
DebugName string
|
||||||
}
|
}
|
||||||
|
|
||||||
type wshRpcContextKey struct{}
|
type wshRpcContextKey struct{}
|
||||||
@ -333,6 +335,9 @@ func (w *WshRpc) handleRequest(req *RpcMessage) {
|
|||||||
func (w *WshRpc) runServer() {
|
func (w *WshRpc) runServer() {
|
||||||
defer close(w.OutputCh)
|
defer close(w.OutputCh)
|
||||||
for msgBytes := range w.InputCh {
|
for msgBytes := range w.InputCh {
|
||||||
|
if w.Debug {
|
||||||
|
log.Printf("[%s] received message: %s\n", w.DebugName, string(msgBytes))
|
||||||
|
}
|
||||||
var msg RpcMessage
|
var msg RpcMessage
|
||||||
err := json.Unmarshal(msgBytes, &msg)
|
err := json.Unmarshal(msgBytes, &msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -465,8 +470,9 @@ func (handler *RpcRequestHandler) SendCancel() {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
msg := &RpcMessage{
|
msg := &RpcMessage{
|
||||||
Cancel: true,
|
Cancel: true,
|
||||||
ReqId: handler.reqId,
|
ReqId: handler.reqId,
|
||||||
|
AuthToken: handler.w.GetAuthToken(),
|
||||||
}
|
}
|
||||||
barr, _ := json.Marshal(msg) // will never fail
|
barr, _ := json.Marshal(msg) // will never fail
|
||||||
handler.w.OutputCh <- barr
|
handler.w.OutputCh <- barr
|
||||||
@ -560,6 +566,7 @@ func (handler *RpcResponseHandler) SendMessage(msg string) {
|
|||||||
Data: wshrpc.CommandMessageData{
|
Data: wshrpc.CommandMessageData{
|
||||||
Message: msg,
|
Message: msg,
|
||||||
},
|
},
|
||||||
|
AuthToken: handler.w.GetAuthToken(),
|
||||||
}
|
}
|
||||||
msgBytes, _ := json.Marshal(rpcMsg) // will never fail
|
msgBytes, _ := json.Marshal(rpcMsg) // will never fail
|
||||||
handler.w.OutputCh <- msgBytes
|
handler.w.OutputCh <- msgBytes
|
||||||
@ -583,9 +590,10 @@ func (handler *RpcResponseHandler) SendResponse(data any, done bool) error {
|
|||||||
defer handler.close()
|
defer handler.close()
|
||||||
}
|
}
|
||||||
msg := &RpcMessage{
|
msg := &RpcMessage{
|
||||||
ResId: handler.reqId,
|
ResId: handler.reqId,
|
||||||
Data: data,
|
Data: data,
|
||||||
Cont: !done,
|
Cont: !done,
|
||||||
|
AuthToken: handler.w.GetAuthToken(),
|
||||||
}
|
}
|
||||||
barr, err := json.Marshal(msg)
|
barr, err := json.Marshal(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -608,8 +616,9 @@ func (handler *RpcResponseHandler) SendResponseError(err error) {
|
|||||||
}
|
}
|
||||||
defer handler.close()
|
defer handler.close()
|
||||||
msg := &RpcMessage{
|
msg := &RpcMessage{
|
||||||
ResId: handler.reqId,
|
ResId: handler.reqId,
|
||||||
Error: err.Error(),
|
Error: err.Error(),
|
||||||
|
AuthToken: handler.w.GetAuthToken(),
|
||||||
}
|
}
|
||||||
barr, _ := json.Marshal(msg) // will never fail
|
barr, _ := json.Marshal(msg) // will never fail
|
||||||
handler.w.OutputCh <- barr
|
handler.w.OutputCh <- barr
|
||||||
@ -670,11 +679,12 @@ func (w *WshRpc) SendComplexRequest(command string, data any, opts *wshrpc.RpcOp
|
|||||||
handler.reqId = uuid.New().String()
|
handler.reqId = uuid.New().String()
|
||||||
}
|
}
|
||||||
req := &RpcMessage{
|
req := &RpcMessage{
|
||||||
Command: command,
|
Command: command,
|
||||||
ReqId: handler.reqId,
|
ReqId: handler.reqId,
|
||||||
Data: data,
|
Data: data,
|
||||||
Timeout: timeoutMs,
|
Timeout: timeoutMs,
|
||||||
Route: opts.Route,
|
Route: opts.Route,
|
||||||
|
AuthToken: w.GetAuthToken(),
|
||||||
}
|
}
|
||||||
barr, err := json.Marshal(req)
|
barr, err := json.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -4,7 +4,6 @@
|
|||||||
package wshutil
|
package wshutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -13,7 +12,6 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
@ -21,6 +19,7 @@ import (
|
|||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/util/packetparser"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||||
"golang.org/x/term"
|
"golang.org/x/term"
|
||||||
@ -198,11 +197,26 @@ func SetupTerminalRpcClient(serverImpl ServerImpl) (*WshRpc, io.Reader) {
|
|||||||
for msg := range outputCh {
|
for msg := range outputCh {
|
||||||
barr := EncodeWaveOSCBytes(WaveOSC, msg)
|
barr := EncodeWaveOSCBytes(WaveOSC, msg)
|
||||||
os.Stdout.Write(barr)
|
os.Stdout.Write(barr)
|
||||||
|
os.Stdout.Write([]byte{'\n'})
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return rpcClient, ptyBuf
|
return rpcClient, ptyBuf
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SetupPacketRpcClient(input io.Reader, output io.Writer, serverImpl ServerImpl) (*WshRpc, chan []byte) {
|
||||||
|
messageCh := make(chan []byte, DefaultInputChSize)
|
||||||
|
outputCh := make(chan []byte, DefaultOutputChSize)
|
||||||
|
rawCh := make(chan []byte, DefaultOutputChSize)
|
||||||
|
rpcClient := MakeWshRpc(messageCh, outputCh, wshrpc.RpcContext{}, serverImpl)
|
||||||
|
go packetparser.Parse(input, messageCh, rawCh)
|
||||||
|
go func() {
|
||||||
|
for msg := range outputCh {
|
||||||
|
packetparser.WritePacket(output, msg)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return rpcClient, rawCh
|
||||||
|
}
|
||||||
|
|
||||||
func SetupConnRpcClient(conn net.Conn, serverImpl ServerImpl) (*WshRpc, chan error, error) {
|
func SetupConnRpcClient(conn net.Conn, serverImpl ServerImpl) (*WshRpc, chan error, error) {
|
||||||
inputCh := make(chan []byte, DefaultInputChSize)
|
inputCh := make(chan []byte, DefaultInputChSize)
|
||||||
outputCh := make(chan []byte, DefaultOutputChSize)
|
outputCh := make(chan []byte, DefaultOutputChSize)
|
||||||
@ -364,15 +378,16 @@ type WriteFlusher interface {
|
|||||||
|
|
||||||
// blocking, returns if there is an error, or on EOF of input
|
// blocking, returns if there is an error, or on EOF of input
|
||||||
func HandleStdIOClient(logName string, input io.Reader, output io.Writer) {
|
func HandleStdIOClient(logName string, input io.Reader, output io.Writer) {
|
||||||
log.Printf("[%s] starting (HandleStdIOClient)\n", logName)
|
|
||||||
proxy := MakeRpcMultiProxy()
|
proxy := MakeRpcMultiProxy()
|
||||||
ptyBuffer := MakePtyBuffer(WaveOSCPrefix, input, proxy.FromRemoteRawCh)
|
rawCh := make(chan []byte, DefaultInputChSize)
|
||||||
|
go packetparser.Parse(input, proxy.FromRemoteRawCh, rawCh)
|
||||||
doneCh := make(chan struct{})
|
doneCh := make(chan struct{})
|
||||||
var doneOnce sync.Once
|
var doneOnce sync.Once
|
||||||
closeDoneCh := func() {
|
closeDoneCh := func() {
|
||||||
doneOnce.Do(func() {
|
doneOnce.Do(func() {
|
||||||
close(doneCh)
|
close(doneCh)
|
||||||
})
|
})
|
||||||
|
proxy.DisposeRoutes()
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
proxy.RunUnauthLoop()
|
proxy.RunUnauthLoop()
|
||||||
@ -380,9 +395,7 @@ func HandleStdIOClient(logName string, input io.Reader, output io.Writer) {
|
|||||||
go func() {
|
go func() {
|
||||||
defer closeDoneCh()
|
defer closeDoneCh()
|
||||||
for msg := range proxy.ToRemoteCh {
|
for msg := range proxy.ToRemoteCh {
|
||||||
log.Printf("[%s] sending message: %s\n", logName, string(msg))
|
err := packetparser.WritePacket(output, msg)
|
||||||
barr := EncodeWaveOSCBytes(WaveServerOSC, msg)
|
|
||||||
_, err := output.Write(barr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[%s] error writing to output: %v\n", logName, err)
|
log.Printf("[%s] error writing to output: %v\n", logName, err)
|
||||||
break
|
break
|
||||||
@ -391,22 +404,8 @@ func HandleStdIOClient(logName string, input io.Reader, output io.Writer) {
|
|||||||
}()
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
defer closeDoneCh()
|
defer closeDoneCh()
|
||||||
br := bufio.NewReader(ptyBuffer.InputReader)
|
for msg := range rawCh {
|
||||||
for {
|
log.Printf("[%s:stdout] %s", logName, msg)
|
||||||
line, err := br.ReadString('\n')
|
|
||||||
if line != "" {
|
|
||||||
if !strings.HasSuffix(line, "\n") {
|
|
||||||
line += "\n"
|
|
||||||
}
|
|
||||||
log.Printf("[%s] %s", logName, line)
|
|
||||||
}
|
|
||||||
if err == io.EOF {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("[%s] error reading from pty buffer: %v\n", logName, err)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
<-doneCh
|
<-doneCh
|
||||||
|
Loading…
Reference in New Issue
Block a user