diff --git a/.gitattributes b/.gitattributes index 212566614..94f480de9 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1 +1 @@ -* text=auto \ No newline at end of file +* text=auto eol=lf \ No newline at end of file diff --git a/cmd/server/main-server.go b/cmd/server/main-server.go index b560b2766..07c2a66d3 100644 --- a/cmd/server/main-server.go +++ b/cmd/server/main-server.go @@ -159,11 +159,11 @@ func shutdownActivityUpdate() { func createMainWshClient() { rpc := wshserver.GetMainRpcClient() - wshutil.DefaultRouter.RegisterRoute(wshutil.DefaultRoute, rpc) + wshutil.DefaultRouter.RegisterRoute(wshutil.DefaultRoute, rpc, true) wps.Broker.SetClient(wshutil.DefaultRouter) localConnWsh := wshutil.MakeWshRpc(nil, nil, wshrpc.RpcContext{Conn: wshrpc.LocalConnName}, &wshremote.ServerImpl{}) go wshremote.RunSysInfoLoop(localConnWsh, wshrpc.LocalConnName) - wshutil.DefaultRouter.RegisterRoute(wshutil.MakeConnectionRouteId(wshrpc.LocalConnName), localConnWsh) + wshutil.DefaultRouter.RegisterRoute(wshutil.MakeConnectionRouteId(wshrpc.LocalConnName), localConnWsh, true) } func main() { diff --git a/cmd/wsh/cmd/wshcmd-conn.go b/cmd/wsh/cmd/wshcmd-conn.go index c7f991056..73d3ee565 100644 --- a/cmd/wsh/cmd/wshcmd-conn.go +++ b/cmd/wsh/cmd/wshcmd-conn.go @@ -5,6 +5,7 @@ package cmd import ( "fmt" + "strings" "github.com/spf13/cobra" "github.com/wavetermdev/waveterm/pkg/remote" @@ -25,17 +26,24 @@ func init() { } func connStatus() error { - resp, err := wshclient.ConnStatusCommand(RpcClient, nil) + var allResp []wshrpc.ConnStatus + sshResp, err := wshclient.ConnStatusCommand(RpcClient, nil) if err != nil { - return fmt.Errorf("getting connection status: %w", err) + return fmt.Errorf("getting ssh connection status: %w", err) } - if len(resp) == 0 { + allResp = append(allResp, sshResp...) + wslResp, err := wshclient.WslStatusCommand(RpcClient, nil) + if err != nil { + return fmt.Errorf("getting wsl connection status: %w", err) + } + allResp = append(allResp, wslResp...) + if len(allResp) == 0 { WriteStdout("no connections\n") return nil } WriteStdout("%-30s %-12s\n", "connection", "status") WriteStdout("----------------------------------------------\n") - for _, conn := range resp { + for _, conn := range allResp { str := fmt.Sprintf("%-30s %-12s", conn.Connection, conn.Status) if conn.Error != "" { str += fmt.Sprintf(" (%s)", conn.Error) @@ -110,7 +118,7 @@ func connRun(cmd *cobra.Command, args []string) error { } connName = args[1] _, err := remote.ParseOpts(connName) - if err != nil { + if err != nil && !strings.HasPrefix(connName, "wsl://") { return fmt.Errorf("cannot parse connection name: %w", err) } } diff --git a/cmd/wsh/cmd/wshcmd-connserver.go b/cmd/wsh/cmd/wshcmd-connserver.go index cc00a694e..4f82c1067 100644 --- a/cmd/wsh/cmd/wshcmd-connserver.go +++ b/cmd/wsh/cmd/wshcmd-connserver.go @@ -4,29 +4,186 @@ package cmd import ( + "encoding/json" + "fmt" + "io" + "log" + "net" "os" + "sync/atomic" + "time" "github.com/spf13/cobra" + "github.com/wavetermdev/waveterm/pkg/util/packetparser" + "github.com/wavetermdev/waveterm/pkg/wavebase" + "github.com/wavetermdev/waveterm/pkg/wshrpc" + "github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient" "github.com/wavetermdev/waveterm/pkg/wshrpc/wshremote" + "github.com/wavetermdev/waveterm/pkg/wshutil" ) var serverCmd = &cobra.Command{ - Use: "connserver", - Hidden: true, - Short: "remote server to power wave blocks", - Args: cobra.NoArgs, - Run: serverRun, - PreRunE: preRunSetupRpcClient, + Use: "connserver", + Hidden: true, + Short: "remote server to power wave blocks", + Args: cobra.NoArgs, + RunE: serverRun, } +var connServerRouter bool + func init() { + serverCmd.Flags().BoolVar(&connServerRouter, "router", false, "run in local router mode") rootCmd.AddCommand(serverCmd) } -func serverRun(cmd *cobra.Command, args []string) { +func MakeRemoteUnixListener() (net.Listener, error) { + serverAddr := wavebase.GetRemoteDomainSocketName() + os.Remove(serverAddr) // ignore error + rtn, err := net.Listen("unix", serverAddr) + if err != nil { + return nil, fmt.Errorf("error creating listener at %v: %v", serverAddr, err) + } + os.Chmod(serverAddr, 0700) + log.Printf("Server [unix-domain] listening on %s\n", serverAddr) + return rtn, nil +} + +func handleNewListenerConn(conn net.Conn, router *wshutil.WshRouter) { + var routeIdContainer atomic.Pointer[string] + proxy := wshutil.MakeRpcProxy() + go func() { + writeErr := wshutil.AdaptOutputChToStream(proxy.ToRemoteCh, conn) + if writeErr != nil { + log.Printf("error writing to domain socket: %v\n", writeErr) + } + }() + go func() { + // when input is closed, close the connection + defer func() { + conn.Close() + routeIdPtr := routeIdContainer.Load() + if routeIdPtr != nil && *routeIdPtr != "" { + router.UnregisterRoute(*routeIdPtr) + disposeMsg := &wshutil.RpcMessage{ + Command: wshrpc.Command_Dispose, + Data: wshrpc.CommandDisposeData{ + RouteId: *routeIdPtr, + }, + Source: *routeIdPtr, + AuthToken: proxy.GetAuthToken(), + } + disposeBytes, _ := json.Marshal(disposeMsg) + router.InjectMessage(disposeBytes, *routeIdPtr) + } + }() + wshutil.AdaptStreamToMsgCh(conn, proxy.FromRemoteCh) + }() + routeId, err := proxy.HandleClientProxyAuth(router) + if err != nil { + log.Printf("error handling client proxy auth: %v\n", err) + conn.Close() + return + } + router.RegisterRoute(routeId, proxy, false) + routeIdContainer.Store(&routeId) +} + +func runListener(listener net.Listener, router *wshutil.WshRouter) { + defer func() { + log.Printf("listener closed, exiting\n") + time.Sleep(500 * time.Millisecond) + wshutil.DoShutdown("", 1, true) + }() + for { + conn, err := listener.Accept() + if err == io.EOF { + break + } + if err != nil { + log.Printf("error accepting connection: %v\n", err) + continue + } + go handleNewListenerConn(conn, router) + } +} + +func setupConnServerRpcClientWithRouter(router *wshutil.WshRouter) (*wshutil.WshRpc, error) { + jwtToken := os.Getenv(wshutil.WaveJwtTokenVarName) + if jwtToken == "" { + return nil, fmt.Errorf("no jwt token found for connserver") + } + rpcCtx, err := wshutil.ExtractUnverifiedRpcContext(jwtToken) + if err != nil { + return nil, fmt.Errorf("error extracting rpc context from %s: %v", wshutil.WaveJwtTokenVarName, err) + } + authRtn, err := router.HandleProxyAuth(jwtToken) + if err != nil { + return nil, fmt.Errorf("error handling proxy auth: %v", err) + } + inputCh := make(chan []byte, wshutil.DefaultInputChSize) + outputCh := make(chan []byte, wshutil.DefaultOutputChSize) + connServerClient := wshutil.MakeWshRpc(inputCh, outputCh, *rpcCtx, &wshremote.ServerImpl{LogWriter: os.Stdout}) + connServerClient.SetAuthToken(authRtn.AuthToken) + router.RegisterRoute(authRtn.RouteId, connServerClient, false) + wshclient.RouteAnnounceCommand(connServerClient, nil) + return connServerClient, nil +} + +func serverRunRouter() error { + router := wshutil.NewWshRouter() + termProxy := wshutil.MakeRpcProxy() + rawCh := make(chan []byte, wshutil.DefaultOutputChSize) + go packetparser.Parse(os.Stdin, termProxy.FromRemoteCh, rawCh) + go func() { + for msg := range termProxy.ToRemoteCh { + packetparser.WritePacket(os.Stdout, msg) + } + }() + go func() { + // just ignore and drain the rawCh (stdin) + // when stdin is closed, shutdown + defer wshutil.DoShutdown("", 0, true) + for range rawCh { + // ignore + } + }() + go func() { + for msg := range termProxy.FromRemoteCh { + // send this to the router + router.InjectMessage(msg, wshutil.UpstreamRoute) + } + }() + router.SetUpstreamClient(termProxy) + // now set up the domain socket + unixListener, err := MakeRemoteUnixListener() + if err != nil { + return fmt.Errorf("cannot create unix listener: %v", err) + } + client, err := setupConnServerRpcClientWithRouter(router) + if err != nil { + return fmt.Errorf("error setting up connserver rpc client: %v", err) + } + go runListener(unixListener, router) + // run the sysinfo loop + wshremote.RunSysInfoLoop(client, client.GetRpcContext().Conn) + select {} +} + +func serverRunNormal() error { + err := setupRpcClient(&wshremote.ServerImpl{LogWriter: os.Stdout}) + if err != nil { + return err + } WriteStdout("running wsh connserver (%s)\n", RpcContext.Conn) go wshremote.RunSysInfoLoop(RpcClient, RpcContext.Conn) - RpcClient.SetServerImpl(&wshremote.ServerImpl{LogWriter: os.Stdout}) - select {} // run forever } + +func serverRun(cmd *cobra.Command, args []string) error { + if connServerRouter { + return serverRunRouter() + } else { + return serverRunNormal() + } +} diff --git a/cmd/wsh/cmd/wshcmd-wsl.go b/cmd/wsh/cmd/wshcmd-wsl.go new file mode 100644 index 000000000..bad95ba21 --- /dev/null +++ b/cmd/wsh/cmd/wshcmd-wsl.go @@ -0,0 +1,60 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package cmd + +import ( + "strings" + + "github.com/spf13/cobra" + "github.com/wavetermdev/waveterm/pkg/waveobj" + "github.com/wavetermdev/waveterm/pkg/wshrpc" + "github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient" +) + +var distroName string + +var wslCmd = &cobra.Command{ + Use: "wsl [-d ]", + Short: "connect this terminal to a local wsl connection", + Args: cobra.NoArgs, + Run: wslRun, + PreRunE: preRunSetupRpcClient, +} + +func init() { + wslCmd.Flags().StringVarP(&distroName, "distribution", "d", "", "Run the specified distribution") + rootCmd.AddCommand(wslCmd) +} + +func wslRun(cmd *cobra.Command, args []string) { + var err error + if distroName == "" { + // get default distro from the host + distroName, err = wshclient.WslDefaultDistroCommand(RpcClient, nil) + if err != nil { + WriteStderr("[error] %s\n", err) + return + } + } + if !strings.HasPrefix(distroName, "wsl://") { + distroName = "wsl://" + distroName + } + blockId := RpcContext.BlockId + if blockId == "" { + WriteStderr("[error] cannot determine blockid (not in JWT)\n") + return + } + data := wshrpc.CommandSetMetaData{ + ORef: waveobj.MakeORef(waveobj.OType_Block, blockId), + Meta: map[string]any{ + waveobj.MetaKey_Connection: distroName, + }, + } + err = wshclient.SetMetaCommand(RpcClient, data, nil) + if err != nil { + WriteStderr("[error] setting switching connection: %v\n", err) + return + } + WriteStderr("switched connection to %q\n", distroName) +} diff --git a/frontend/app/block/blockframe.tsx b/frontend/app/block/blockframe.tsx index 0eb791a7c..b713bc88c 100644 --- a/frontend/app/block/blockframe.tsx +++ b/frontend/app/block/blockframe.tsx @@ -521,6 +521,7 @@ const ChangeConnectionBlockModal = React.memo( const connStatusAtom = getConnStatusAtom(connection); const connStatus = jotai.useAtomValue(connStatusAtom); const [connList, setConnList] = React.useState>([]); + const [wslList, setWslList] = React.useState>([]); const allConnStatus = jotai.useAtomValue(atoms.allConnStatus); const [rowIndex, setRowIndex] = React.useState(0); const connStatusMap = new Map(); @@ -540,6 +541,18 @@ const ChangeConnectionBlockModal = React.memo( prtn.then((newConnList) => { setConnList(newConnList ?? []); }).catch((e) => console.log("unable to load conn list from backend. using blank list: ", e)); + const p2rtn = RpcApi.WslListCommand(TabRpcClient, { timeout: 2000 }); + p2rtn + .then((newWslList) => { + console.log(newWslList); + setWslList(newWslList ?? []); + }) + .catch((e) => { + // removing this log and failing silentyly since it will happen + // if a system isn't using the wsl. and would happen every time the + // typeahead was opened. good candidate for verbose log level. + //console.log("unable to load wsl list from backend. using blank list: ", e) + }); }, [changeConnModalOpen, setConnList]); const changeConnection = React.useCallback( @@ -588,6 +601,15 @@ const ChangeConnectionBlockModal = React.memo( filteredList.push(conn); } } + const filteredWslList: Array = []; + for (const conn of wslList) { + if (conn === connSelected) { + createNew = false; + } + if (conn.includes(connSelected)) { + filteredWslList.push(conn); + } + } // priority handles special suggestions when necessary // for instance, when reconnecting const newConnectionSuggestion: SuggestionConnectionItem = { @@ -637,6 +659,20 @@ const ChangeConnectionBlockModal = React.memo( label: localName, }); } + for (const wslConn of filteredWslList) { + const connStatus = connStatusMap.get(wslConn); + const connColorNum = computeConnColorNum(connStatus); + localSuggestion.items.push({ + status: "connected", + icon: "arrow-right-arrow-left", + iconColor: + connStatus?.status == "connected" + ? `var(--conn-icon-color-${connColorNum})` + : "var(--grey-text-color)", + value: "wsl://" + wslConn, + label: "wsl://" + wslConn, + }); + } const remoteItems = filteredList.map((connName) => { const connStatus = connStatusMap.get(connName); const connColorNum = computeConnColorNum(connStatus); diff --git a/frontend/app/store/wshclientapi.ts b/frontend/app/store/wshclientapi.ts index 733dd6df0..f4a949ce0 100644 --- a/frontend/app/store/wshclientapi.ts +++ b/frontend/app/store/wshclientapi.ts @@ -72,6 +72,11 @@ class RpcApiType { return client.wshRpcCall("deleteblock", data, opts); } + // command "dispose" [call] + DisposeCommand(client: WshClient, data: CommandDisposeData, opts?: RpcOpts): Promise { + return client.wshRpcCall("dispose", data, opts); + } + // command "eventpublish" [call] EventPublishCommand(client: WshClient, data: WaveEvent, opts?: RpcOpts): Promise { return client.wshRpcCall("eventpublish", data, opts); @@ -237,6 +242,21 @@ class RpcApiType { return client.wshRpcCall("webselector", data, opts); } + // command "wsldefaultdistro" [call] + WslDefaultDistroCommand(client: WshClient, opts?: RpcOpts): Promise { + return client.wshRpcCall("wsldefaultdistro", null, opts); + } + + // command "wsllist" [call] + WslListCommand(client: WshClient, opts?: RpcOpts): Promise { + return client.wshRpcCall("wsllist", null, opts); + } + + // command "wslstatus" [call] + WslStatusCommand(client: WshClient, opts?: RpcOpts): Promise { + return client.wshRpcCall("wslstatus", null, opts); + } + } export const RpcApi = new RpcApiType(); diff --git a/frontend/types/gotypes.d.ts b/frontend/types/gotypes.d.ts index 6a0e57363..0a8652b36 100644 --- a/frontend/types/gotypes.d.ts +++ b/frontend/types/gotypes.d.ts @@ -63,6 +63,7 @@ declare global { // wshrpc.CommandAuthenticateRtnData type CommandAuthenticateRtnData = { routeid: string; + authtoken?: string; }; // wshrpc.CommandBlockInputData @@ -100,6 +101,11 @@ declare global { blockid: string; }; + // wshrpc.CommandDisposeData + type CommandDisposeData = { + routeid: string; + }; + // wshrpc.CommandEventReadHistoryData type CommandEventReadHistoryData = { event: string; @@ -416,6 +422,7 @@ declare global { resid?: string; timeout?: number; route?: string; + authtoken?: string; source?: string; cont?: boolean; cancel?: boolean; diff --git a/go.mod b/go.mod index 202dccedf..9c8cd95b3 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/shirou/gopsutil/v4 v4.24.9 github.com/skeema/knownhosts v1.3.0 github.com/spf13/cobra v1.8.1 + github.com/ubuntu/gowsl v0.0.0-20240906163211-049fd49bd93b github.com/wavetermdev/htmltoken v0.1.0 golang.org/x/crypto v0.28.0 golang.org/x/sys v0.26.0 @@ -36,9 +37,11 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect + github.com/sirupsen/logrus v1.9.3 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect + github.com/ubuntu/decorate v0.0.0-20230125165522-2d5b0a9bb117 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go.uber.org/atomic v1.7.0 // indirect golang.org/x/net v0.29.0 // indirect diff --git a/go.sum b/go.sum index 590fed68d..1ccb890df 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/0xrawsec/golang-utils v1.3.2 h1:ww4jrtHRSnX9xrGzJYbalx5nXoZewy4zPxiY+ubJgtg= +github.com/0xrawsec/golang-utils v1.3.2/go.mod h1:m7AzHXgdSAkFCD9tWWsApxNVxMlyy7anpPVOyT/yM7E= github.com/alexflint/go-filemutex v1.3.0 h1:LgE+nTUWnQCyRKbpoceKZsPQbs84LivvgwUymZXdOcM= github.com/alexflint/go-filemutex v1.3.0/go.mod h1:U0+VA/i30mGBlLCrFPGtTe9y6wGQfNAWPBTekHQ+c8A= github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= @@ -62,6 +64,8 @@ github.com/sawka/txwrap v0.2.0 h1:V3LfvKVLULxcYSxdMguLwFyQFMEU9nFDJopg0ZkL+94= github.com/sawka/txwrap v0.2.0/go.mod h1:wwQ2SQiN4U+6DU/iVPhbvr7OzXAtgZlQCIGuvOswEfA= github.com/shirou/gopsutil/v4 v4.24.9 h1:KIV+/HaHD5ka5f570RZq+2SaeFsb/pq+fp2DGNWYoOI= github.com/shirou/gopsutil/v4 v4.24.9/go.mod h1:3fkaHNeYsUFCGZ8+9vZVWtbyM1k2eRnlL+bWO8Bxa/Q= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/skeema/knownhosts v1.3.0 h1:AM+y0rI04VksttfwjkSTNQorvGqmwATnvnAHpSgc0LY= github.com/skeema/knownhosts v1.3.0/go.mod h1:sPINvnADmT/qYH1kfv+ePMmOBTH6Tbl7b5LvTDjFK7M= github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= @@ -71,12 +75,17 @@ github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= +github.com/ubuntu/decorate v0.0.0-20230125165522-2d5b0a9bb117 h1:XQpsQG5lqRJlx4mUVHcJvyyc1rdTI9nHvwrdfcuy8aM= +github.com/ubuntu/decorate v0.0.0-20230125165522-2d5b0a9bb117/go.mod h1:mx0TjbqsaDD9DUT5gA1s3hw47U6RIbbIBfvGzR85K0g= +github.com/ubuntu/gowsl v0.0.0-20240906163211-049fd49bd93b h1:wFBKF5k5xbJQU8bYgcSoQ/ScvmYyq6KHUabAuVUjOWM= +github.com/ubuntu/gowsl v0.0.0-20240906163211-049fd49bd93b/go.mod h1:N1CYNinssZru+ikvYTgVbVeSi21thHUTCoJ9xMvWe+s= github.com/wavetermdev/htmltoken v0.1.0 h1:RMdA9zTfnYa5jRC4RRG3XNoV5NOP8EDxpaVPjuVz//Q= github.com/wavetermdev/htmltoken v0.1.0/go.mod h1:5FM0XV6zNYiNza2iaTcFGj+hnMtgqumFHO31Z8euquk= github.com/wavetermdev/ssh_config v0.0.0-20240306041034-17e2087ebde2 h1:onqZrJVap1sm15AiIGTfWzdr6cEF0KdtddeuuOVhzyY= @@ -91,6 +100,7 @@ golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220721230656-c6bc011c0c49/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -102,5 +112,6 @@ golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pkg/blockcontroller/blockcontroller.go b/pkg/blockcontroller/blockcontroller.go index 6b187a342..09a86b2c4 100644 --- a/pkg/blockcontroller/blockcontroller.go +++ b/pkg/blockcontroller/blockcontroller.go @@ -11,6 +11,7 @@ import ( "io" "io/fs" "log" + "strings" "sync" "time" @@ -24,6 +25,7 @@ import ( "github.com/wavetermdev/waveterm/pkg/wps" "github.com/wavetermdev/waveterm/pkg/wshrpc" "github.com/wavetermdev/waveterm/pkg/wshutil" + "github.com/wavetermdev/waveterm/pkg/wsl" "github.com/wavetermdev/waveterm/pkg/wstore" ) @@ -262,7 +264,30 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj return fmt.Errorf("unknown controller type %q", bc.ControllerType) } var shellProc *shellexec.ShellProc - if remoteName != "" { + if strings.HasPrefix(remoteName, "wsl://") { + wslName := strings.TrimPrefix(remoteName, "wsl://") + credentialCtx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second) + defer cancelFunc() + + wslConn := wsl.GetWslConn(credentialCtx, wslName, false) + connStatus := wslConn.DeriveConnStatus() + if connStatus.Status != conncontroller.Status_Connected { + return fmt.Errorf("not connected, cannot start shellproc") + } + + // create jwt + if !blockMeta.GetBool(waveobj.MetaKey_CmdNoWsh, false) { + jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId, Conn: wslConn.GetName()}, wslConn.GetDomainSocketName()) + if err != nil { + return fmt.Errorf("error making jwt token: %w", err) + } + cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr + } + shellProc, err = shellexec.StartWslShellProc(ctx, rc.TermSize, cmdStr, cmdOpts, wslConn) + if err != nil { + return err + } + } else if remoteName != "" { credentialCtx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second) defer cancelFunc() @@ -325,7 +350,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj // we don't need to authenticate this wshProxy since it is coming direct wshProxy := wshutil.MakeRpcProxy() wshProxy.SetRpcContext(&wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId}) - wshutil.DefaultRouter.RegisterRoute(wshutil.MakeControllerRouteId(bc.BlockId), wshProxy) + wshutil.DefaultRouter.RegisterRoute(wshutil.MakeControllerRouteId(bc.BlockId), wshProxy, true) ptyBuffer := wshutil.MakePtyBuffer(wshutil.WaveOSCPrefix, bc.ShellProc.Cmd, wshProxy.FromRemoteCh) go func() { // handles regular output from the pty (goes to the blockfile and xterm) @@ -494,6 +519,15 @@ func CheckConnStatus(blockId string) error { if connName == "" { return nil } + if strings.HasPrefix(connName, "wsl://") { + distroName := strings.TrimPrefix(connName, "wsl://") + conn := wsl.GetWslConn(context.Background(), distroName, false) + connStatus := conn.DeriveConnStatus() + if connStatus.Status != conncontroller.Status_Connected { + return fmt.Errorf("not connected: %s", connStatus.Status) + } + return nil + } opts, err := remote.ParseOpts(connName) if err != nil { return fmt.Errorf("error parsing connection name: %w", err) diff --git a/pkg/remote/connutil.go b/pkg/remote/connutil.go index 5c224e880..b5d1841b8 100644 --- a/pkg/remote/connutil.go +++ b/pkg/remote/connutil.go @@ -1,3 +1,6 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + package remote import ( diff --git a/pkg/service/clientservice/clientservice.go b/pkg/service/clientservice/clientservice.go index f2a8d9b68..b8ec51fbd 100644 --- a/pkg/service/clientservice/clientservice.go +++ b/pkg/service/clientservice/clientservice.go @@ -17,6 +17,7 @@ import ( "github.com/wavetermdev/waveterm/pkg/wcore" "github.com/wavetermdev/waveterm/pkg/wlayout" "github.com/wavetermdev/waveterm/pkg/wshrpc" + "github.com/wavetermdev/waveterm/pkg/wsl" "github.com/wavetermdev/waveterm/pkg/wstore" ) @@ -77,7 +78,9 @@ func (cs *ClientService) MakeWindow(ctx context.Context) (*waveobj.Window, error } func (cs *ClientService) GetAllConnStatus(ctx context.Context) ([]wshrpc.ConnStatus, error) { - return conncontroller.GetAllConnStatus(), nil + sshStatuses := conncontroller.GetAllConnStatus() + wslStatuses := wsl.GetAllConnStatus() + return append(sshStatuses, wslStatuses...), nil } // moves the window to the front of the windowId stack diff --git a/pkg/shellexec/conninterface.go b/pkg/shellexec/conninterface.go index e601f8e1d..fce23f242 100644 --- a/pkg/shellexec/conninterface.go +++ b/pkg/shellexec/conninterface.go @@ -7,6 +7,7 @@ import ( "time" "github.com/creack/pty" + "github.com/wavetermdev/waveterm/pkg/wsl" "golang.org/x/crypto/ssh" ) @@ -129,3 +130,42 @@ func (sw SessionWrap) StderrPipe() (io.ReadCloser, error) { func (sw SessionWrap) SetSize(h int, w int) error { return sw.Session.WindowChange(h, w) } + +type WslCmdWrap struct { + *wsl.WslCmd + Tty pty.Tty + pty.Pty +} + +func (wcw WslCmdWrap) Kill() { + wcw.Tty.Close() + wcw.Close() +} + +func (wcw WslCmdWrap) KillGraceful(timeout time.Duration) { + process := wcw.WslCmd.GetProcess() + if process == nil { + return + } + processState := wcw.WslCmd.GetProcessState() + if processState != nil && processState.Exited() { + return + } + process.Signal(os.Interrupt) + go func() { + time.Sleep(timeout) + process := wcw.WslCmd.GetProcess() + processState := wcw.WslCmd.GetProcessState() + if processState == nil || !processState.Exited() { + process.Kill() // force kill if it is already not exited + } + }() +} + +/** + * SetSize does nothing for WslCmdWrap as there + * is no pty to manage. +**/ +func (wcw WslCmdWrap) SetSize(w int, h int) error { + return nil +} diff --git a/pkg/shellexec/shellexec.go b/pkg/shellexec/shellexec.go index c11faf619..ff985352d 100644 --- a/pkg/shellexec/shellexec.go +++ b/pkg/shellexec/shellexec.go @@ -5,6 +5,7 @@ package shellexec import ( "bytes" + "context" "fmt" "io" "log" @@ -25,6 +26,7 @@ import ( "github.com/wavetermdev/waveterm/pkg/wavebase" "github.com/wavetermdev/waveterm/pkg/waveobj" "github.com/wavetermdev/waveterm/pkg/wshutil" + "github.com/wavetermdev/waveterm/pkg/wsl" ) const DefaultGracefulKillWait = 400 * time.Millisecond @@ -141,6 +143,96 @@ func (pp *PipePty) WriteString(s string) (n int, err error) { return pp.Write([]byte(s)) } +func StartWslShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *wsl.WslConn) (*ShellProc, error) { + client := conn.GetClient() + shellPath := cmdOpts.ShellPath + if shellPath == "" { + remoteShellPath, err := wsl.DetectShell(conn.Context, client) + if err != nil { + return nil, err + } + shellPath = remoteShellPath + } + var shellOpts []string + log.Printf("detected shell: %s", shellPath) + + err := wsl.InstallClientRcFiles(conn.Context, client) + if err != nil { + log.Printf("error installing rc files: %v", err) + return nil, err + } + + homeDir := wsl.GetHomeDir(conn.Context, client) + shellOpts = append(shellOpts, "~", "-d", client.Name()) + + if isZshShell(shellPath) { + shellOpts = append(shellOpts, fmt.Sprintf(`ZDOTDIR="%s/.waveterm/%s"`, homeDir, shellutil.ZshIntegrationDir)) + } + var subShellOpts []string + + if cmdStr == "" { + /* transform command in order to inject environment vars */ + if isBashShell(shellPath) { + log.Printf("recognized as bash shell") + // add --rcfile + // cant set -l or -i with --rcfile + subShellOpts = append(subShellOpts, "--rcfile", fmt.Sprintf(`%s/.waveterm/%s/.bashrc`, homeDir, shellutil.BashIntegrationDir)) + } else if isFishShell(shellPath) { + carg := fmt.Sprintf(`"set -x PATH \"%s\"/.waveterm/%s $PATH"`, homeDir, shellutil.WaveHomeBinDir) + subShellOpts = append(subShellOpts, "-C", carg) + } else if wsl.IsPowershell(shellPath) { + // powershell is weird about quoted path executables and requires an ampersand first + shellPath = "& " + shellPath + subShellOpts = append(subShellOpts, "-ExecutionPolicy", "Bypass", "-NoExit", "-File", homeDir+fmt.Sprintf("/.waveterm/%s/wavepwsh.ps1", shellutil.PwshIntegrationDir)) + } else { + if cmdOpts.Login { + subShellOpts = append(subShellOpts, "-l") + } + if cmdOpts.Interactive { + subShellOpts = append(subShellOpts, "-i") + } + // can't set environment vars this way + // will try to do later if possible + } + } else { + shellPath = cmdStr + if cmdOpts.Login { + subShellOpts = append(subShellOpts, "-l") + } + if cmdOpts.Interactive { + subShellOpts = append(subShellOpts, "-i") + } + subShellOpts = append(subShellOpts, "-c", cmdStr) + } + + jwtToken, ok := cmdOpts.Env[wshutil.WaveJwtTokenVarName] + if !ok { + return nil, fmt.Errorf("no jwt token provided to connection") + } + if remote.IsPowershell(shellPath) { + shellOpts = append(shellOpts, "--", fmt.Sprintf(`$env:%s=%s;`, wshutil.WaveJwtTokenVarName, jwtToken)) + } else { + shellOpts = append(shellOpts, "--", fmt.Sprintf(`%s=%s`, wshutil.WaveJwtTokenVarName, jwtToken)) + } + shellOpts = append(shellOpts, shellPath) + shellOpts = append(shellOpts, subShellOpts...) + log.Printf("full cmd is: %s %s", "wsl.exe", strings.Join(shellOpts, " ")) + + ecmd := exec.Command("wsl.exe", shellOpts...) + if termSize.Rows == 0 || termSize.Cols == 0 { + termSize.Rows = shellutil.DefaultTermRows + termSize.Cols = shellutil.DefaultTermCols + } + if termSize.Rows <= 0 || termSize.Cols <= 0 { + return nil, fmt.Errorf("invalid term size: %v", termSize) + } + cmdPty, err := pty.StartWithSize(ecmd, &pty.Winsize{Rows: uint16(termSize.Rows), Cols: uint16(termSize.Cols)}) + if err != nil { + return nil, err + } + return &ShellProc{Cmd: CmdWrap{ecmd, cmdPty}, ConnName: conn.GetName(), CloseOnce: &sync.Once{}, DoneCh: make(chan any)}, nil +} + func StartRemoteShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *conncontroller.SSHConn) (*ShellProc, error) { client := conn.GetClient() shellPath := cmdOpts.ShellPath diff --git a/pkg/util/packetparser/packetparser.go b/pkg/util/packetparser/packetparser.go new file mode 100644 index 000000000..51df1666d --- /dev/null +++ b/pkg/util/packetparser/packetparser.go @@ -0,0 +1,58 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package packetparser + +import ( + "bufio" + "bytes" + "fmt" + "io" +) + +type PacketParser struct { + Reader io.Reader + Ch chan []byte +} + +func Parse(input io.Reader, packetCh chan []byte, rawCh chan []byte) error { + bufReader := bufio.NewReader(input) + defer close(packetCh) + defer close(rawCh) + for { + line, err := bufReader.ReadBytes('\n') + if err == io.EOF { + return nil + } + if err != nil { + return err + } + if len(line) <= 1 { + // just a blank line + continue + } + if bytes.HasPrefix(line, []byte{'#', '#', 'N', '{'}) && bytes.HasSuffix(line, []byte{'}', '\n'}) { + // strip off the leading "##" and trailing "\n" (single byte) + packetCh <- line[3 : len(line)-1] + } else { + rawCh <- line + } + } +} + +func WritePacket(output io.Writer, packet []byte) error { + if len(packet) < 2 { + return nil + } + if packet[0] != '{' || packet[len(packet)-1] != '}' { + return fmt.Errorf("invalid packet, must start with '{' and end with '}'") + } + fullPacket := make([]byte, 0, len(packet)+5) + // we add the extra newline to make sure the ## appears at the beginning of the line + // since writer isn't buffered, we want to send this all at once + fullPacket = append(fullPacket, '\n', '#', '#', 'N') + fullPacket = append(fullPacket, packet...) + fullPacket = append(fullPacket, '\n') + _, err := output.Write(fullPacket) + return err +} diff --git a/pkg/wavebase/wavebase.go b/pkg/wavebase/wavebase.go index 44f80b7a7..805386d52 100644 --- a/pkg/wavebase/wavebase.go +++ b/pkg/wavebase/wavebase.go @@ -30,10 +30,13 @@ const WaveDataHomeEnvVar = "WAVETERM_DATA_HOME" const WaveDevVarName = "WAVETERM_DEV" const WaveLockFile = "wave.lock" const DomainSocketBaseName = "wave.sock" +const RemoteDomainSocketBaseName = "wave-remote.sock" const WaveDBDir = "db" const JwtSecret = "waveterm" // TODO generate and store this const ConfigDir = "config" +var RemoteWaveHome = ExpandHomeDirSafe("~/.waveterm") + const WaveAppPathVarName = "WAVETERM_APP_PATH" const AppPathBinDir = "bin" @@ -101,6 +104,10 @@ func GetDomainSocketName() string { return filepath.Join(GetWaveDataDir(), DomainSocketBaseName) } +func GetRemoteDomainSocketName() string { + return filepath.Join(RemoteWaveHome, RemoteDomainSocketBaseName) +} + func GetWaveDataDir() string { retVal, found := os.LookupEnv(WaveDataHomeEnvVar) if !found { diff --git a/pkg/web/web.go b/pkg/web/web.go index 92d99ed77..695c73bbf 100644 --- a/pkg/web/web.go +++ b/pkg/web/web.go @@ -431,7 +431,7 @@ func MakeTCPListener(serviceName string) (net.Listener, error) { } func MakeUnixListener() (net.Listener, error) { - serverAddr := wavebase.GetWaveDataDir() + "/wave.sock" + serverAddr := wavebase.GetDomainSocketName() os.Remove(serverAddr) // ignore error rtn, err := net.Listen("unix", serverAddr) if err != nil { diff --git a/pkg/web/ws.go b/pkg/web/ws.go index 4cb6c1dcf..c0374efff 100644 --- a/pkg/web/ws.go +++ b/pkg/web/ws.go @@ -252,7 +252,7 @@ func registerConn(wsConnId string, routeId string, wproxy *wshutil.WshRpcProxy) wshutil.DefaultRouter.UnregisterRoute(routeId) } RouteToConnMap[routeId] = wsConnId - wshutil.DefaultRouter.RegisterRoute(routeId, wproxy) + wshutil.DefaultRouter.RegisterRoute(routeId, wproxy, true) } func unregisterConn(wsConnId string, routeId string) { diff --git a/pkg/wshrpc/wshclient/wshclient.go b/pkg/wshrpc/wshclient/wshclient.go index e8743ade3..5a8d553df 100644 --- a/pkg/wshrpc/wshclient/wshclient.go +++ b/pkg/wshrpc/wshclient/wshclient.go @@ -92,6 +92,12 @@ func DeleteBlockCommand(w *wshutil.WshRpc, data wshrpc.CommandDeleteBlockData, o return err } +// command "dispose", wshserver.DisposeCommand +func DisposeCommand(w *wshutil.WshRpc, data wshrpc.CommandDisposeData, opts *wshrpc.RpcOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "dispose", data, opts) + return err +} + // command "eventpublish", wshserver.EventPublishCommand func EventPublishCommand(w *wshutil.WshRpc, data wps.WaveEvent, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "eventpublish", data, opts) @@ -285,4 +291,22 @@ func WebSelectorCommand(w *wshutil.WshRpc, data wshrpc.CommandWebSelectorData, o return resp, err } +// command "wsldefaultdistro", wshserver.WslDefaultDistroCommand +func WslDefaultDistroCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) (string, error) { + resp, err := sendRpcRequestCallHelper[string](w, "wsldefaultdistro", nil, opts) + return resp, err +} + +// command "wsllist", wshserver.WslListCommand +func WslListCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) ([]string, error) { + resp, err := sendRpcRequestCallHelper[[]string](w, "wsllist", nil, opts) + return resp, err +} + +// command "wslstatus", wshserver.WslStatusCommand +func WslStatusCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) ([]wshrpc.ConnStatus, error) { + resp, err := sendRpcRequestCallHelper[[]wshrpc.ConnStatus](w, "wslstatus", nil, opts) + return resp, err +} + diff --git a/pkg/wshrpc/wshrpctypes.go b/pkg/wshrpc/wshrpctypes.go index 55dbc18ce..40ae627de 100644 --- a/pkg/wshrpc/wshrpctypes.go +++ b/pkg/wshrpc/wshrpctypes.go @@ -28,6 +28,7 @@ const ( const ( Command_Authenticate = "authenticate" // special + Command_Dispose = "dispose" // special (disposes of the route, for multiproxy only) Command_RouteAnnounce = "routeannounce" // special (for routing) Command_RouteUnannounce = "routeunannounce" // special (for routing) Command_Message = "message" @@ -62,11 +63,15 @@ const ( Command_RemoteFileDelete = "remotefiledelete" Command_RemoteFileJoiin = "remotefilejoin" + Command_ConnStatus = "connstatus" + Command_WslStatus = "wslstatus" Command_ConnEnsure = "connensure" Command_ConnReinstallWsh = "connreinstallwsh" Command_ConnConnect = "connconnect" Command_ConnDisconnect = "conndisconnect" Command_ConnList = "connlist" + Command_WslList = "wsllist" + Command_WslDefaultDistro = "wsldefaultdistro" Command_WebSelector = "webselector" Command_Notify = "notify" @@ -83,6 +88,7 @@ type RespOrErrorUnion[T any] struct { type WshRpcInterface interface { AuthenticateCommand(ctx context.Context, data string) (CommandAuthenticateRtnData, error) + DisposeCommand(ctx context.Context, data CommandDisposeData) error RouteAnnounceCommand(ctx context.Context) error // (special) announces a new route to the main router RouteUnannounceCommand(ctx context.Context) error // (special) unannounces a route to the main router @@ -114,11 +120,14 @@ type WshRpcInterface interface { // connection functions ConnStatusCommand(ctx context.Context) ([]ConnStatus, error) + WslStatusCommand(ctx context.Context) ([]ConnStatus, error) ConnEnsureCommand(ctx context.Context, connName string) error ConnReinstallWshCommand(ctx context.Context, connName string) error ConnConnectCommand(ctx context.Context, connName string) error ConnDisconnectCommand(ctx context.Context, connName string) error ConnListCommand(ctx context.Context) ([]string, error) + WslListCommand(ctx context.Context) ([]string, error) + WslDefaultDistroCommand(ctx context.Context) (string, error) // eventrecv is special, it's handled internally by WshRpc with EventListener EventRecvCommand(ctx context.Context, data wps.WaveEvent) error @@ -200,7 +209,13 @@ func HackRpcContextIntoData(dataPtr any, rpcContext RpcContext) { } type CommandAuthenticateRtnData struct { + RouteId string `json:"routeid"` + AuthToken string `json:"authtoken,omitempty"` +} + +type CommandDisposeData struct { RouteId string `json:"routeid"` + // auth token travels in the packet directly } type CommandMessageData struct { diff --git a/pkg/wshrpc/wshserver/wshserver.go b/pkg/wshrpc/wshserver/wshserver.go index d9afe7889..df1670b19 100644 --- a/pkg/wshrpc/wshserver/wshserver.go +++ b/pkg/wshrpc/wshserver/wshserver.go @@ -21,6 +21,7 @@ import ( "github.com/wavetermdev/waveterm/pkg/filestore" "github.com/wavetermdev/waveterm/pkg/remote" "github.com/wavetermdev/waveterm/pkg/remote/conncontroller" + "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/waveai" "github.com/wavetermdev/waveterm/pkg/waveobj" "github.com/wavetermdev/waveterm/pkg/wconfig" @@ -29,6 +30,7 @@ import ( "github.com/wavetermdev/waveterm/pkg/wps" "github.com/wavetermdev/waveterm/pkg/wshrpc" "github.com/wavetermdev/waveterm/pkg/wshutil" + "github.com/wavetermdev/waveterm/pkg/wsl" "github.com/wavetermdev/waveterm/pkg/wstore" ) @@ -36,6 +38,7 @@ const SimpleId_This = "this" const SimpleId_Tab = "tab" var SimpleId_BlockNum_Regex = regexp.MustCompile(`^\d+$`) +var InvalidWslDistroNames = []string{"docker-desktop", "docker-desktop-data"} type WshServer struct{} @@ -463,11 +466,28 @@ func (ws *WshServer) ConnStatusCommand(ctx context.Context) ([]wshrpc.ConnStatus return rtn, nil } +func (ws *WshServer) WslStatusCommand(ctx context.Context) ([]wshrpc.ConnStatus, error) { + rtn := wsl.GetAllConnStatus() + return rtn, nil +} + func (ws *WshServer) ConnEnsureCommand(ctx context.Context, connName string) error { + if strings.HasPrefix(connName, "wsl://") { + distroName := strings.TrimPrefix(connName, "wsl://") + return wsl.EnsureConnection(ctx, distroName) + } return conncontroller.EnsureConnection(ctx, connName) } func (ws *WshServer) ConnDisconnectCommand(ctx context.Context, connName string) error { + if strings.HasPrefix(connName, "wsl://") { + distroName := strings.TrimPrefix(connName, "wsl://") + conn := wsl.GetWslConn(ctx, distroName, false) + if conn == nil { + return fmt.Errorf("distro not found: %s", connName) + } + return conn.Close() + } connOpts, err := remote.ParseOpts(connName) if err != nil { return fmt.Errorf("error parsing connection name: %w", err) @@ -480,6 +500,14 @@ func (ws *WshServer) ConnDisconnectCommand(ctx context.Context, connName string) } func (ws *WshServer) ConnConnectCommand(ctx context.Context, connName string) error { + if strings.HasPrefix(connName, "wsl://") { + distroName := strings.TrimPrefix(connName, "wsl://") + conn := wsl.GetWslConn(ctx, distroName, false) + if conn == nil { + return fmt.Errorf("connection not found: %s", connName) + } + return conn.Connect(ctx) + } connOpts, err := remote.ParseOpts(connName) if err != nil { return fmt.Errorf("error parsing connection name: %w", err) @@ -492,6 +520,14 @@ func (ws *WshServer) ConnConnectCommand(ctx context.Context, connName string) er } func (ws *WshServer) ConnReinstallWshCommand(ctx context.Context, connName string) error { + if strings.HasPrefix(connName, "wsl://") { + distroName := strings.TrimPrefix(connName, "wsl://") + conn := wsl.GetWslConn(ctx, distroName, false) + if conn == nil { + return fmt.Errorf("connection not found: %s", connName) + } + return conn.CheckAndInstallWsh(ctx, connName, &wsl.WshInstallOpts{Force: true, NoUserPrompt: true}) + } connOpts, err := remote.ParseOpts(connName) if err != nil { return fmt.Errorf("error parsing connection name: %w", err) @@ -507,6 +543,33 @@ func (ws *WshServer) ConnListCommand(ctx context.Context) ([]string, error) { return conncontroller.GetConnectionsList() } +func (ws *WshServer) WslListCommand(ctx context.Context) ([]string, error) { + distros, err := wsl.RegisteredDistros(ctx) + if err != nil { + return nil, err + } + var distroNames []string + for _, distro := range distros { + distroName := distro.Name() + if utilfn.ContainsStr(InvalidWslDistroNames, distroName) { + continue + } + distroNames = append(distroNames, distroName) + } + return distroNames, nil +} + +func (ws *WshServer) WslDefaultDistroCommand(ctx context.Context) (string, error) { + distro, ok, err := wsl.DefaultDistro(ctx) + if err != nil { + return "", fmt.Errorf("unable to determine default distro: %w", err) + } + if !ok { + return "", fmt.Errorf("unable to determine default distro") + } + return distro.Name(), nil +} + func (ws *WshServer) BlockInfoCommand(ctx context.Context, blockId string) (*wshrpc.BlockInfoData, error) { blockData, err := wstore.DBMustGet[*waveobj.Block](ctx, blockId) if err != nil { diff --git a/pkg/wshutil/wshmultiproxy.go b/pkg/wshutil/wshmultiproxy.go new file mode 100644 index 000000000..be2888bf1 --- /dev/null +++ b/pkg/wshutil/wshmultiproxy.go @@ -0,0 +1,151 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package wshutil + +import ( + "encoding/json" + "fmt" + "sync" + + "github.com/google/uuid" + "github.com/wavetermdev/waveterm/pkg/wshrpc" +) + +type multiProxyRouteInfo struct { + RouteId string + AuthToken string + Proxy *WshRpcProxy + RpcContext *wshrpc.RpcContext +} + +// handles messages from multiple unauthenitcated clients +type WshRpcMultiProxy struct { + Lock *sync.Mutex + RouteInfo map[string]*multiProxyRouteInfo // authtoken to info + ToRemoteCh chan []byte + FromRemoteRawCh chan []byte // raw message from the remote +} + +func MakeRpcMultiProxy() *WshRpcMultiProxy { + return &WshRpcMultiProxy{ + Lock: &sync.Mutex{}, + RouteInfo: make(map[string]*multiProxyRouteInfo), + ToRemoteCh: make(chan []byte, DefaultInputChSize), + FromRemoteRawCh: make(chan []byte, DefaultOutputChSize), + } +} + +func (p *WshRpcMultiProxy) DisposeRoutes() { + p.Lock.Lock() + defer p.Lock.Unlock() + for authToken, routeInfo := range p.RouteInfo { + DefaultRouter.UnregisterRoute(routeInfo.RouteId) + delete(p.RouteInfo, authToken) + } +} + +func (p *WshRpcMultiProxy) getRouteInfo(authToken string) *multiProxyRouteInfo { + p.Lock.Lock() + defer p.Lock.Unlock() + return p.RouteInfo[authToken] +} + +func (p *WshRpcMultiProxy) setRouteInfo(authToken string, routeInfo *multiProxyRouteInfo) { + p.Lock.Lock() + defer p.Lock.Unlock() + p.RouteInfo[authToken] = routeInfo +} + +func (p *WshRpcMultiProxy) removeRouteInfo(authToken string) { + p.Lock.Lock() + defer p.Lock.Unlock() + delete(p.RouteInfo, authToken) +} + +func (p *WshRpcMultiProxy) sendResponseError(msg RpcMessage, sendErr error) { + if msg.ReqId == "" { + // no response needed + return + } + resp := RpcMessage{ + ResId: msg.ReqId, + Error: sendErr.Error(), + } + respBytes, _ := json.Marshal(resp) + p.ToRemoteCh <- respBytes +} + +func (p *WshRpcMultiProxy) sendAuthResponse(msg RpcMessage, routeId string, authToken string) { + if msg.ReqId == "" { + // no response needed + return + } + resp := RpcMessage{ + ResId: msg.ReqId, + Data: wshrpc.CommandAuthenticateRtnData{RouteId: routeId, AuthToken: authToken}, + } + respBytes, _ := json.Marshal(resp) + p.ToRemoteCh <- respBytes +} + +func (p *WshRpcMultiProxy) handleUnauthMessage(msgBytes []byte) { + var msg RpcMessage + err := json.Unmarshal(msgBytes, &msg) + if err != nil { + // nothing to do here, malformed message + return + } + if msg.Command == wshrpc.Command_Authenticate { + rpcContext, routeId, err := handleAuthenticationCommand(msg) + if err != nil { + p.sendResponseError(msg, err) + return + } + routeInfo := &multiProxyRouteInfo{ + RouteId: routeId, + AuthToken: uuid.New().String(), + RpcContext: rpcContext, + } + routeInfo.Proxy = MakeRpcProxy() + routeInfo.Proxy.SetRpcContext(rpcContext) + p.setRouteInfo(routeInfo.AuthToken, routeInfo) + p.sendAuthResponse(msg, routeId, routeInfo.AuthToken) + go func() { + for msgBytes := range routeInfo.Proxy.ToRemoteCh { + p.ToRemoteCh <- msgBytes + } + }() + DefaultRouter.RegisterRoute(routeId, routeInfo.Proxy, true) + return + } + if msg.AuthToken == "" { + p.sendResponseError(msg, fmt.Errorf("no auth token")) + return + } + routeInfo := p.getRouteInfo(msg.AuthToken) + if routeInfo == nil { + p.sendResponseError(msg, fmt.Errorf("invalid auth token")) + return + } + if msg.Command != "" && msg.Source != routeInfo.RouteId { + p.sendResponseError(msg, fmt.Errorf("invalid source route for auth token")) + return + } + if msg.Command == wshrpc.Command_Dispose { + DefaultRouter.UnregisterRoute(routeInfo.RouteId) + p.removeRouteInfo(msg.AuthToken) + close(routeInfo.Proxy.ToRemoteCh) + close(routeInfo.Proxy.FromRemoteCh) + return + } + routeInfo.Proxy.FromRemoteCh <- msgBytes +} + +func (p *WshRpcMultiProxy) RunUnauthLoop() { + // loop over unauthenticated message + // handle Authenicate commands, and pass authenticated messages to the AuthCh + for msgBytes := range p.FromRemoteRawCh { + p.handleUnauthMessage(msgBytes) + } +} diff --git a/pkg/wshutil/wshproxy.go b/pkg/wshutil/wshproxy.go index c6a1ecf9f..c919b5d07 100644 --- a/pkg/wshutil/wshproxy.go +++ b/pkg/wshutil/wshproxy.go @@ -6,7 +6,6 @@ package wshutil import ( "encoding/json" "fmt" - "log" "sync" "github.com/google/uuid" @@ -18,6 +17,7 @@ type WshRpcProxy struct { RpcContext *wshrpc.RpcContext ToRemoteCh chan []byte FromRemoteCh chan []byte + AuthToken string } func MakeRpcProxy() *WshRpcProxy { @@ -40,6 +40,18 @@ func (p *WshRpcProxy) GetRpcContext() *wshrpc.RpcContext { return p.RpcContext } +func (p *WshRpcProxy) SetAuthToken(authToken string) { + p.Lock.Lock() + defer p.Lock.Unlock() + p.AuthToken = authToken +} + +func (p *WshRpcProxy) GetAuthToken() string { + p.Lock.Lock() + defer p.Lock.Unlock() + return p.AuthToken +} + func (p *WshRpcProxy) sendResponseError(msg RpcMessage, sendErr error) { if msg.ReqId == "" { // no response needed @@ -54,7 +66,7 @@ func (p *WshRpcProxy) sendResponseError(msg RpcMessage, sendErr error) { p.SendRpcMessage(respBytes) } -func (p *WshRpcProxy) sendResponse(msg RpcMessage, routeId string) { +func (p *WshRpcProxy) sendAuthenticateResponse(msg RpcMessage, routeId string) { if msg.ReqId == "" { // no response needed return @@ -98,6 +110,49 @@ func handleAuthenticationCommand(msg RpcMessage) (*wshrpc.RpcContext, string, er return newCtx, routeId, nil } +// runs on the client (stdio client) +func (p *WshRpcProxy) HandleClientProxyAuth(router *WshRouter) (string, error) { + for { + msgBytes, ok := <-p.FromRemoteCh + if !ok { + return "", fmt.Errorf("remote closed, not authenticated") + } + var origMsg RpcMessage + err := json.Unmarshal(msgBytes, &origMsg) + if err != nil { + // nothing to do, can't even send a response since we don't have Source or ReqId + continue + } + if origMsg.Command == "" { + // this message is not allowed (protocol error at this point), ignore + continue + } + // we only allow one command "authenticate", everything else returns an error + if origMsg.Command != wshrpc.Command_Authenticate { + respErr := fmt.Errorf("connection not authenticated") + p.sendResponseError(origMsg, respErr) + continue + } + authRtn, err := router.HandleProxyAuth(origMsg.Data) + if err != nil { + respErr := fmt.Errorf("error handling proxy auth: %w", err) + p.sendResponseError(origMsg, respErr) + return "", respErr + } + p.SetAuthToken(authRtn.AuthToken) + announceMsg := RpcMessage{ + Command: wshrpc.Command_RouteAnnounce, + Source: authRtn.RouteId, + AuthToken: authRtn.AuthToken, + } + announceBytes, _ := json.Marshal(announceMsg) + router.InjectMessage(announceBytes, authRtn.RouteId) + p.sendAuthenticateResponse(origMsg, authRtn.RouteId) + return authRtn.RouteId, nil + } +} + +// runs on the server func (p *WshRpcProxy) HandleAuthentication() (*wshrpc.RpcContext, error) { for { msgBytes, ok := <-p.FromRemoteCh @@ -122,11 +177,10 @@ func (p *WshRpcProxy) HandleAuthentication() (*wshrpc.RpcContext, error) { } newCtx, routeId, err := handleAuthenticationCommand(msg) if err != nil { - log.Printf("error handling authentication: %v\n", err) p.sendResponseError(msg, err) continue } - p.sendResponse(msg, routeId) + p.sendAuthenticateResponse(msg, routeId) return newCtx, nil } } @@ -136,9 +190,10 @@ func (p *WshRpcProxy) SendRpcMessage(msg []byte) { } func (p *WshRpcProxy) RecvRpcMessage() ([]byte, bool) { - msgBytes, ok := <-p.FromRemoteCh - if !ok || p.RpcContext == nil { - return msgBytes, ok + msgBytes, more := <-p.FromRemoteCh + authToken := p.GetAuthToken() + if !more || (p.RpcContext == nil && authToken == "") { + return msgBytes, more } var msg RpcMessage err := json.Unmarshal(msgBytes, &msg) @@ -146,10 +201,15 @@ func (p *WshRpcProxy) RecvRpcMessage() ([]byte, bool) { // nothing to do here -- will error out at another level return msgBytes, true } - msg.Data, err = recodeCommandData(msg.Command, msg.Data, p.RpcContext) - if err != nil { - // nothing to do here -- will error out at another level - return msgBytes, true + if p.RpcContext != nil { + msg.Data, err = recodeCommandData(msg.Command, msg.Data, p.RpcContext) + if err != nil { + // nothing to do here -- will error out at another level + return msgBytes, true + } + } + if msg.AuthToken == "" { + msg.AuthToken = authToken } newBytes, err := json.Marshal(msg) if err != nil { diff --git a/pkg/wshutil/wshrouter.go b/pkg/wshutil/wshrouter.go index da213943b..64479f498 100644 --- a/pkg/wshutil/wshrouter.go +++ b/pkg/wshutil/wshrouter.go @@ -12,11 +12,14 @@ import ( "sync" "time" + "github.com/google/uuid" + "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/wps" "github.com/wavetermdev/waveterm/pkg/wshrpc" ) const DefaultRoute = "wavesrv" +const UpstreamRoute = "upstream" const SysRoute = "sys" // this route doesn't exist, just a placeholder for system messages const ElectronRoute = "electron" @@ -36,12 +39,13 @@ type msgAndRoute struct { } type WshRouter struct { - Lock *sync.Mutex - RouteMap map[string]AbstractRpcClient // routeid => client - UpstreamClient AbstractRpcClient // upstream client (if we are not the terminal router) - AnnouncedRoutes map[string]string // routeid => local routeid - RpcMap map[string]*routeInfo // rpcid => routeinfo - InputCh chan msgAndRoute + Lock *sync.Mutex + RouteMap map[string]AbstractRpcClient // routeid => client + UpstreamClient AbstractRpcClient // upstream client (if we are not the terminal router) + AnnouncedRoutes map[string]string // routeid => local routeid + RpcMap map[string]*routeInfo // rpcid => routeinfo + SimpleRequestMap map[string]chan *RpcMessage // simple reqid => response channel + InputCh chan msgAndRoute } func MakeConnectionRouteId(connId string) string { @@ -68,11 +72,12 @@ var DefaultRouter = NewWshRouter() func NewWshRouter() *WshRouter { rtn := &WshRouter{ - Lock: &sync.Mutex{}, - RouteMap: make(map[string]AbstractRpcClient), - AnnouncedRoutes: make(map[string]string), - RpcMap: make(map[string]*routeInfo), - InputCh: make(chan msgAndRoute, DefaultInputChSize), + Lock: &sync.Mutex{}, + RouteMap: make(map[string]AbstractRpcClient), + AnnouncedRoutes: make(map[string]string), + RpcMap: make(map[string]*routeInfo), + SimpleRequestMap: make(map[string]chan *RpcMessage), + InputCh: make(chan msgAndRoute, DefaultInputChSize), } go rtn.runServer() return rtn @@ -237,6 +242,10 @@ func (router *WshRouter) runServer() { router.sendRoutedMessage(msgBytes, routeInfo.DestRouteId) continue } else if msg.ResId != "" { + ok := router.trySimpleResponse(&msg) + if ok { + continue + } routeInfo := router.getRouteInfo(msg.ResId) if routeInfo == nil { // no route info, nothing to do @@ -269,10 +278,10 @@ func (router *WshRouter) WaitForRegister(ctx context.Context, routeId string) er } // this will also consume the output channel of the abstract client -func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient) { - if routeId == SysRoute { +func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient, shouldAnnounce bool) { + if routeId == SysRoute || routeId == UpstreamRoute { // cannot register sys route - log.Printf("error: WshRouter cannot register sys route\n") + log.Printf("error: WshRouter cannot register %s route\n", routeId) return } log.Printf("[router] registering wsh route %q\n", routeId) @@ -285,7 +294,7 @@ func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient) { router.RouteMap[routeId] = rpc go func() { // announce - if !alreadyExists && router.GetUpstreamClient() != nil { + if shouldAnnounce && !alreadyExists && router.GetUpstreamClient() != nil { announceMsg := RpcMessage{Command: wshrpc.Command_RouteAnnounce, Source: routeId} announceBytes, _ := json.Marshal(announceMsg) router.GetUpstreamClient().SendRpcMessage(announceBytes) @@ -352,3 +361,97 @@ func (router *WshRouter) GetUpstreamClient() AbstractRpcClient { defer router.Lock.Unlock() return router.UpstreamClient } + +func (router *WshRouter) InjectMessage(msgBytes []byte, fromRouteId string) { + router.InputCh <- msgAndRoute{msgBytes: msgBytes, fromRouteId: fromRouteId} +} + +func (router *WshRouter) registerSimpleRequest(reqId string) chan *RpcMessage { + router.Lock.Lock() + defer router.Lock.Unlock() + rtn := make(chan *RpcMessage, 1) + router.SimpleRequestMap[reqId] = rtn + return rtn +} + +func (router *WshRouter) trySimpleResponse(msg *RpcMessage) bool { + router.Lock.Lock() + defer router.Lock.Unlock() + respCh := router.SimpleRequestMap[msg.ResId] + if respCh == nil { + return false + } + respCh <- msg + delete(router.SimpleRequestMap, msg.ResId) + return true +} + +func (router *WshRouter) clearSimpleRequest(reqId string) { + router.Lock.Lock() + defer router.Lock.Unlock() + delete(router.SimpleRequestMap, reqId) +} + +func (router *WshRouter) RunSimpleRawCommand(ctx context.Context, msg RpcMessage, fromRouteId string) (*RpcMessage, error) { + if msg.Command == "" { + return nil, errors.New("no command") + } + msgBytes, err := json.Marshal(msg) + if err != nil { + return nil, err + } + var respCh chan *RpcMessage + if msg.ReqId != "" { + respCh = router.registerSimpleRequest(msg.ReqId) + } + router.InjectMessage(msgBytes, fromRouteId) + if respCh == nil { + return nil, nil + } + select { + case <-ctx.Done(): + router.clearSimpleRequest(msg.ReqId) + return nil, ctx.Err() + case resp := <-respCh: + if resp.Error != "" { + return nil, errors.New(resp.Error) + } + return resp, nil + } +} + +func (router *WshRouter) HandleProxyAuth(jwtTokenAny any) (*wshrpc.CommandAuthenticateRtnData, error) { + if jwtTokenAny == nil { + return nil, errors.New("no jwt token") + } + jwtToken, ok := jwtTokenAny.(string) + if !ok { + return nil, errors.New("jwt token not a string") + } + if jwtToken == "" { + return nil, errors.New("empty jwt token") + } + msg := RpcMessage{ + Command: wshrpc.Command_Authenticate, + ReqId: uuid.New().String(), + Data: jwtToken, + } + ctx, cancelFn := context.WithTimeout(context.Background(), DefaultTimeoutMs*time.Millisecond) + defer cancelFn() + resp, err := router.RunSimpleRawCommand(ctx, msg, "") + if err != nil { + return nil, err + } + if resp == nil || resp.Data == nil { + return nil, errors.New("no data in authenticate response") + } + var respData wshrpc.CommandAuthenticateRtnData + err = utilfn.ReUnmarshal(&respData, resp.Data) + if err != nil { + return nil, fmt.Errorf("error unmarshalling authenticate response: %v", err) + } + if respData.AuthToken == "" { + return nil, errors.New("no auth token in authenticate response") + } + return &respData, nil +} diff --git a/pkg/wshutil/wshrpc.go b/pkg/wshutil/wshrpc.go index cccc353e7..7d31246ac 100644 --- a/pkg/wshutil/wshrpc.go +++ b/pkg/wshutil/wshrpc.go @@ -45,10 +45,13 @@ type WshRpc struct { InputCh chan []byte OutputCh chan []byte RpcContext *atomic.Pointer[wshrpc.RpcContext] + AuthToken string RpcMap map[string]*rpcData ServerImpl ServerImpl EventListener *EventListener ResponseHandlerMap map[string]*RpcResponseHandler // reqId => handler + Debug bool + DebugName string } type wshRpcContextKey struct{} @@ -104,17 +107,18 @@ func (w *WshRpc) RecvRpcMessage() ([]byte, bool) { } type RpcMessage struct { - Command string `json:"command,omitempty"` - ReqId string `json:"reqid,omitempty"` - ResId string `json:"resid,omitempty"` - Timeout int `json:"timeout,omitempty"` - Route string `json:"route,omitempty"` // to route/forward requests to alternate servers - Source string `json:"source,omitempty"` // source route id - Cont bool `json:"cont,omitempty"` // flag if additional requests/responses are forthcoming - Cancel bool `json:"cancel,omitempty"` // used to cancel a streaming request or response (sent from the side that is not streaming) - Error string `json:"error,omitempty"` - DataType string `json:"datatype,omitempty"` - Data any `json:"data,omitempty"` + Command string `json:"command,omitempty"` + ReqId string `json:"reqid,omitempty"` + ResId string `json:"resid,omitempty"` + Timeout int `json:"timeout,omitempty"` + Route string `json:"route,omitempty"` // to route/forward requests to alternate servers + AuthToken string `json:"authtoken,omitempty"` // needed for routing unauthenticated requests (WshRpcMultiProxy) + Source string `json:"source,omitempty"` // source route id + Cont bool `json:"cont,omitempty"` // flag if additional requests/responses are forthcoming + Cancel bool `json:"cancel,omitempty"` // used to cancel a streaming request or response (sent from the side that is not streaming) + Error string `json:"error,omitempty"` + DataType string `json:"datatype,omitempty"` + Data any `json:"data,omitempty"` } func (r *RpcMessage) IsRpcRequest() bool { @@ -226,6 +230,14 @@ func (w *WshRpc) SetRpcContext(ctx wshrpc.RpcContext) { w.RpcContext.Store(&ctx) } +func (w *WshRpc) SetAuthToken(token string) { + w.AuthToken = token +} + +func (w *WshRpc) GetAuthToken() string { + return w.AuthToken +} + func (w *WshRpc) registerResponseHandler(reqId string, handler *RpcResponseHandler) { w.Lock.Lock() defer w.Lock.Unlock() @@ -323,6 +335,9 @@ func (w *WshRpc) handleRequest(req *RpcMessage) { func (w *WshRpc) runServer() { defer close(w.OutputCh) for msgBytes := range w.InputCh { + if w.Debug { + log.Printf("[%s] received message: %s\n", w.DebugName, string(msgBytes)) + } var msg RpcMessage err := json.Unmarshal(msgBytes, &msg) if err != nil { @@ -455,8 +470,9 @@ func (handler *RpcRequestHandler) SendCancel() { } }() msg := &RpcMessage{ - Cancel: true, - ReqId: handler.reqId, + Cancel: true, + ReqId: handler.reqId, + AuthToken: handler.w.GetAuthToken(), } barr, _ := json.Marshal(msg) // will never fail handler.w.OutputCh <- barr @@ -550,6 +566,7 @@ func (handler *RpcResponseHandler) SendMessage(msg string) { Data: wshrpc.CommandMessageData{ Message: msg, }, + AuthToken: handler.w.GetAuthToken(), } msgBytes, _ := json.Marshal(rpcMsg) // will never fail handler.w.OutputCh <- msgBytes @@ -573,9 +590,10 @@ func (handler *RpcResponseHandler) SendResponse(data any, done bool) error { defer handler.close() } msg := &RpcMessage{ - ResId: handler.reqId, - Data: data, - Cont: !done, + ResId: handler.reqId, + Data: data, + Cont: !done, + AuthToken: handler.w.GetAuthToken(), } barr, err := json.Marshal(msg) if err != nil { @@ -598,8 +616,9 @@ func (handler *RpcResponseHandler) SendResponseError(err error) { } defer handler.close() msg := &RpcMessage{ - ResId: handler.reqId, - Error: err.Error(), + ResId: handler.reqId, + Error: err.Error(), + AuthToken: handler.w.GetAuthToken(), } barr, _ := json.Marshal(msg) // will never fail handler.w.OutputCh <- barr @@ -660,11 +679,12 @@ func (w *WshRpc) SendComplexRequest(command string, data any, opts *wshrpc.RpcOp handler.reqId = uuid.New().String() } req := &RpcMessage{ - Command: command, - ReqId: handler.reqId, - Data: data, - Timeout: timeoutMs, - Route: opts.Route, + Command: command, + ReqId: handler.reqId, + Data: data, + Timeout: timeoutMs, + Route: opts.Route, + AuthToken: w.GetAuthToken(), } barr, err := json.Marshal(req) if err != nil { diff --git a/pkg/wshutil/wshutil.go b/pkg/wshutil/wshutil.go index 79cdc6080..8be9c908a 100644 --- a/pkg/wshutil/wshutil.go +++ b/pkg/wshutil/wshutil.go @@ -19,6 +19,7 @@ import ( "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" + "github.com/wavetermdev/waveterm/pkg/util/packetparser" "github.com/wavetermdev/waveterm/pkg/wavebase" "github.com/wavetermdev/waveterm/pkg/wshrpc" "golang.org/x/term" @@ -204,11 +205,26 @@ func SetupTerminalRpcClient(serverImpl ServerImpl) (*WshRpc, io.Reader) { continue } os.Stdout.Write(barr) + os.Stdout.Write([]byte{'\n'}) } }() return rpcClient, ptyBuf } +func SetupPacketRpcClient(input io.Reader, output io.Writer, serverImpl ServerImpl) (*WshRpc, chan []byte) { + messageCh := make(chan []byte, DefaultInputChSize) + outputCh := make(chan []byte, DefaultOutputChSize) + rawCh := make(chan []byte, DefaultOutputChSize) + rpcClient := MakeWshRpc(messageCh, outputCh, wshrpc.RpcContext{}, serverImpl) + go packetparser.Parse(input, messageCh, rawCh) + go func() { + for msg := range outputCh { + packetparser.WritePacket(output, msg) + } + }() + return rpcClient, rawCh +} + func SetupConnRpcClient(conn net.Conn, serverImpl ServerImpl) (*WshRpc, chan error, error) { inputCh := make(chan []byte, DefaultInputChSize) outputCh := make(chan []byte, DefaultOutputChSize) @@ -229,10 +245,22 @@ func SetupConnRpcClient(conn net.Conn, serverImpl ServerImpl) (*WshRpc, chan err return rtn, writeErrCh, nil } -func SetupDomainSocketRpcClient(sockName string, serverImpl ServerImpl) (*WshRpc, error) { - conn, err := net.Dial("unix", sockName) +func tryTcpSocket(sockName string) (net.Conn, error) { + addr, err := net.ResolveTCPAddr("tcp", sockName) if err != nil { - return nil, fmt.Errorf("failed to connect to Unix domain socket: %w", err) + return nil, err + } + return net.DialTCP("tcp", nil, addr) +} + +func SetupDomainSocketRpcClient(sockName string, serverImpl ServerImpl) (*WshRpc, error) { + conn, tcpErr := tryTcpSocket(sockName) + var unixErr error + if tcpErr != nil { + conn, unixErr = net.Dial("unix", sockName) + } + if tcpErr != nil && unixErr != nil { + return nil, fmt.Errorf("failed to connect to tcp or unix domain socket: tcp err:%w: unix socket err: %w", tcpErr, unixErr) } rtn, errCh, err := SetupConnRpcClient(conn, serverImpl) go func() { @@ -363,6 +391,46 @@ func MakeRouteIdFromCtx(rpcCtx *wshrpc.RpcContext) (string, error) { return MakeProcRouteId(procId), nil } +type WriteFlusher interface { + Write([]byte) (int, error) + Flush() error +} + +// blocking, returns if there is an error, or on EOF of input +func HandleStdIOClient(logName string, input io.Reader, output io.Writer) { + proxy := MakeRpcMultiProxy() + rawCh := make(chan []byte, DefaultInputChSize) + go packetparser.Parse(input, proxy.FromRemoteRawCh, rawCh) + doneCh := make(chan struct{}) + var doneOnce sync.Once + closeDoneCh := func() { + doneOnce.Do(func() { + close(doneCh) + }) + proxy.DisposeRoutes() + } + go func() { + proxy.RunUnauthLoop() + }() + go func() { + defer closeDoneCh() + for msg := range proxy.ToRemoteCh { + err := packetparser.WritePacket(output, msg) + if err != nil { + log.Printf("[%s] error writing to output: %v\n", logName, err) + break + } + } + }() + go func() { + defer closeDoneCh() + for msg := range rawCh { + log.Printf("[%s:stdout] %s", logName, msg) + } + }() + <-doneCh +} + func handleDomainSocketClient(conn net.Conn) { var routeIdContainer atomic.Pointer[string] proxy := MakeRpcProxy() @@ -399,7 +467,7 @@ func handleDomainSocketClient(conn net.Conn) { return } routeIdContainer.Store(&routeId) - DefaultRouter.RegisterRoute(routeId, proxy) + DefaultRouter.RegisterRoute(routeId, proxy, true) } // only for use on client @@ -433,5 +501,6 @@ func ExtractUnverifiedSocketName(tokenStr string) (string, error) { if !ok { return "", fmt.Errorf("sock claim is missing or invalid") } + sockName = wavebase.ExpandHomeDirSafe(sockName) return sockName, nil } diff --git a/pkg/wsl/wsl-unix.go b/pkg/wsl/wsl-unix.go new file mode 100644 index 000000000..055e46669 --- /dev/null +++ b/pkg/wsl/wsl-unix.go @@ -0,0 +1,67 @@ +//go:build !windows + +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package wsl + +import ( + "context" + "fmt" + "io" + "os" + "os/exec" +) + +func RegisteredDistros(ctx context.Context) (distros []Distro, err error) { + return nil, fmt.Errorf("RegisteredDistros not implemented on this system") +} + +func DefaultDistro(ctx context.Context) (d Distro, ok bool, err error) { + return d, false, fmt.Errorf("DefaultDistro not implemented on this system") +} + +type Distro struct{} + +func (d *Distro) Name() string { + return "" +} + +func (d *Distro) WslCommand(ctx context.Context, cmd string) *WslCmd { + return nil +} + +// just use the regular cmd since it's +// similar enough to not cause issues +// type WslCmd = exec.Cmd +type WslCmd struct { + exec.Cmd +} + +func (wc *WslCmd) GetProcess() *os.Process { + return nil +} + +func (wc *WslCmd) GetProcessState() *os.ProcessState { + return nil +} + +func (c *WslCmd) SetStdin(stdin io.Reader) { + c.Stdin = stdin +} + +func (c *WslCmd) SetStdout(stdout io.Writer) { + c.Stdout = stdout +} + +func (c *WslCmd) SetStderr(stderr io.Writer) { + c.Stdout = stderr +} + +func GetDistroCmd(ctx context.Context, wslDistroName string, cmd string) (*WslCmd, error) { + return nil, fmt.Errorf("GetDistroCmd not implemented on this system") +} + +func GetDistro(ctx context.Context, wslDistroName WslName) (*Distro, error) { + return nil, fmt.Errorf("GetDistro not implemented on this system") +} diff --git a/pkg/wsl/wsl-util.go b/pkg/wsl/wsl-util.go new file mode 100644 index 000000000..5d1f70d35 --- /dev/null +++ b/pkg/wsl/wsl-util.go @@ -0,0 +1,296 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package wsl + +import ( + "bytes" + "context" + "errors" + "fmt" + "html/template" + "io" + "log" + "os" + "path/filepath" + "strings" + "time" +) + +func DetectShell(ctx context.Context, client *Distro) (string, error) { + wshPath := GetWshPath(ctx, client) + + cmd := client.WslCommand(ctx, wshPath+" shell") + log.Printf("shell detecting using command: %s shell", wshPath) + out, err := cmd.Output() + if err != nil { + log.Printf("unable to determine shell. defaulting to /bin/bash: %s", err) + return "/bin/bash", nil + } + log.Printf("detecting shell: %s", out) + + // quoting breaks this particular case + return strings.TrimSpace(string(out)), nil +} + +func GetWshVersion(ctx context.Context, client *Distro) (string, error) { + wshPath := GetWshPath(ctx, client) + + cmd := client.WslCommand(ctx, wshPath+" version") + out, err := cmd.Output() + if err != nil { + return "", err + } + + return strings.TrimSpace(string(out)), nil +} + +func GetWshPath(ctx context.Context, client *Distro) string { + defaultPath := "~/.waveterm/bin/wsh" + + cmd := client.WslCommand(ctx, "which wsh") + out, whichErr := cmd.Output() + if whichErr == nil { + return strings.TrimSpace(string(out)) + } + + cmd = client.WslCommand(ctx, "where.exe wsh") + out, whereErr := cmd.Output() + if whereErr == nil { + return strings.TrimSpace(string(out)) + } + + // check cmd on windows since it requires an absolute path with backslashes + cmd = client.WslCommand(ctx, "(dir 2>&1 *``|echo %userprofile%\\.waveterm%\\.waveterm\\bin\\wsh.exe);&<# rem #>echo none") + out, cmdErr := cmd.Output() + if cmdErr == nil && strings.TrimSpace(string(out)) != "none" { + return strings.TrimSpace(string(out)) + } + + // no custom install, use default path + return defaultPath +} + +func hasBashInstalled(ctx context.Context, client *Distro) (bool, error) { + cmd := client.WslCommand(ctx, "which bash") + out, whichErr := cmd.Output() + if whichErr == nil && len(out) != 0 { + return true, nil + } + + cmd = client.WslCommand(ctx, "where.exe bash") + out, whereErr := cmd.Output() + if whereErr == nil && len(out) != 0 { + return true, nil + } + + // note: we could also check in /bin/bash explicitly + // just in case that wasn't added to the path. but if + // that's true, we will most likely have worse + // problems going forward + + return false, nil +} + +func GetClientOs(ctx context.Context, client *Distro) (string, error) { + cmd := client.WslCommand(ctx, "uname -s") + out, unixErr := cmd.Output() + if unixErr == nil { + formatted := strings.ToLower(string(out)) + formatted = strings.TrimSpace(formatted) + return formatted, nil + } + + cmd = client.WslCommand(ctx, "echo %OS%") + out, cmdErr := cmd.Output() + if cmdErr == nil && strings.TrimSpace(string(out)) != "%OS%" { + formatted := strings.ToLower(string(out)) + formatted = strings.TrimSpace(formatted) + return strings.Split(formatted, "_")[0], nil + } + + cmd = client.WslCommand(ctx, "echo $env:OS") + out, psErr := cmd.Output() + if psErr == nil && strings.TrimSpace(string(out)) != "$env:OS" { + formatted := strings.ToLower(string(out)) + formatted = strings.TrimSpace(formatted) + return strings.Split(formatted, "_")[0], nil + } + return "", fmt.Errorf("unable to determine os: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr) +} + +func GetClientArch(ctx context.Context, client *Distro) (string, error) { + cmd := client.WslCommand(ctx, "uname -m") + out, unixErr := cmd.Output() + if unixErr == nil { + formatted := strings.ToLower(string(out)) + formatted = strings.TrimSpace(formatted) + if formatted == "x86_64" { + return "x64", nil + } + return formatted, nil + } + + cmd = client.WslCommand(ctx, "echo %PROCESSOR_ARCHITECTURE%") + out, cmdErr := cmd.Output() + if cmdErr == nil && strings.TrimSpace(string(out)) != "%PROCESSOR_ARCHITECTURE%" { + formatted := strings.ToLower(string(out)) + return strings.TrimSpace(formatted), nil + } + + cmd = client.WslCommand(ctx, "echo $env:PROCESSOR_ARCHITECTURE") + out, psErr := cmd.Output() + if psErr == nil && strings.TrimSpace(string(out)) != "$env:PROCESSOR_ARCHITECTURE" { + formatted := strings.ToLower(string(out)) + return strings.TrimSpace(formatted), nil + } + return "", fmt.Errorf("unable to determine architecture: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr) +} + +type CancellableCmd struct { + Cmd *WslCmd + Cancel func() +} + +var installTemplatesRawBash = map[string]string{ + "mkdir": `bash -c 'mkdir -p {{.installDir}}'`, + "cat": `bash -c 'cat > {{.tempPath}}'`, + "mv": `bash -c 'mv {{.tempPath}} {{.installPath}}'`, + "chmod": `bash -c 'chmod a+x {{.installPath}}'`, +} + +var installTemplatesRawDefault = map[string]string{ + "mkdir": `mkdir -p {{.installDir}}`, + "cat": `cat > {{.tempPath}}`, + "mv": `mv {{.tempPath}} {{.installPath}}`, + "chmod": `chmod a+x {{.installPath}}`, +} + +func makeCancellableCommand(ctx context.Context, client *Distro, cmdTemplateRaw string, words map[string]string) (*CancellableCmd, error) { + cmdContext, cmdCancel := context.WithCancel(ctx) + + cmdStr := &bytes.Buffer{} + cmdTemplate, err := template.New("").Parse(cmdTemplateRaw) + if err != nil { + cmdCancel() + return nil, err + } + cmdTemplate.Execute(cmdStr, words) + + cmd := client.WslCommand(cmdContext, cmdStr.String()) + return &CancellableCmd{cmd, cmdCancel}, nil +} + +func CpHostToRemote(ctx context.Context, client *Distro, sourcePath string, destPath string) error { + // warning: does not work on windows remote yet + bashInstalled, err := hasBashInstalled(ctx, client) + if err != nil { + return err + } + + var selectedTemplatesRaw map[string]string + if bashInstalled { + selectedTemplatesRaw = installTemplatesRawBash + } else { + log.Printf("bash is not installed on remote. attempting with default shell") + selectedTemplatesRaw = installTemplatesRawDefault + } + + // I need to use toSlash here to force unix keybindings + // this means we can't guarantee it will work on a remote windows machine + var installWords = map[string]string{ + "installDir": filepath.ToSlash(filepath.Dir(destPath)), + "tempPath": destPath + ".temp", + "installPath": destPath, + } + + installStepCmds := make(map[string]*CancellableCmd) + for cmdName, selectedTemplateRaw := range selectedTemplatesRaw { + cancellableCmd, err := makeCancellableCommand(ctx, client, selectedTemplateRaw, installWords) + if err != nil { + return err + } + installStepCmds[cmdName] = cancellableCmd + } + + _, err = installStepCmds["mkdir"].Cmd.Output() + if err != nil { + return err + } + + // the cat part of this is complicated since it requires stdin + catCmd := installStepCmds["cat"].Cmd + catStdin, err := catCmd.StdinPipe() + if err != nil { + return err + } + err = catCmd.Start() + if err != nil { + return err + } + input, err := os.Open(sourcePath) + if err != nil { + return fmt.Errorf("cannot open local file %s to send to host: %v", sourcePath, err) + } + go func() { + io.Copy(catStdin, input) + installStepCmds["cat"].Cancel() + + // backup just in case something weird happens + // could cause potential race condition, but very + // unlikely + time.Sleep(time.Second * 1) + process := catCmd.GetProcess() + if process != nil { + process.Kill() + } + }() + catErr := catCmd.Wait() + if catErr != nil && !errors.Is(catErr, context.Canceled) { + return catErr + } + + _, err = installStepCmds["mv"].Cmd.Output() + if err != nil { + return err + } + + _, err = installStepCmds["chmod"].Cmd.Output() + if err != nil { + return err + } + + return nil +} + +func InstallClientRcFiles(ctx context.Context, client *Distro) error { + path := GetWshPath(ctx, client) + log.Printf("path to wsh searched is: %s", path) + + cmd := client.WslCommand(ctx, path+" rcfiles") + _, err := cmd.Output() + return err +} + +func GetHomeDir(ctx context.Context, client *Distro) string { + // note: also works for powershell + cmd := client.WslCommand(ctx, `echo "$HOME"`) + out, err := cmd.Output() + if err == nil { + return strings.TrimSpace(string(out)) + } + + cmd = client.WslCommand(ctx, `echo %userprofile%`) + out, err = cmd.Output() + if err == nil { + return strings.TrimSpace(string(out)) + } + + return "~" +} + +func IsPowershell(shellPath string) bool { + // get the base path, and then check contains + shellBase := filepath.Base(shellPath) + return strings.Contains(shellBase, "powershell") || strings.Contains(shellBase, "pwsh") +} diff --git a/pkg/wsl/wsl-win.go b/pkg/wsl/wsl-win.go new file mode 100644 index 000000000..782e15719 --- /dev/null +++ b/pkg/wsl/wsl-win.go @@ -0,0 +1,125 @@ +//go:build windows + +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package wsl + +import ( + "context" + "fmt" + "io" + "os" + "sync" + + "github.com/ubuntu/gowsl" +) + +var RegisteredDistros = gowsl.RegisteredDistros +var DefaultDistro = gowsl.DefaultDistro + +type Distro struct { + gowsl.Distro +} + +type WslCmd struct { + c *gowsl.Cmd + wg *sync.WaitGroup + once *sync.Once + lock *sync.Mutex + waitErr error +} + +func (d *Distro) WslCommand(ctx context.Context, cmd string) *WslCmd { + if ctx == nil { + panic("nil Context") + } + innerCmd := d.Command(ctx, cmd) + var wg sync.WaitGroup + var lock *sync.Mutex + return &WslCmd{innerCmd, &wg, new(sync.Once), lock, nil} +} + +func (c *WslCmd) CombinedOutput() (out []byte, err error) { + return c.c.CombinedOutput() +} +func (c *WslCmd) Output() (out []byte, err error) { + return c.c.Output() +} +func (c *WslCmd) Run() error { + return c.c.Run() +} +func (c *WslCmd) Start() (err error) { + return c.c.Start() +} +func (c *WslCmd) StderrPipe() (r io.ReadCloser, err error) { + return c.c.StderrPipe() +} +func (c *WslCmd) StdinPipe() (w io.WriteCloser, err error) { + return c.c.StdinPipe() +} +func (c *WslCmd) StdoutPipe() (r io.ReadCloser, err error) { + return c.c.StdoutPipe() +} +func (c *WslCmd) Wait() (err error) { + c.wg.Add(1) + c.once.Do(func() { + c.waitErr = c.c.Wait() + }) + c.wg.Done() + c.wg.Wait() + if c.waitErr != nil && c.waitErr.Error() == "not started" { + c.once = new(sync.Once) + return c.waitErr + } + return c.waitErr +} +func (c *WslCmd) GetProcess() *os.Process { + return c.c.Process +} + +func (c *WslCmd) GetProcessState() *os.ProcessState { + return c.c.ProcessState +} + +func (c *WslCmd) SetStdin(stdin io.Reader) { + c.c.Stdin = stdin +} + +func (c *WslCmd) SetStdout(stdout io.Writer) { + c.c.Stdout = stdout +} + +func (c *WslCmd) SetStderr(stderr io.Writer) { + c.c.Stdout = stderr +} + +func GetDistroCmd(ctx context.Context, wslDistroName string, cmd string) (*WslCmd, error) { + distros, err := RegisteredDistros(ctx) + if err != nil { + return nil, err + } + for _, distro := range distros { + if distro.Name() != wslDistroName { + continue + } + wrappedDistro := Distro{distro} + return wrappedDistro.WslCommand(ctx, cmd), nil + } + return nil, fmt.Errorf("wsl distro %s not found", wslDistroName) +} + +func GetDistro(ctx context.Context, wslDistroName WslName) (*Distro, error) { + distros, err := RegisteredDistros(ctx) + if err != nil { + return nil, err + } + for _, distro := range distros { + if distro.Name() != wslDistroName.Distro { + continue + } + wrappedDistro := Distro{distro} + return &wrappedDistro, nil + } + return nil, fmt.Errorf("wsl distro %s not found", wslDistroName) +} diff --git a/pkg/wsl/wsl.go b/pkg/wsl/wsl.go new file mode 100644 index 000000000..0f5927ebb --- /dev/null +++ b/pkg/wsl/wsl.go @@ -0,0 +1,494 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package wsl + +import ( + "context" + "fmt" + "io" + "log" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/wavetermdev/waveterm/pkg/userinput" + "github.com/wavetermdev/waveterm/pkg/util/shellutil" + "github.com/wavetermdev/waveterm/pkg/wavebase" + "github.com/wavetermdev/waveterm/pkg/waveobj" + "github.com/wavetermdev/waveterm/pkg/wconfig" + "github.com/wavetermdev/waveterm/pkg/wps" + "github.com/wavetermdev/waveterm/pkg/wshrpc" + "github.com/wavetermdev/waveterm/pkg/wshutil" +) + +const ( + Status_Init = "init" + Status_Connecting = "connecting" + Status_Connected = "connected" + Status_Disconnected = "disconnected" + Status_Error = "error" +) + +const DefaultConnectionTimeout = 60 * time.Second + +var globalLock = &sync.Mutex{} +var clientControllerMap = make(map[string]*WslConn) +var activeConnCounter = &atomic.Int32{} + +type WslConn struct { + Lock *sync.Mutex + Status string + Name WslName + Client *Distro + SockName string + DomainSockListener net.Listener + ConnController *WslCmd + Error string + HasWaiter *atomic.Bool + LastConnectTime int64 + ActiveConnNum int + Context context.Context + cancelFn func() +} + +type WslName struct { + Distro string `json:"distro"` +} + +func GetAllConnStatus() []wshrpc.ConnStatus { + globalLock.Lock() + defer globalLock.Unlock() + + var connStatuses []wshrpc.ConnStatus + for _, conn := range clientControllerMap { + connStatuses = append(connStatuses, conn.DeriveConnStatus()) + } + return connStatuses +} + +func (conn *WslConn) DeriveConnStatus() wshrpc.ConnStatus { + conn.Lock.Lock() + defer conn.Lock.Unlock() + return wshrpc.ConnStatus{ + Status: conn.Status, + Connected: conn.Status == Status_Connected, + Connection: conn.GetName(), + HasConnected: (conn.LastConnectTime > 0), + ActiveConnNum: conn.ActiveConnNum, + Error: conn.Error, + } +} + +func (conn *WslConn) FireConnChangeEvent() { + status := conn.DeriveConnStatus() + event := wps.WaveEvent{ + Event: wps.Event_ConnChange, + Scopes: []string{ + fmt.Sprintf("connection:%s", conn.GetName()), + }, + Data: status, + } + log.Printf("sending event: %+#v", event) + wps.Broker.Publish(event) +} + +func (conn *WslConn) Close() error { + defer conn.FireConnChangeEvent() + conn.WithLock(func() { + if conn.Status == Status_Connected || conn.Status == Status_Connecting { + // if status is init, disconnected, or error don't change it + conn.Status = Status_Disconnected + } + conn.close_nolock() + }) + // we must wait for the waiter to complete + startTime := time.Now() + for conn.HasWaiter.Load() { + time.Sleep(10 * time.Millisecond) + if time.Since(startTime) > 2*time.Second { + return fmt.Errorf("timeout waiting for waiter to complete") + } + } + return nil +} + +func (conn *WslConn) close_nolock() { + // does not set status (that should happen at another level) + if conn.DomainSockListener != nil { + conn.DomainSockListener.Close() + conn.DomainSockListener = nil + } + if conn.ConnController != nil { + conn.cancelFn() // this suspends the conn controller + conn.ConnController = nil + } + if conn.Client != nil { + // conn.Client.Close() is not relevant here + // we do not want to completely close the wsl in case + // other applications are using it + conn.Client = nil + } +} + +func (conn *WslConn) GetDomainSocketName() string { + conn.Lock.Lock() + defer conn.Lock.Unlock() + return conn.SockName +} + +func (conn *WslConn) GetStatus() string { + conn.Lock.Lock() + defer conn.Lock.Unlock() + return conn.Status +} + +func (conn *WslConn) GetName() string { + // no lock required because opts is immutable + return "wsl://" + conn.Name.Distro +} + +/** + * This function is does not set a listener for WslConn + * It is still required in order to set SockName +**/ +func (conn *WslConn) OpenDomainSocketListener() error { + var allowed bool + conn.WithLock(func() { + if conn.Status != Status_Connecting { + allowed = false + } else { + allowed = true + } + }) + if !allowed { + return fmt.Errorf("cannot open domain socket for %q when status is %q", conn.GetName(), conn.GetStatus()) + } + conn.WithLock(func() { + conn.SockName = "~/.waveterm/wave-remote.sock" + }) + return nil +} + +func (conn *WslConn) StartConnServer() error { + var allowed bool + conn.WithLock(func() { + if conn.Status != Status_Connecting { + allowed = false + } else { + allowed = true + } + }) + if !allowed { + return fmt.Errorf("cannot start conn server for %q when status is %q", conn.GetName(), conn.GetStatus()) + } + client := conn.GetClient() + wshPath := GetWshPath(conn.Context, client) + rpcCtx := wshrpc.RpcContext{ + ClientType: wshrpc.ClientType_ConnServer, + Conn: conn.GetName(), + } + sockName := conn.GetDomainSocketName() + jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, sockName) + if err != nil { + return fmt.Errorf("unable to create jwt token for conn controller: %w", err) + } + shellPath, err := DetectShell(conn.Context, client) + if err != nil { + return err + } + var cmdStr string + if IsPowershell(shellPath) { + cmdStr = fmt.Sprintf("$env:%s=\"%s\"; %s connserver --router", wshutil.WaveJwtTokenVarName, jwtToken, wshPath) + } else { + cmdStr = fmt.Sprintf("%s=\"%s\" %s connserver --router", wshutil.WaveJwtTokenVarName, jwtToken, wshPath) + } + log.Printf("starting conn controller: %s\n", cmdStr) + cmd := client.WslCommand(conn.Context, cmdStr) + pipeRead, pipeWrite := io.Pipe() + inputPipeRead, inputPipeWrite := io.Pipe() + cmd.SetStdout(pipeWrite) + cmd.SetStderr(pipeWrite) + cmd.SetStdin(inputPipeRead) + err = cmd.Start() + if err != nil { + return fmt.Errorf("unable to start conn controller: %w", err) + } + conn.WithLock(func() { + conn.ConnController = cmd + }) + // service the I/O + go func() { + // wait for termination, clear the controller + defer conn.WithLock(func() { + conn.ConnController = nil + }) + waitErr := cmd.Wait() + log.Printf("conn controller (%q) terminated: %v", conn.GetName(), waitErr) + }() + go func() { + logName := fmt.Sprintf("conncontroller:%s", conn.GetName()) + wshutil.HandleStdIOClient(logName, pipeRead, inputPipeWrite) + }() + regCtx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelFn() + err = wshutil.DefaultRouter.WaitForRegister(regCtx, wshutil.MakeConnectionRouteId(rpcCtx.Conn)) + if err != nil { + return fmt.Errorf("timeout waiting for connserver to register") + } + time.Sleep(300 * time.Millisecond) // TODO remove this sleep (but we need to wait until connserver is "ready") + return nil +} + +type WshInstallOpts struct { + Force bool + NoUserPrompt bool +} + +func (conn *WslConn) CheckAndInstallWsh(ctx context.Context, clientDisplayName string, opts *WshInstallOpts) error { + if opts == nil { + opts = &WshInstallOpts{} + } + client := conn.GetClient() + if client == nil { + return fmt.Errorf("client is nil") + } + // check that correct wsh extensions are installed + expectedVersion := fmt.Sprintf("wsh v%s", wavebase.WaveVersion) + clientVersion, err := GetWshVersion(ctx, client) + if err == nil && clientVersion == expectedVersion && !opts.Force { + return nil + } + var queryText string + var title string + if opts.Force { + queryText = fmt.Sprintf("ReInstalling Wave Shell Extensions (%s) on `%s`\n", wavebase.WaveVersion, clientDisplayName) + title = "Install Wave Shell Extensions" + } else if err != nil { + queryText = fmt.Sprintf("Wave requires Wave Shell Extensions to be \n"+ + "installed on `%s` \n"+ + "to ensure a seamless experience. \n\n"+ + "Would you like to install them?", clientDisplayName) + title = "Install Wave Shell Extensions" + } else { + // don't ask for upgrading the version + opts.NoUserPrompt = true + } + if !opts.NoUserPrompt { + request := &userinput.UserInputRequest{ + ResponseType: "confirm", + QueryText: queryText, + Title: title, + Markdown: true, + CheckBoxMsg: "Don't show me this again", + } + response, err := userinput.GetUserInput(ctx, request) + if err != nil || !response.Confirm { + return err + } + if response.CheckboxStat { + meta := waveobj.MetaMapType{ + wconfig.ConfigKey_ConnAskBeforeWshInstall: false, + } + err := wconfig.SetBaseConfigValue(meta) + if err != nil { + return fmt.Errorf("error setting conn:askbeforewshinstall value: %w", err) + } + } + } + log.Printf("attempting to install wsh to `%s`", clientDisplayName) + clientOs, err := GetClientOs(ctx, client) + if err != nil { + return err + } + clientArch, err := GetClientArch(ctx, client) + if err != nil { + return err + } + // attempt to install extension + wshLocalPath := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch) + err = CpHostToRemote(ctx, client, wshLocalPath, "~/.waveterm/bin/wsh") + if err != nil { + return err + } + log.Printf("successfully installed wsh on %s\n", conn.GetName()) + return nil +} + +func (conn *WslConn) GetClient() *Distro { + conn.Lock.Lock() + defer conn.Lock.Unlock() + return conn.Client +} + +func (conn *WslConn) Reconnect(ctx context.Context) error { + err := conn.Close() + if err != nil { + return err + } + return conn.Connect(ctx) +} + +func (conn *WslConn) WaitForConnect(ctx context.Context) error { + for { + status := conn.DeriveConnStatus() + if status.Status == Status_Connected { + return nil + } + if status.Status == Status_Connecting { + select { + case <-ctx.Done(): + return fmt.Errorf("context timeout") + case <-time.After(100 * time.Millisecond): + continue + } + } + if status.Status == Status_Init || status.Status == Status_Disconnected { + return fmt.Errorf("disconnected") + } + if status.Status == Status_Error { + return fmt.Errorf("error: %v", status.Error) + } + return fmt.Errorf("unknown status: %q", status.Status) + } +} + +// does not return an error since that error is stored inside of WslConn +func (conn *WslConn) Connect(ctx context.Context) error { + var connectAllowed bool + conn.WithLock(func() { + if conn.Status == Status_Connecting || conn.Status == Status_Connected { + connectAllowed = false + } else { + conn.Status = Status_Connecting + conn.Error = "" + connectAllowed = true + } + }) + log.Printf("Connect %s\n", conn.GetName()) + if !connectAllowed { + return fmt.Errorf("cannot connect to %q when status is %q", conn.GetName(), conn.GetStatus()) + } + conn.FireConnChangeEvent() + err := conn.connectInternal(ctx) + conn.WithLock(func() { + if err != nil { + conn.Status = Status_Error + conn.Error = err.Error() + conn.close_nolock() + } else { + conn.Status = Status_Connected + conn.LastConnectTime = time.Now().UnixMilli() + if conn.ActiveConnNum == 0 { + conn.ActiveConnNum = int(activeConnCounter.Add(1)) + } + } + }) + conn.FireConnChangeEvent() + return err +} + +func (conn *WslConn) WithLock(fn func()) { + conn.Lock.Lock() + defer conn.Lock.Unlock() + fn() +} + +func (conn *WslConn) connectInternal(ctx context.Context) error { + client, err := GetDistro(ctx, conn.Name) + if err != nil { + return err + } + conn.WithLock(func() { + conn.Client = client + }) + err = conn.OpenDomainSocketListener() + if err != nil { + return err + } + config := wconfig.ReadFullConfig() + installErr := conn.CheckAndInstallWsh(ctx, conn.GetName(), &WshInstallOpts{NoUserPrompt: !config.Settings.ConnAskBeforeWshInstall}) + if installErr != nil { + return fmt.Errorf("conncontroller %s wsh install error: %v", conn.GetName(), installErr) + } + csErr := conn.StartConnServer() + if csErr != nil { + return fmt.Errorf("conncontroller %s start wsh connserver error: %v", conn.GetName(), csErr) + } + conn.HasWaiter.Store(true) + go conn.waitForDisconnect() + return nil +} + +func (conn *WslConn) waitForDisconnect() { + defer conn.FireConnChangeEvent() + defer conn.HasWaiter.Store(false) + err := conn.ConnController.Wait() + conn.WithLock(func() { + // disconnects happen for a variety of reasons (like network, etc. and are typically transient) + // so we just set the status to "disconnected" here (not error) + // don't overwrite any existing error (or error status) + if err != nil && conn.Error == "" { + conn.Error = err.Error() + } + if conn.Status != Status_Error { + conn.Status = Status_Disconnected + } + conn.close_nolock() + }) +} + +func getConnInternal(name string) *WslConn { + globalLock.Lock() + defer globalLock.Unlock() + connName := WslName{Distro: name} + rtn := clientControllerMap[name] + if rtn == nil { + ctx, cancelFn := context.WithCancel(context.Background()) + rtn = &WslConn{Lock: &sync.Mutex{}, Status: Status_Init, Name: connName, HasWaiter: &atomic.Bool{}, Context: ctx, cancelFn: cancelFn} + clientControllerMap[name] = rtn + } + return rtn +} + +func GetWslConn(ctx context.Context, name string, shouldConnect bool) *WslConn { + conn := getConnInternal(name) + if conn.Client == nil && shouldConnect { + conn.Connect(ctx) + } + return conn +} + +// Convenience function for ensuring a connection is established +func EnsureConnection(ctx context.Context, connName string) error { + if connName == "" { + return nil + } + conn := GetWslConn(ctx, connName, false) + if conn == nil { + return fmt.Errorf("connection not found: %s", connName) + } + connStatus := conn.DeriveConnStatus() + switch connStatus.Status { + case Status_Connected: + return nil + case Status_Connecting: + return conn.WaitForConnect(ctx) + case Status_Init, Status_Disconnected: + return conn.Connect(ctx) + case Status_Error: + return fmt.Errorf("connection error: %s", connStatus.Error) + default: + return fmt.Errorf("unknown connection status %q", connStatus.Status) + } +} + +func DisconnectClient(connName string) error { + conn := getConnInternal(connName) + if conn == nil { + return fmt.Errorf("client %q not found", connName) + } + err := conn.Close() + return err +}