From e0ffa4fa860654b1781c845743cb164f4cdd95e8 Mon Sep 17 00:00:00 2001 From: sawka Date: Mon, 21 Oct 2024 16:38:54 -0700 Subject: [PATCH] protocol fixups --- cmd/wsh/cmd/wshcmd-connserver.go | 1 + pkg/wshutil/wshproxy.go | 13 ++++++++++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/cmd/wsh/cmd/wshcmd-connserver.go b/cmd/wsh/cmd/wshcmd-connserver.go index 2b1508621..01fd94cce 100644 --- a/cmd/wsh/cmd/wshcmd-connserver.go +++ b/cmd/wsh/cmd/wshcmd-connserver.go @@ -82,6 +82,7 @@ func handleNewListenerConn(conn net.Conn, router *wshutil.WshRouter) { routeId, err := proxy.HandleClientProxyAuth(upstreamClient) if err != nil { log.Printf("error handling client proxy auth: %v\n", err) + conn.Close() return } router.RegisterRoute(routeId, proxy, false) diff --git a/pkg/wshutil/wshproxy.go b/pkg/wshutil/wshproxy.go index 637d9b4ee..cfda2872a 100644 --- a/pkg/wshutil/wshproxy.go +++ b/pkg/wshutil/wshproxy.go @@ -137,15 +137,21 @@ func (p *WshRpcProxy) HandleClientProxyAuth(upstream *WshRpc) (string, error) { } resp, err := upstream.SendRpcRequest(msg.Command, msg.Data, nil) if err != nil { - return "", fmt.Errorf("error authenticating: %w", err) + respErr := fmt.Errorf("error authenticating: %w", err) + p.sendResponseError(msg, respErr) + return "", respErr } var respData wshrpc.CommandAuthenticateRtnData err = utilfn.ReUnmarshal(&respData, resp) if err != nil { - return "", fmt.Errorf("error unmarshalling authenticate response: %w", err) + respErr := fmt.Errorf("error unmarshalling authenticate response: %w", err) + p.sendResponseError(msg, respErr) + return "", respErr } if respData.AuthToken == "" { - return "", fmt.Errorf("no auth token in authenticate response") + respErr := fmt.Errorf("no auth token in authenticate response") + p.sendResponseError(msg, respErr) + return "", respErr } p.SetAuthToken(respData.AuthToken) announceMsg := RpcMessage{ @@ -155,6 +161,7 @@ func (p *WshRpcProxy) HandleClientProxyAuth(upstream *WshRpc) (string, error) { } announceBytes, _ := json.Marshal(announceMsg) upstream.SendRpcMessage(announceBytes) + p.sendAuthenticateResponse(msg, respData.RouteId) return respData.RouteId, nil } }