diff --git a/pkg/web/ws.go b/pkg/web/ws.go index bde3ec167..5ef0c0344 100644 --- a/pkg/web/ws.go +++ b/pkg/web/ws.go @@ -30,6 +30,9 @@ const wsInitialPingTime = 1 * time.Second const DefaultCommandTimeout = 2 * time.Second +var GlobalLock = &sync.Mutex{} +var RouteToConnMap = map[string]string{} // routeid => connid + func RunWebSocketServer(listener net.Listener) { gr := mux.NewRouter() gr.HandleFunc("/ws", HandleWs) @@ -240,6 +243,31 @@ func WriteLoop(conn *websocket.Conn, outputCh chan any, closeCh chan any, routeI } } +func registerConn(wsConnId string, routeId string, wproxy *wshutil.WshRpcProxy) { + GlobalLock.Lock() + defer GlobalLock.Unlock() + curConnId := RouteToConnMap[routeId] + if curConnId != "" { + log.Printf("[websocket] warning: replacing existing connection for route %q\n", routeId) + wshutil.DefaultRouter.UnregisterRoute(routeId) + } + RouteToConnMap[routeId] = wsConnId + wshutil.DefaultRouter.RegisterRoute(routeId, wproxy) +} + +func unregisterConn(wsConnId string, routeId string) { + GlobalLock.Lock() + defer GlobalLock.Unlock() + curConnId := RouteToConnMap[routeId] + if curConnId != wsConnId { + // only unregister if we are the current connection (otherwise we were already removed) + log.Printf("[websocket] warning: trying to unregister connection %q for route %q but it is not the current connection (ignoring)\n", wsConnId, routeId) + return + } + delete(RouteToConnMap, routeId) + wshutil.DefaultRouter.UnregisterRoute(routeId) +} + func HandleWsInternal(w http.ResponseWriter, r *http.Request) error { windowId := r.URL.Query().Get("windowid") if windowId == "" { @@ -261,23 +289,19 @@ func HandleWsInternal(w http.ResponseWriter, r *http.Request) error { wsConnId := uuid.New().String() outputCh := make(chan any, 100) closeCh := make(chan any) - eventbus.RegisterWSChannel(wsConnId, windowId, outputCh) var routeId string if windowId == wshutil.ElectronRoute { routeId = wshutil.ElectronRoute } else { routeId = wshutil.MakeWindowRouteId(windowId) } - defer eventbus.UnregisterWSChannel(wsConnId) log.Printf("[websocket] new connection: windowid:%s connid:%s routeid:%s\n", windowId, wsConnId, routeId) - // we create a wshproxy to handle rpc messages to/from the window - wproxy := wshutil.MakeRpcProxy() - wshutil.DefaultRouter.RegisterRoute(routeId, wproxy) - defer func() { - wshutil.DefaultRouter.UnregisterRoute(routeId) - close(wproxy.ToRemoteCh) - }() - // WshServerFactoryFn(rpcInputCh, rpcOutputCh, wshrpc.RpcContext{}) + eventbus.RegisterWSChannel(wsConnId, windowId, outputCh) + defer eventbus.UnregisterWSChannel(wsConnId) + wproxy := wshutil.MakeRpcProxy() // we create a wshproxy to handle rpc messages to/from the window + defer close(wproxy.ToRemoteCh) + registerConn(wsConnId, routeId, wproxy) + defer unregisterConn(wsConnId, routeId) wg := &sync.WaitGroup{} wg.Add(2) go func() { diff --git a/pkg/wshutil/wshrouter.go b/pkg/wshutil/wshrouter.go index 26fdb8a39..6389e3d99 100644 --- a/pkg/wshutil/wshrouter.go +++ b/pkg/wshutil/wshrouter.go @@ -274,10 +274,14 @@ func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient) { log.Printf("[router] registering wsh route %q\n", routeId) router.Lock.Lock() defer router.Lock.Unlock() + alreadyExists := router.RouteMap[routeId] != nil + if alreadyExists { + log.Printf("[router] warning: route %q already exists (replacing)\n", routeId) + } router.RouteMap[routeId] = rpc go func() { // announce - if router.GetUpstreamClient() != nil { + if !alreadyExists && router.GetUpstreamClient() != nil { announceMsg := RpcMessage{Command: wshrpc.Command_RouteAnnounce, Source: routeId} announceBytes, _ := json.Marshal(announceMsg) router.GetUpstreamClient().SendRpcMessage(announceBytes)