WSL Integration (#1031)

Adds support for connecting to local WSL installations on Windows.

(also adds wshrpcmmultiproxy / connserver router)
This commit is contained in:
Sylvie Crowe 2024-10-23 22:43:17 -07:00 committed by GitHub
parent 4e86b67936
commit 8248637e00
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 2101 additions and 75 deletions

2
.gitattributes vendored
View File

@ -1 +1 @@
* text=auto * text=auto eol=lf

View File

@ -159,11 +159,11 @@ func shutdownActivityUpdate() {
func createMainWshClient() { func createMainWshClient() {
rpc := wshserver.GetMainRpcClient() rpc := wshserver.GetMainRpcClient()
wshutil.DefaultRouter.RegisterRoute(wshutil.DefaultRoute, rpc) wshutil.DefaultRouter.RegisterRoute(wshutil.DefaultRoute, rpc, true)
wps.Broker.SetClient(wshutil.DefaultRouter) wps.Broker.SetClient(wshutil.DefaultRouter)
localConnWsh := wshutil.MakeWshRpc(nil, nil, wshrpc.RpcContext{Conn: wshrpc.LocalConnName}, &wshremote.ServerImpl{}) localConnWsh := wshutil.MakeWshRpc(nil, nil, wshrpc.RpcContext{Conn: wshrpc.LocalConnName}, &wshremote.ServerImpl{})
go wshremote.RunSysInfoLoop(localConnWsh, wshrpc.LocalConnName) go wshremote.RunSysInfoLoop(localConnWsh, wshrpc.LocalConnName)
wshutil.DefaultRouter.RegisterRoute(wshutil.MakeConnectionRouteId(wshrpc.LocalConnName), localConnWsh) wshutil.DefaultRouter.RegisterRoute(wshutil.MakeConnectionRouteId(wshrpc.LocalConnName), localConnWsh, true)
} }
func main() { func main() {

View File

@ -5,6 +5,7 @@ package cmd
import ( import (
"fmt" "fmt"
"strings"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/wavetermdev/waveterm/pkg/remote" "github.com/wavetermdev/waveterm/pkg/remote"
@ -25,17 +26,24 @@ func init() {
} }
func connStatus() error { func connStatus() error {
resp, err := wshclient.ConnStatusCommand(RpcClient, nil) var allResp []wshrpc.ConnStatus
sshResp, err := wshclient.ConnStatusCommand(RpcClient, nil)
if err != nil { if err != nil {
return fmt.Errorf("getting connection status: %w", err) return fmt.Errorf("getting ssh connection status: %w", err)
} }
if len(resp) == 0 { allResp = append(allResp, sshResp...)
wslResp, err := wshclient.WslStatusCommand(RpcClient, nil)
if err != nil {
return fmt.Errorf("getting wsl connection status: %w", err)
}
allResp = append(allResp, wslResp...)
if len(allResp) == 0 {
WriteStdout("no connections\n") WriteStdout("no connections\n")
return nil return nil
} }
WriteStdout("%-30s %-12s\n", "connection", "status") WriteStdout("%-30s %-12s\n", "connection", "status")
WriteStdout("----------------------------------------------\n") WriteStdout("----------------------------------------------\n")
for _, conn := range resp { for _, conn := range allResp {
str := fmt.Sprintf("%-30s %-12s", conn.Connection, conn.Status) str := fmt.Sprintf("%-30s %-12s", conn.Connection, conn.Status)
if conn.Error != "" { if conn.Error != "" {
str += fmt.Sprintf(" (%s)", conn.Error) str += fmt.Sprintf(" (%s)", conn.Error)
@ -110,7 +118,7 @@ func connRun(cmd *cobra.Command, args []string) error {
} }
connName = args[1] connName = args[1]
_, err := remote.ParseOpts(connName) _, err := remote.ParseOpts(connName)
if err != nil { if err != nil && !strings.HasPrefix(connName, "wsl://") {
return fmt.Errorf("cannot parse connection name: %w", err) return fmt.Errorf("cannot parse connection name: %w", err)
} }
} }

View File

@ -4,29 +4,186 @@
package cmd package cmd
import ( import (
"encoding/json"
"fmt"
"io"
"log"
"net"
"os" "os"
"sync/atomic"
"time"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/wavetermdev/waveterm/pkg/util/packetparser"
"github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshremote" "github.com/wavetermdev/waveterm/pkg/wshrpc/wshremote"
"github.com/wavetermdev/waveterm/pkg/wshutil"
) )
var serverCmd = &cobra.Command{ var serverCmd = &cobra.Command{
Use: "connserver", Use: "connserver",
Hidden: true, Hidden: true,
Short: "remote server to power wave blocks", Short: "remote server to power wave blocks",
Args: cobra.NoArgs, Args: cobra.NoArgs,
Run: serverRun, RunE: serverRun,
PreRunE: preRunSetupRpcClient,
} }
var connServerRouter bool
func init() { func init() {
serverCmd.Flags().BoolVar(&connServerRouter, "router", false, "run in local router mode")
rootCmd.AddCommand(serverCmd) rootCmd.AddCommand(serverCmd)
} }
func serverRun(cmd *cobra.Command, args []string) { func MakeRemoteUnixListener() (net.Listener, error) {
serverAddr := wavebase.GetRemoteDomainSocketName()
os.Remove(serverAddr) // ignore error
rtn, err := net.Listen("unix", serverAddr)
if err != nil {
return nil, fmt.Errorf("error creating listener at %v: %v", serverAddr, err)
}
os.Chmod(serverAddr, 0700)
log.Printf("Server [unix-domain] listening on %s\n", serverAddr)
return rtn, nil
}
func handleNewListenerConn(conn net.Conn, router *wshutil.WshRouter) {
var routeIdContainer atomic.Pointer[string]
proxy := wshutil.MakeRpcProxy()
go func() {
writeErr := wshutil.AdaptOutputChToStream(proxy.ToRemoteCh, conn)
if writeErr != nil {
log.Printf("error writing to domain socket: %v\n", writeErr)
}
}()
go func() {
// when input is closed, close the connection
defer func() {
conn.Close()
routeIdPtr := routeIdContainer.Load()
if routeIdPtr != nil && *routeIdPtr != "" {
router.UnregisterRoute(*routeIdPtr)
disposeMsg := &wshutil.RpcMessage{
Command: wshrpc.Command_Dispose,
Data: wshrpc.CommandDisposeData{
RouteId: *routeIdPtr,
},
Source: *routeIdPtr,
AuthToken: proxy.GetAuthToken(),
}
disposeBytes, _ := json.Marshal(disposeMsg)
router.InjectMessage(disposeBytes, *routeIdPtr)
}
}()
wshutil.AdaptStreamToMsgCh(conn, proxy.FromRemoteCh)
}()
routeId, err := proxy.HandleClientProxyAuth(router)
if err != nil {
log.Printf("error handling client proxy auth: %v\n", err)
conn.Close()
return
}
router.RegisterRoute(routeId, proxy, false)
routeIdContainer.Store(&routeId)
}
func runListener(listener net.Listener, router *wshutil.WshRouter) {
defer func() {
log.Printf("listener closed, exiting\n")
time.Sleep(500 * time.Millisecond)
wshutil.DoShutdown("", 1, true)
}()
for {
conn, err := listener.Accept()
if err == io.EOF {
break
}
if err != nil {
log.Printf("error accepting connection: %v\n", err)
continue
}
go handleNewListenerConn(conn, router)
}
}
func setupConnServerRpcClientWithRouter(router *wshutil.WshRouter) (*wshutil.WshRpc, error) {
jwtToken := os.Getenv(wshutil.WaveJwtTokenVarName)
if jwtToken == "" {
return nil, fmt.Errorf("no jwt token found for connserver")
}
rpcCtx, err := wshutil.ExtractUnverifiedRpcContext(jwtToken)
if err != nil {
return nil, fmt.Errorf("error extracting rpc context from %s: %v", wshutil.WaveJwtTokenVarName, err)
}
authRtn, err := router.HandleProxyAuth(jwtToken)
if err != nil {
return nil, fmt.Errorf("error handling proxy auth: %v", err)
}
inputCh := make(chan []byte, wshutil.DefaultInputChSize)
outputCh := make(chan []byte, wshutil.DefaultOutputChSize)
connServerClient := wshutil.MakeWshRpc(inputCh, outputCh, *rpcCtx, &wshremote.ServerImpl{LogWriter: os.Stdout})
connServerClient.SetAuthToken(authRtn.AuthToken)
router.RegisterRoute(authRtn.RouteId, connServerClient, false)
wshclient.RouteAnnounceCommand(connServerClient, nil)
return connServerClient, nil
}
func serverRunRouter() error {
router := wshutil.NewWshRouter()
termProxy := wshutil.MakeRpcProxy()
rawCh := make(chan []byte, wshutil.DefaultOutputChSize)
go packetparser.Parse(os.Stdin, termProxy.FromRemoteCh, rawCh)
go func() {
for msg := range termProxy.ToRemoteCh {
packetparser.WritePacket(os.Stdout, msg)
}
}()
go func() {
// just ignore and drain the rawCh (stdin)
// when stdin is closed, shutdown
defer wshutil.DoShutdown("", 0, true)
for range rawCh {
// ignore
}
}()
go func() {
for msg := range termProxy.FromRemoteCh {
// send this to the router
router.InjectMessage(msg, wshutil.UpstreamRoute)
}
}()
router.SetUpstreamClient(termProxy)
// now set up the domain socket
unixListener, err := MakeRemoteUnixListener()
if err != nil {
return fmt.Errorf("cannot create unix listener: %v", err)
}
client, err := setupConnServerRpcClientWithRouter(router)
if err != nil {
return fmt.Errorf("error setting up connserver rpc client: %v", err)
}
go runListener(unixListener, router)
// run the sysinfo loop
wshremote.RunSysInfoLoop(client, client.GetRpcContext().Conn)
select {}
}
func serverRunNormal() error {
err := setupRpcClient(&wshremote.ServerImpl{LogWriter: os.Stdout})
if err != nil {
return err
}
WriteStdout("running wsh connserver (%s)\n", RpcContext.Conn) WriteStdout("running wsh connserver (%s)\n", RpcContext.Conn)
go wshremote.RunSysInfoLoop(RpcClient, RpcContext.Conn) go wshremote.RunSysInfoLoop(RpcClient, RpcContext.Conn)
RpcClient.SetServerImpl(&wshremote.ServerImpl{LogWriter: os.Stdout})
select {} // run forever select {} // run forever
} }
func serverRun(cmd *cobra.Command, args []string) error {
if connServerRouter {
return serverRunRouter()
} else {
return serverRunNormal()
}
}

60
cmd/wsh/cmd/wshcmd-wsl.go Normal file
View File

@ -0,0 +1,60 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package cmd
import (
"strings"
"github.com/spf13/cobra"
"github.com/wavetermdev/waveterm/pkg/waveobj"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
)
var distroName string
var wslCmd = &cobra.Command{
Use: "wsl [-d <Distro>]",
Short: "connect this terminal to a local wsl connection",
Args: cobra.NoArgs,
Run: wslRun,
PreRunE: preRunSetupRpcClient,
}
func init() {
wslCmd.Flags().StringVarP(&distroName, "distribution", "d", "", "Run the specified distribution")
rootCmd.AddCommand(wslCmd)
}
func wslRun(cmd *cobra.Command, args []string) {
var err error
if distroName == "" {
// get default distro from the host
distroName, err = wshclient.WslDefaultDistroCommand(RpcClient, nil)
if err != nil {
WriteStderr("[error] %s\n", err)
return
}
}
if !strings.HasPrefix(distroName, "wsl://") {
distroName = "wsl://" + distroName
}
blockId := RpcContext.BlockId
if blockId == "" {
WriteStderr("[error] cannot determine blockid (not in JWT)\n")
return
}
data := wshrpc.CommandSetMetaData{
ORef: waveobj.MakeORef(waveobj.OType_Block, blockId),
Meta: map[string]any{
waveobj.MetaKey_Connection: distroName,
},
}
err = wshclient.SetMetaCommand(RpcClient, data, nil)
if err != nil {
WriteStderr("[error] setting switching connection: %v\n", err)
return
}
WriteStderr("switched connection to %q\n", distroName)
}

View File

@ -521,6 +521,7 @@ const ChangeConnectionBlockModal = React.memo(
const connStatusAtom = getConnStatusAtom(connection); const connStatusAtom = getConnStatusAtom(connection);
const connStatus = jotai.useAtomValue(connStatusAtom); const connStatus = jotai.useAtomValue(connStatusAtom);
const [connList, setConnList] = React.useState<Array<string>>([]); const [connList, setConnList] = React.useState<Array<string>>([]);
const [wslList, setWslList] = React.useState<Array<string>>([]);
const allConnStatus = jotai.useAtomValue(atoms.allConnStatus); const allConnStatus = jotai.useAtomValue(atoms.allConnStatus);
const [rowIndex, setRowIndex] = React.useState(0); const [rowIndex, setRowIndex] = React.useState(0);
const connStatusMap = new Map<string, ConnStatus>(); const connStatusMap = new Map<string, ConnStatus>();
@ -540,6 +541,18 @@ const ChangeConnectionBlockModal = React.memo(
prtn.then((newConnList) => { prtn.then((newConnList) => {
setConnList(newConnList ?? []); setConnList(newConnList ?? []);
}).catch((e) => console.log("unable to load conn list from backend. using blank list: ", e)); }).catch((e) => console.log("unable to load conn list from backend. using blank list: ", e));
const p2rtn = RpcApi.WslListCommand(TabRpcClient, { timeout: 2000 });
p2rtn
.then((newWslList) => {
console.log(newWslList);
setWslList(newWslList ?? []);
})
.catch((e) => {
// removing this log and failing silentyly since it will happen
// if a system isn't using the wsl. and would happen every time the
// typeahead was opened. good candidate for verbose log level.
//console.log("unable to load wsl list from backend. using blank list: ", e)
});
}, [changeConnModalOpen, setConnList]); }, [changeConnModalOpen, setConnList]);
const changeConnection = React.useCallback( const changeConnection = React.useCallback(
@ -588,6 +601,15 @@ const ChangeConnectionBlockModal = React.memo(
filteredList.push(conn); filteredList.push(conn);
} }
} }
const filteredWslList: Array<string> = [];
for (const conn of wslList) {
if (conn === connSelected) {
createNew = false;
}
if (conn.includes(connSelected)) {
filteredWslList.push(conn);
}
}
// priority handles special suggestions when necessary // priority handles special suggestions when necessary
// for instance, when reconnecting // for instance, when reconnecting
const newConnectionSuggestion: SuggestionConnectionItem = { const newConnectionSuggestion: SuggestionConnectionItem = {
@ -637,6 +659,20 @@ const ChangeConnectionBlockModal = React.memo(
label: localName, label: localName,
}); });
} }
for (const wslConn of filteredWslList) {
const connStatus = connStatusMap.get(wslConn);
const connColorNum = computeConnColorNum(connStatus);
localSuggestion.items.push({
status: "connected",
icon: "arrow-right-arrow-left",
iconColor:
connStatus?.status == "connected"
? `var(--conn-icon-color-${connColorNum})`
: "var(--grey-text-color)",
value: "wsl://" + wslConn,
label: "wsl://" + wslConn,
});
}
const remoteItems = filteredList.map((connName) => { const remoteItems = filteredList.map((connName) => {
const connStatus = connStatusMap.get(connName); const connStatus = connStatusMap.get(connName);
const connColorNum = computeConnColorNum(connStatus); const connColorNum = computeConnColorNum(connStatus);

View File

@ -72,6 +72,11 @@ class RpcApiType {
return client.wshRpcCall("deleteblock", data, opts); return client.wshRpcCall("deleteblock", data, opts);
} }
// command "dispose" [call]
DisposeCommand(client: WshClient, data: CommandDisposeData, opts?: RpcOpts): Promise<void> {
return client.wshRpcCall("dispose", data, opts);
}
// command "eventpublish" [call] // command "eventpublish" [call]
EventPublishCommand(client: WshClient, data: WaveEvent, opts?: RpcOpts): Promise<void> { EventPublishCommand(client: WshClient, data: WaveEvent, opts?: RpcOpts): Promise<void> {
return client.wshRpcCall("eventpublish", data, opts); return client.wshRpcCall("eventpublish", data, opts);
@ -237,6 +242,21 @@ class RpcApiType {
return client.wshRpcCall("webselector", data, opts); return client.wshRpcCall("webselector", data, opts);
} }
// command "wsldefaultdistro" [call]
WslDefaultDistroCommand(client: WshClient, opts?: RpcOpts): Promise<string> {
return client.wshRpcCall("wsldefaultdistro", null, opts);
}
// command "wsllist" [call]
WslListCommand(client: WshClient, opts?: RpcOpts): Promise<string[]> {
return client.wshRpcCall("wsllist", null, opts);
}
// command "wslstatus" [call]
WslStatusCommand(client: WshClient, opts?: RpcOpts): Promise<ConnStatus[]> {
return client.wshRpcCall("wslstatus", null, opts);
}
} }
export const RpcApi = new RpcApiType(); export const RpcApi = new RpcApiType();

View File

@ -63,6 +63,7 @@ declare global {
// wshrpc.CommandAuthenticateRtnData // wshrpc.CommandAuthenticateRtnData
type CommandAuthenticateRtnData = { type CommandAuthenticateRtnData = {
routeid: string; routeid: string;
authtoken?: string;
}; };
// wshrpc.CommandBlockInputData // wshrpc.CommandBlockInputData
@ -100,6 +101,11 @@ declare global {
blockid: string; blockid: string;
}; };
// wshrpc.CommandDisposeData
type CommandDisposeData = {
routeid: string;
};
// wshrpc.CommandEventReadHistoryData // wshrpc.CommandEventReadHistoryData
type CommandEventReadHistoryData = { type CommandEventReadHistoryData = {
event: string; event: string;
@ -416,6 +422,7 @@ declare global {
resid?: string; resid?: string;
timeout?: number; timeout?: number;
route?: string; route?: string;
authtoken?: string;
source?: string; source?: string;
cont?: boolean; cont?: boolean;
cancel?: boolean; cancel?: boolean;

3
go.mod
View File

@ -21,6 +21,7 @@ require (
github.com/shirou/gopsutil/v4 v4.24.9 github.com/shirou/gopsutil/v4 v4.24.9
github.com/skeema/knownhosts v1.3.0 github.com/skeema/knownhosts v1.3.0
github.com/spf13/cobra v1.8.1 github.com/spf13/cobra v1.8.1
github.com/ubuntu/gowsl v0.0.0-20240906163211-049fd49bd93b
github.com/wavetermdev/htmltoken v0.1.0 github.com/wavetermdev/htmltoken v0.1.0
golang.org/x/crypto v0.28.0 golang.org/x/crypto v0.28.0
golang.org/x/sys v0.26.0 golang.org/x/sys v0.26.0
@ -36,9 +37,11 @@ require (
github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/pflag v1.0.5 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect github.com/tklauser/numcpus v0.6.1 // indirect
github.com/ubuntu/decorate v0.0.0-20230125165522-2d5b0a9bb117 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.uber.org/atomic v1.7.0 // indirect go.uber.org/atomic v1.7.0 // indirect
golang.org/x/net v0.29.0 // indirect golang.org/x/net v0.29.0 // indirect

11
go.sum
View File

@ -1,5 +1,7 @@
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/0xrawsec/golang-utils v1.3.2 h1:ww4jrtHRSnX9xrGzJYbalx5nXoZewy4zPxiY+ubJgtg=
github.com/0xrawsec/golang-utils v1.3.2/go.mod h1:m7AzHXgdSAkFCD9tWWsApxNVxMlyy7anpPVOyT/yM7E=
github.com/alexflint/go-filemutex v1.3.0 h1:LgE+nTUWnQCyRKbpoceKZsPQbs84LivvgwUymZXdOcM= github.com/alexflint/go-filemutex v1.3.0 h1:LgE+nTUWnQCyRKbpoceKZsPQbs84LivvgwUymZXdOcM=
github.com/alexflint/go-filemutex v1.3.0/go.mod h1:U0+VA/i30mGBlLCrFPGtTe9y6wGQfNAWPBTekHQ+c8A= github.com/alexflint/go-filemutex v1.3.0/go.mod h1:U0+VA/i30mGBlLCrFPGtTe9y6wGQfNAWPBTekHQ+c8A=
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
@ -62,6 +64,8 @@ github.com/sawka/txwrap v0.2.0 h1:V3LfvKVLULxcYSxdMguLwFyQFMEU9nFDJopg0ZkL+94=
github.com/sawka/txwrap v0.2.0/go.mod h1:wwQ2SQiN4U+6DU/iVPhbvr7OzXAtgZlQCIGuvOswEfA= github.com/sawka/txwrap v0.2.0/go.mod h1:wwQ2SQiN4U+6DU/iVPhbvr7OzXAtgZlQCIGuvOswEfA=
github.com/shirou/gopsutil/v4 v4.24.9 h1:KIV+/HaHD5ka5f570RZq+2SaeFsb/pq+fp2DGNWYoOI= github.com/shirou/gopsutil/v4 v4.24.9 h1:KIV+/HaHD5ka5f570RZq+2SaeFsb/pq+fp2DGNWYoOI=
github.com/shirou/gopsutil/v4 v4.24.9/go.mod h1:3fkaHNeYsUFCGZ8+9vZVWtbyM1k2eRnlL+bWO8Bxa/Q= github.com/shirou/gopsutil/v4 v4.24.9/go.mod h1:3fkaHNeYsUFCGZ8+9vZVWtbyM1k2eRnlL+bWO8Bxa/Q=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/skeema/knownhosts v1.3.0 h1:AM+y0rI04VksttfwjkSTNQorvGqmwATnvnAHpSgc0LY= github.com/skeema/knownhosts v1.3.0 h1:AM+y0rI04VksttfwjkSTNQorvGqmwATnvnAHpSgc0LY=
github.com/skeema/knownhosts v1.3.0/go.mod h1:sPINvnADmT/qYH1kfv+ePMmOBTH6Tbl7b5LvTDjFK7M= github.com/skeema/knownhosts v1.3.0/go.mod h1:sPINvnADmT/qYH1kfv+ePMmOBTH6Tbl7b5LvTDjFK7M=
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
@ -71,12 +75,17 @@ github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
github.com/ubuntu/decorate v0.0.0-20230125165522-2d5b0a9bb117 h1:XQpsQG5lqRJlx4mUVHcJvyyc1rdTI9nHvwrdfcuy8aM=
github.com/ubuntu/decorate v0.0.0-20230125165522-2d5b0a9bb117/go.mod h1:mx0TjbqsaDD9DUT5gA1s3hw47U6RIbbIBfvGzR85K0g=
github.com/ubuntu/gowsl v0.0.0-20240906163211-049fd49bd93b h1:wFBKF5k5xbJQU8bYgcSoQ/ScvmYyq6KHUabAuVUjOWM=
github.com/ubuntu/gowsl v0.0.0-20240906163211-049fd49bd93b/go.mod h1:N1CYNinssZru+ikvYTgVbVeSi21thHUTCoJ9xMvWe+s=
github.com/wavetermdev/htmltoken v0.1.0 h1:RMdA9zTfnYa5jRC4RRG3XNoV5NOP8EDxpaVPjuVz//Q= github.com/wavetermdev/htmltoken v0.1.0 h1:RMdA9zTfnYa5jRC4RRG3XNoV5NOP8EDxpaVPjuVz//Q=
github.com/wavetermdev/htmltoken v0.1.0/go.mod h1:5FM0XV6zNYiNza2iaTcFGj+hnMtgqumFHO31Z8euquk= github.com/wavetermdev/htmltoken v0.1.0/go.mod h1:5FM0XV6zNYiNza2iaTcFGj+hnMtgqumFHO31Z8euquk=
github.com/wavetermdev/ssh_config v0.0.0-20240306041034-17e2087ebde2 h1:onqZrJVap1sm15AiIGTfWzdr6cEF0KdtddeuuOVhzyY= github.com/wavetermdev/ssh_config v0.0.0-20240306041034-17e2087ebde2 h1:onqZrJVap1sm15AiIGTfWzdr6cEF0KdtddeuuOVhzyY=
@ -91,6 +100,7 @@ golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo=
golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220721230656-c6bc011c0c49/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220721230656-c6bc011c0c49/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@ -102,5 +112,6 @@ golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -11,6 +11,7 @@ import (
"io" "io"
"io/fs" "io/fs"
"log" "log"
"strings"
"sync" "sync"
"time" "time"
@ -24,6 +25,7 @@ import (
"github.com/wavetermdev/waveterm/pkg/wps" "github.com/wavetermdev/waveterm/pkg/wps"
"github.com/wavetermdev/waveterm/pkg/wshrpc" "github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshutil" "github.com/wavetermdev/waveterm/pkg/wshutil"
"github.com/wavetermdev/waveterm/pkg/wsl"
"github.com/wavetermdev/waveterm/pkg/wstore" "github.com/wavetermdev/waveterm/pkg/wstore"
) )
@ -262,7 +264,30 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
return fmt.Errorf("unknown controller type %q", bc.ControllerType) return fmt.Errorf("unknown controller type %q", bc.ControllerType)
} }
var shellProc *shellexec.ShellProc var shellProc *shellexec.ShellProc
if remoteName != "" { if strings.HasPrefix(remoteName, "wsl://") {
wslName := strings.TrimPrefix(remoteName, "wsl://")
credentialCtx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second)
defer cancelFunc()
wslConn := wsl.GetWslConn(credentialCtx, wslName, false)
connStatus := wslConn.DeriveConnStatus()
if connStatus.Status != conncontroller.Status_Connected {
return fmt.Errorf("not connected, cannot start shellproc")
}
// create jwt
if !blockMeta.GetBool(waveobj.MetaKey_CmdNoWsh, false) {
jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId, Conn: wslConn.GetName()}, wslConn.GetDomainSocketName())
if err != nil {
return fmt.Errorf("error making jwt token: %w", err)
}
cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr
}
shellProc, err = shellexec.StartWslShellProc(ctx, rc.TermSize, cmdStr, cmdOpts, wslConn)
if err != nil {
return err
}
} else if remoteName != "" {
credentialCtx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second) credentialCtx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second)
defer cancelFunc() defer cancelFunc()
@ -325,7 +350,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
// we don't need to authenticate this wshProxy since it is coming direct // we don't need to authenticate this wshProxy since it is coming direct
wshProxy := wshutil.MakeRpcProxy() wshProxy := wshutil.MakeRpcProxy()
wshProxy.SetRpcContext(&wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId}) wshProxy.SetRpcContext(&wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId})
wshutil.DefaultRouter.RegisterRoute(wshutil.MakeControllerRouteId(bc.BlockId), wshProxy) wshutil.DefaultRouter.RegisterRoute(wshutil.MakeControllerRouteId(bc.BlockId), wshProxy, true)
ptyBuffer := wshutil.MakePtyBuffer(wshutil.WaveOSCPrefix, bc.ShellProc.Cmd, wshProxy.FromRemoteCh) ptyBuffer := wshutil.MakePtyBuffer(wshutil.WaveOSCPrefix, bc.ShellProc.Cmd, wshProxy.FromRemoteCh)
go func() { go func() {
// handles regular output from the pty (goes to the blockfile and xterm) // handles regular output from the pty (goes to the blockfile and xterm)
@ -494,6 +519,15 @@ func CheckConnStatus(blockId string) error {
if connName == "" { if connName == "" {
return nil return nil
} }
if strings.HasPrefix(connName, "wsl://") {
distroName := strings.TrimPrefix(connName, "wsl://")
conn := wsl.GetWslConn(context.Background(), distroName, false)
connStatus := conn.DeriveConnStatus()
if connStatus.Status != conncontroller.Status_Connected {
return fmt.Errorf("not connected: %s", connStatus.Status)
}
return nil
}
opts, err := remote.ParseOpts(connName) opts, err := remote.ParseOpts(connName)
if err != nil { if err != nil {
return fmt.Errorf("error parsing connection name: %w", err) return fmt.Errorf("error parsing connection name: %w", err)

View File

@ -1,3 +1,6 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package remote package remote
import ( import (

View File

@ -17,6 +17,7 @@ import (
"github.com/wavetermdev/waveterm/pkg/wcore" "github.com/wavetermdev/waveterm/pkg/wcore"
"github.com/wavetermdev/waveterm/pkg/wlayout" "github.com/wavetermdev/waveterm/pkg/wlayout"
"github.com/wavetermdev/waveterm/pkg/wshrpc" "github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wsl"
"github.com/wavetermdev/waveterm/pkg/wstore" "github.com/wavetermdev/waveterm/pkg/wstore"
) )
@ -77,7 +78,9 @@ func (cs *ClientService) MakeWindow(ctx context.Context) (*waveobj.Window, error
} }
func (cs *ClientService) GetAllConnStatus(ctx context.Context) ([]wshrpc.ConnStatus, error) { func (cs *ClientService) GetAllConnStatus(ctx context.Context) ([]wshrpc.ConnStatus, error) {
return conncontroller.GetAllConnStatus(), nil sshStatuses := conncontroller.GetAllConnStatus()
wslStatuses := wsl.GetAllConnStatus()
return append(sshStatuses, wslStatuses...), nil
} }
// moves the window to the front of the windowId stack // moves the window to the front of the windowId stack

View File

@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/creack/pty" "github.com/creack/pty"
"github.com/wavetermdev/waveterm/pkg/wsl"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@ -129,3 +130,42 @@ func (sw SessionWrap) StderrPipe() (io.ReadCloser, error) {
func (sw SessionWrap) SetSize(h int, w int) error { func (sw SessionWrap) SetSize(h int, w int) error {
return sw.Session.WindowChange(h, w) return sw.Session.WindowChange(h, w)
} }
type WslCmdWrap struct {
*wsl.WslCmd
Tty pty.Tty
pty.Pty
}
func (wcw WslCmdWrap) Kill() {
wcw.Tty.Close()
wcw.Close()
}
func (wcw WslCmdWrap) KillGraceful(timeout time.Duration) {
process := wcw.WslCmd.GetProcess()
if process == nil {
return
}
processState := wcw.WslCmd.GetProcessState()
if processState != nil && processState.Exited() {
return
}
process.Signal(os.Interrupt)
go func() {
time.Sleep(timeout)
process := wcw.WslCmd.GetProcess()
processState := wcw.WslCmd.GetProcessState()
if processState == nil || !processState.Exited() {
process.Kill() // force kill if it is already not exited
}
}()
}
/**
* SetSize does nothing for WslCmdWrap as there
* is no pty to manage.
**/
func (wcw WslCmdWrap) SetSize(w int, h int) error {
return nil
}

View File

@ -5,6 +5,7 @@ package shellexec
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"io" "io"
"log" "log"
@ -25,6 +26,7 @@ import (
"github.com/wavetermdev/waveterm/pkg/wavebase" "github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/waveobj" "github.com/wavetermdev/waveterm/pkg/waveobj"
"github.com/wavetermdev/waveterm/pkg/wshutil" "github.com/wavetermdev/waveterm/pkg/wshutil"
"github.com/wavetermdev/waveterm/pkg/wsl"
) )
const DefaultGracefulKillWait = 400 * time.Millisecond const DefaultGracefulKillWait = 400 * time.Millisecond
@ -141,6 +143,96 @@ func (pp *PipePty) WriteString(s string) (n int, err error) {
return pp.Write([]byte(s)) return pp.Write([]byte(s))
} }
func StartWslShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *wsl.WslConn) (*ShellProc, error) {
client := conn.GetClient()
shellPath := cmdOpts.ShellPath
if shellPath == "" {
remoteShellPath, err := wsl.DetectShell(conn.Context, client)
if err != nil {
return nil, err
}
shellPath = remoteShellPath
}
var shellOpts []string
log.Printf("detected shell: %s", shellPath)
err := wsl.InstallClientRcFiles(conn.Context, client)
if err != nil {
log.Printf("error installing rc files: %v", err)
return nil, err
}
homeDir := wsl.GetHomeDir(conn.Context, client)
shellOpts = append(shellOpts, "~", "-d", client.Name())
if isZshShell(shellPath) {
shellOpts = append(shellOpts, fmt.Sprintf(`ZDOTDIR="%s/.waveterm/%s"`, homeDir, shellutil.ZshIntegrationDir))
}
var subShellOpts []string
if cmdStr == "" {
/* transform command in order to inject environment vars */
if isBashShell(shellPath) {
log.Printf("recognized as bash shell")
// add --rcfile
// cant set -l or -i with --rcfile
subShellOpts = append(subShellOpts, "--rcfile", fmt.Sprintf(`%s/.waveterm/%s/.bashrc`, homeDir, shellutil.BashIntegrationDir))
} else if isFishShell(shellPath) {
carg := fmt.Sprintf(`"set -x PATH \"%s\"/.waveterm/%s $PATH"`, homeDir, shellutil.WaveHomeBinDir)
subShellOpts = append(subShellOpts, "-C", carg)
} else if wsl.IsPowershell(shellPath) {
// powershell is weird about quoted path executables and requires an ampersand first
shellPath = "& " + shellPath
subShellOpts = append(subShellOpts, "-ExecutionPolicy", "Bypass", "-NoExit", "-File", homeDir+fmt.Sprintf("/.waveterm/%s/wavepwsh.ps1", shellutil.PwshIntegrationDir))
} else {
if cmdOpts.Login {
subShellOpts = append(subShellOpts, "-l")
}
if cmdOpts.Interactive {
subShellOpts = append(subShellOpts, "-i")
}
// can't set environment vars this way
// will try to do later if possible
}
} else {
shellPath = cmdStr
if cmdOpts.Login {
subShellOpts = append(subShellOpts, "-l")
}
if cmdOpts.Interactive {
subShellOpts = append(subShellOpts, "-i")
}
subShellOpts = append(subShellOpts, "-c", cmdStr)
}
jwtToken, ok := cmdOpts.Env[wshutil.WaveJwtTokenVarName]
if !ok {
return nil, fmt.Errorf("no jwt token provided to connection")
}
if remote.IsPowershell(shellPath) {
shellOpts = append(shellOpts, "--", fmt.Sprintf(`$env:%s=%s;`, wshutil.WaveJwtTokenVarName, jwtToken))
} else {
shellOpts = append(shellOpts, "--", fmt.Sprintf(`%s=%s`, wshutil.WaveJwtTokenVarName, jwtToken))
}
shellOpts = append(shellOpts, shellPath)
shellOpts = append(shellOpts, subShellOpts...)
log.Printf("full cmd is: %s %s", "wsl.exe", strings.Join(shellOpts, " "))
ecmd := exec.Command("wsl.exe", shellOpts...)
if termSize.Rows == 0 || termSize.Cols == 0 {
termSize.Rows = shellutil.DefaultTermRows
termSize.Cols = shellutil.DefaultTermCols
}
if termSize.Rows <= 0 || termSize.Cols <= 0 {
return nil, fmt.Errorf("invalid term size: %v", termSize)
}
cmdPty, err := pty.StartWithSize(ecmd, &pty.Winsize{Rows: uint16(termSize.Rows), Cols: uint16(termSize.Cols)})
if err != nil {
return nil, err
}
return &ShellProc{Cmd: CmdWrap{ecmd, cmdPty}, ConnName: conn.GetName(), CloseOnce: &sync.Once{}, DoneCh: make(chan any)}, nil
}
func StartRemoteShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *conncontroller.SSHConn) (*ShellProc, error) { func StartRemoteShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *conncontroller.SSHConn) (*ShellProc, error) {
client := conn.GetClient() client := conn.GetClient()
shellPath := cmdOpts.ShellPath shellPath := cmdOpts.ShellPath

View File

@ -0,0 +1,58 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package packetparser
import (
"bufio"
"bytes"
"fmt"
"io"
)
type PacketParser struct {
Reader io.Reader
Ch chan []byte
}
func Parse(input io.Reader, packetCh chan []byte, rawCh chan []byte) error {
bufReader := bufio.NewReader(input)
defer close(packetCh)
defer close(rawCh)
for {
line, err := bufReader.ReadBytes('\n')
if err == io.EOF {
return nil
}
if err != nil {
return err
}
if len(line) <= 1 {
// just a blank line
continue
}
if bytes.HasPrefix(line, []byte{'#', '#', 'N', '{'}) && bytes.HasSuffix(line, []byte{'}', '\n'}) {
// strip off the leading "##" and trailing "\n" (single byte)
packetCh <- line[3 : len(line)-1]
} else {
rawCh <- line
}
}
}
func WritePacket(output io.Writer, packet []byte) error {
if len(packet) < 2 {
return nil
}
if packet[0] != '{' || packet[len(packet)-1] != '}' {
return fmt.Errorf("invalid packet, must start with '{' and end with '}'")
}
fullPacket := make([]byte, 0, len(packet)+5)
// we add the extra newline to make sure the ## appears at the beginning of the line
// since writer isn't buffered, we want to send this all at once
fullPacket = append(fullPacket, '\n', '#', '#', 'N')
fullPacket = append(fullPacket, packet...)
fullPacket = append(fullPacket, '\n')
_, err := output.Write(fullPacket)
return err
}

View File

@ -30,10 +30,13 @@ const WaveDataHomeEnvVar = "WAVETERM_DATA_HOME"
const WaveDevVarName = "WAVETERM_DEV" const WaveDevVarName = "WAVETERM_DEV"
const WaveLockFile = "wave.lock" const WaveLockFile = "wave.lock"
const DomainSocketBaseName = "wave.sock" const DomainSocketBaseName = "wave.sock"
const RemoteDomainSocketBaseName = "wave-remote.sock"
const WaveDBDir = "db" const WaveDBDir = "db"
const JwtSecret = "waveterm" // TODO generate and store this const JwtSecret = "waveterm" // TODO generate and store this
const ConfigDir = "config" const ConfigDir = "config"
var RemoteWaveHome = ExpandHomeDirSafe("~/.waveterm")
const WaveAppPathVarName = "WAVETERM_APP_PATH" const WaveAppPathVarName = "WAVETERM_APP_PATH"
const AppPathBinDir = "bin" const AppPathBinDir = "bin"
@ -101,6 +104,10 @@ func GetDomainSocketName() string {
return filepath.Join(GetWaveDataDir(), DomainSocketBaseName) return filepath.Join(GetWaveDataDir(), DomainSocketBaseName)
} }
func GetRemoteDomainSocketName() string {
return filepath.Join(RemoteWaveHome, RemoteDomainSocketBaseName)
}
func GetWaveDataDir() string { func GetWaveDataDir() string {
retVal, found := os.LookupEnv(WaveDataHomeEnvVar) retVal, found := os.LookupEnv(WaveDataHomeEnvVar)
if !found { if !found {

View File

@ -431,7 +431,7 @@ func MakeTCPListener(serviceName string) (net.Listener, error) {
} }
func MakeUnixListener() (net.Listener, error) { func MakeUnixListener() (net.Listener, error) {
serverAddr := wavebase.GetWaveDataDir() + "/wave.sock" serverAddr := wavebase.GetDomainSocketName()
os.Remove(serverAddr) // ignore error os.Remove(serverAddr) // ignore error
rtn, err := net.Listen("unix", serverAddr) rtn, err := net.Listen("unix", serverAddr)
if err != nil { if err != nil {

View File

@ -252,7 +252,7 @@ func registerConn(wsConnId string, routeId string, wproxy *wshutil.WshRpcProxy)
wshutil.DefaultRouter.UnregisterRoute(routeId) wshutil.DefaultRouter.UnregisterRoute(routeId)
} }
RouteToConnMap[routeId] = wsConnId RouteToConnMap[routeId] = wsConnId
wshutil.DefaultRouter.RegisterRoute(routeId, wproxy) wshutil.DefaultRouter.RegisterRoute(routeId, wproxy, true)
} }
func unregisterConn(wsConnId string, routeId string) { func unregisterConn(wsConnId string, routeId string) {

View File

@ -92,6 +92,12 @@ func DeleteBlockCommand(w *wshutil.WshRpc, data wshrpc.CommandDeleteBlockData, o
return err return err
} }
// command "dispose", wshserver.DisposeCommand
func DisposeCommand(w *wshutil.WshRpc, data wshrpc.CommandDisposeData, opts *wshrpc.RpcOpts) error {
_, err := sendRpcRequestCallHelper[any](w, "dispose", data, opts)
return err
}
// command "eventpublish", wshserver.EventPublishCommand // command "eventpublish", wshserver.EventPublishCommand
func EventPublishCommand(w *wshutil.WshRpc, data wps.WaveEvent, opts *wshrpc.RpcOpts) error { func EventPublishCommand(w *wshutil.WshRpc, data wps.WaveEvent, opts *wshrpc.RpcOpts) error {
_, err := sendRpcRequestCallHelper[any](w, "eventpublish", data, opts) _, err := sendRpcRequestCallHelper[any](w, "eventpublish", data, opts)
@ -285,4 +291,22 @@ func WebSelectorCommand(w *wshutil.WshRpc, data wshrpc.CommandWebSelectorData, o
return resp, err return resp, err
} }
// command "wsldefaultdistro", wshserver.WslDefaultDistroCommand
func WslDefaultDistroCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) (string, error) {
resp, err := sendRpcRequestCallHelper[string](w, "wsldefaultdistro", nil, opts)
return resp, err
}
// command "wsllist", wshserver.WslListCommand
func WslListCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) ([]string, error) {
resp, err := sendRpcRequestCallHelper[[]string](w, "wsllist", nil, opts)
return resp, err
}
// command "wslstatus", wshserver.WslStatusCommand
func WslStatusCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) ([]wshrpc.ConnStatus, error) {
resp, err := sendRpcRequestCallHelper[[]wshrpc.ConnStatus](w, "wslstatus", nil, opts)
return resp, err
}

View File

@ -28,6 +28,7 @@ const (
const ( const (
Command_Authenticate = "authenticate" // special Command_Authenticate = "authenticate" // special
Command_Dispose = "dispose" // special (disposes of the route, for multiproxy only)
Command_RouteAnnounce = "routeannounce" // special (for routing) Command_RouteAnnounce = "routeannounce" // special (for routing)
Command_RouteUnannounce = "routeunannounce" // special (for routing) Command_RouteUnannounce = "routeunannounce" // special (for routing)
Command_Message = "message" Command_Message = "message"
@ -62,11 +63,15 @@ const (
Command_RemoteFileDelete = "remotefiledelete" Command_RemoteFileDelete = "remotefiledelete"
Command_RemoteFileJoiin = "remotefilejoin" Command_RemoteFileJoiin = "remotefilejoin"
Command_ConnStatus = "connstatus"
Command_WslStatus = "wslstatus"
Command_ConnEnsure = "connensure" Command_ConnEnsure = "connensure"
Command_ConnReinstallWsh = "connreinstallwsh" Command_ConnReinstallWsh = "connreinstallwsh"
Command_ConnConnect = "connconnect" Command_ConnConnect = "connconnect"
Command_ConnDisconnect = "conndisconnect" Command_ConnDisconnect = "conndisconnect"
Command_ConnList = "connlist" Command_ConnList = "connlist"
Command_WslList = "wsllist"
Command_WslDefaultDistro = "wsldefaultdistro"
Command_WebSelector = "webselector" Command_WebSelector = "webselector"
Command_Notify = "notify" Command_Notify = "notify"
@ -83,6 +88,7 @@ type RespOrErrorUnion[T any] struct {
type WshRpcInterface interface { type WshRpcInterface interface {
AuthenticateCommand(ctx context.Context, data string) (CommandAuthenticateRtnData, error) AuthenticateCommand(ctx context.Context, data string) (CommandAuthenticateRtnData, error)
DisposeCommand(ctx context.Context, data CommandDisposeData) error
RouteAnnounceCommand(ctx context.Context) error // (special) announces a new route to the main router RouteAnnounceCommand(ctx context.Context) error // (special) announces a new route to the main router
RouteUnannounceCommand(ctx context.Context) error // (special) unannounces a route to the main router RouteUnannounceCommand(ctx context.Context) error // (special) unannounces a route to the main router
@ -114,11 +120,14 @@ type WshRpcInterface interface {
// connection functions // connection functions
ConnStatusCommand(ctx context.Context) ([]ConnStatus, error) ConnStatusCommand(ctx context.Context) ([]ConnStatus, error)
WslStatusCommand(ctx context.Context) ([]ConnStatus, error)
ConnEnsureCommand(ctx context.Context, connName string) error ConnEnsureCommand(ctx context.Context, connName string) error
ConnReinstallWshCommand(ctx context.Context, connName string) error ConnReinstallWshCommand(ctx context.Context, connName string) error
ConnConnectCommand(ctx context.Context, connName string) error ConnConnectCommand(ctx context.Context, connName string) error
ConnDisconnectCommand(ctx context.Context, connName string) error ConnDisconnectCommand(ctx context.Context, connName string) error
ConnListCommand(ctx context.Context) ([]string, error) ConnListCommand(ctx context.Context) ([]string, error)
WslListCommand(ctx context.Context) ([]string, error)
WslDefaultDistroCommand(ctx context.Context) (string, error)
// eventrecv is special, it's handled internally by WshRpc with EventListener // eventrecv is special, it's handled internally by WshRpc with EventListener
EventRecvCommand(ctx context.Context, data wps.WaveEvent) error EventRecvCommand(ctx context.Context, data wps.WaveEvent) error
@ -200,7 +209,13 @@ func HackRpcContextIntoData(dataPtr any, rpcContext RpcContext) {
} }
type CommandAuthenticateRtnData struct { type CommandAuthenticateRtnData struct {
RouteId string `json:"routeid"`
AuthToken string `json:"authtoken,omitempty"`
}
type CommandDisposeData struct {
RouteId string `json:"routeid"` RouteId string `json:"routeid"`
// auth token travels in the packet directly
} }
type CommandMessageData struct { type CommandMessageData struct {

View File

@ -21,6 +21,7 @@ import (
"github.com/wavetermdev/waveterm/pkg/filestore" "github.com/wavetermdev/waveterm/pkg/filestore"
"github.com/wavetermdev/waveterm/pkg/remote" "github.com/wavetermdev/waveterm/pkg/remote"
"github.com/wavetermdev/waveterm/pkg/remote/conncontroller" "github.com/wavetermdev/waveterm/pkg/remote/conncontroller"
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
"github.com/wavetermdev/waveterm/pkg/waveai" "github.com/wavetermdev/waveterm/pkg/waveai"
"github.com/wavetermdev/waveterm/pkg/waveobj" "github.com/wavetermdev/waveterm/pkg/waveobj"
"github.com/wavetermdev/waveterm/pkg/wconfig" "github.com/wavetermdev/waveterm/pkg/wconfig"
@ -29,6 +30,7 @@ import (
"github.com/wavetermdev/waveterm/pkg/wps" "github.com/wavetermdev/waveterm/pkg/wps"
"github.com/wavetermdev/waveterm/pkg/wshrpc" "github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshutil" "github.com/wavetermdev/waveterm/pkg/wshutil"
"github.com/wavetermdev/waveterm/pkg/wsl"
"github.com/wavetermdev/waveterm/pkg/wstore" "github.com/wavetermdev/waveterm/pkg/wstore"
) )
@ -36,6 +38,7 @@ const SimpleId_This = "this"
const SimpleId_Tab = "tab" const SimpleId_Tab = "tab"
var SimpleId_BlockNum_Regex = regexp.MustCompile(`^\d+$`) var SimpleId_BlockNum_Regex = regexp.MustCompile(`^\d+$`)
var InvalidWslDistroNames = []string{"docker-desktop", "docker-desktop-data"}
type WshServer struct{} type WshServer struct{}
@ -463,11 +466,28 @@ func (ws *WshServer) ConnStatusCommand(ctx context.Context) ([]wshrpc.ConnStatus
return rtn, nil return rtn, nil
} }
func (ws *WshServer) WslStatusCommand(ctx context.Context) ([]wshrpc.ConnStatus, error) {
rtn := wsl.GetAllConnStatus()
return rtn, nil
}
func (ws *WshServer) ConnEnsureCommand(ctx context.Context, connName string) error { func (ws *WshServer) ConnEnsureCommand(ctx context.Context, connName string) error {
if strings.HasPrefix(connName, "wsl://") {
distroName := strings.TrimPrefix(connName, "wsl://")
return wsl.EnsureConnection(ctx, distroName)
}
return conncontroller.EnsureConnection(ctx, connName) return conncontroller.EnsureConnection(ctx, connName)
} }
func (ws *WshServer) ConnDisconnectCommand(ctx context.Context, connName string) error { func (ws *WshServer) ConnDisconnectCommand(ctx context.Context, connName string) error {
if strings.HasPrefix(connName, "wsl://") {
distroName := strings.TrimPrefix(connName, "wsl://")
conn := wsl.GetWslConn(ctx, distroName, false)
if conn == nil {
return fmt.Errorf("distro not found: %s", connName)
}
return conn.Close()
}
connOpts, err := remote.ParseOpts(connName) connOpts, err := remote.ParseOpts(connName)
if err != nil { if err != nil {
return fmt.Errorf("error parsing connection name: %w", err) return fmt.Errorf("error parsing connection name: %w", err)
@ -480,6 +500,14 @@ func (ws *WshServer) ConnDisconnectCommand(ctx context.Context, connName string)
} }
func (ws *WshServer) ConnConnectCommand(ctx context.Context, connName string) error { func (ws *WshServer) ConnConnectCommand(ctx context.Context, connName string) error {
if strings.HasPrefix(connName, "wsl://") {
distroName := strings.TrimPrefix(connName, "wsl://")
conn := wsl.GetWslConn(ctx, distroName, false)
if conn == nil {
return fmt.Errorf("connection not found: %s", connName)
}
return conn.Connect(ctx)
}
connOpts, err := remote.ParseOpts(connName) connOpts, err := remote.ParseOpts(connName)
if err != nil { if err != nil {
return fmt.Errorf("error parsing connection name: %w", err) return fmt.Errorf("error parsing connection name: %w", err)
@ -492,6 +520,14 @@ func (ws *WshServer) ConnConnectCommand(ctx context.Context, connName string) er
} }
func (ws *WshServer) ConnReinstallWshCommand(ctx context.Context, connName string) error { func (ws *WshServer) ConnReinstallWshCommand(ctx context.Context, connName string) error {
if strings.HasPrefix(connName, "wsl://") {
distroName := strings.TrimPrefix(connName, "wsl://")
conn := wsl.GetWslConn(ctx, distroName, false)
if conn == nil {
return fmt.Errorf("connection not found: %s", connName)
}
return conn.CheckAndInstallWsh(ctx, connName, &wsl.WshInstallOpts{Force: true, NoUserPrompt: true})
}
connOpts, err := remote.ParseOpts(connName) connOpts, err := remote.ParseOpts(connName)
if err != nil { if err != nil {
return fmt.Errorf("error parsing connection name: %w", err) return fmt.Errorf("error parsing connection name: %w", err)
@ -507,6 +543,33 @@ func (ws *WshServer) ConnListCommand(ctx context.Context) ([]string, error) {
return conncontroller.GetConnectionsList() return conncontroller.GetConnectionsList()
} }
func (ws *WshServer) WslListCommand(ctx context.Context) ([]string, error) {
distros, err := wsl.RegisteredDistros(ctx)
if err != nil {
return nil, err
}
var distroNames []string
for _, distro := range distros {
distroName := distro.Name()
if utilfn.ContainsStr(InvalidWslDistroNames, distroName) {
continue
}
distroNames = append(distroNames, distroName)
}
return distroNames, nil
}
func (ws *WshServer) WslDefaultDistroCommand(ctx context.Context) (string, error) {
distro, ok, err := wsl.DefaultDistro(ctx)
if err != nil {
return "", fmt.Errorf("unable to determine default distro: %w", err)
}
if !ok {
return "", fmt.Errorf("unable to determine default distro")
}
return distro.Name(), nil
}
func (ws *WshServer) BlockInfoCommand(ctx context.Context, blockId string) (*wshrpc.BlockInfoData, error) { func (ws *WshServer) BlockInfoCommand(ctx context.Context, blockId string) (*wshrpc.BlockInfoData, error) {
blockData, err := wstore.DBMustGet[*waveobj.Block](ctx, blockId) blockData, err := wstore.DBMustGet[*waveobj.Block](ctx, blockId)
if err != nil { if err != nil {

View File

@ -0,0 +1,151 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package wshutil
import (
"encoding/json"
"fmt"
"sync"
"github.com/google/uuid"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
)
type multiProxyRouteInfo struct {
RouteId string
AuthToken string
Proxy *WshRpcProxy
RpcContext *wshrpc.RpcContext
}
// handles messages from multiple unauthenitcated clients
type WshRpcMultiProxy struct {
Lock *sync.Mutex
RouteInfo map[string]*multiProxyRouteInfo // authtoken to info
ToRemoteCh chan []byte
FromRemoteRawCh chan []byte // raw message from the remote
}
func MakeRpcMultiProxy() *WshRpcMultiProxy {
return &WshRpcMultiProxy{
Lock: &sync.Mutex{},
RouteInfo: make(map[string]*multiProxyRouteInfo),
ToRemoteCh: make(chan []byte, DefaultInputChSize),
FromRemoteRawCh: make(chan []byte, DefaultOutputChSize),
}
}
func (p *WshRpcMultiProxy) DisposeRoutes() {
p.Lock.Lock()
defer p.Lock.Unlock()
for authToken, routeInfo := range p.RouteInfo {
DefaultRouter.UnregisterRoute(routeInfo.RouteId)
delete(p.RouteInfo, authToken)
}
}
func (p *WshRpcMultiProxy) getRouteInfo(authToken string) *multiProxyRouteInfo {
p.Lock.Lock()
defer p.Lock.Unlock()
return p.RouteInfo[authToken]
}
func (p *WshRpcMultiProxy) setRouteInfo(authToken string, routeInfo *multiProxyRouteInfo) {
p.Lock.Lock()
defer p.Lock.Unlock()
p.RouteInfo[authToken] = routeInfo
}
func (p *WshRpcMultiProxy) removeRouteInfo(authToken string) {
p.Lock.Lock()
defer p.Lock.Unlock()
delete(p.RouteInfo, authToken)
}
func (p *WshRpcMultiProxy) sendResponseError(msg RpcMessage, sendErr error) {
if msg.ReqId == "" {
// no response needed
return
}
resp := RpcMessage{
ResId: msg.ReqId,
Error: sendErr.Error(),
}
respBytes, _ := json.Marshal(resp)
p.ToRemoteCh <- respBytes
}
func (p *WshRpcMultiProxy) sendAuthResponse(msg RpcMessage, routeId string, authToken string) {
if msg.ReqId == "" {
// no response needed
return
}
resp := RpcMessage{
ResId: msg.ReqId,
Data: wshrpc.CommandAuthenticateRtnData{RouteId: routeId, AuthToken: authToken},
}
respBytes, _ := json.Marshal(resp)
p.ToRemoteCh <- respBytes
}
func (p *WshRpcMultiProxy) handleUnauthMessage(msgBytes []byte) {
var msg RpcMessage
err := json.Unmarshal(msgBytes, &msg)
if err != nil {
// nothing to do here, malformed message
return
}
if msg.Command == wshrpc.Command_Authenticate {
rpcContext, routeId, err := handleAuthenticationCommand(msg)
if err != nil {
p.sendResponseError(msg, err)
return
}
routeInfo := &multiProxyRouteInfo{
RouteId: routeId,
AuthToken: uuid.New().String(),
RpcContext: rpcContext,
}
routeInfo.Proxy = MakeRpcProxy()
routeInfo.Proxy.SetRpcContext(rpcContext)
p.setRouteInfo(routeInfo.AuthToken, routeInfo)
p.sendAuthResponse(msg, routeId, routeInfo.AuthToken)
go func() {
for msgBytes := range routeInfo.Proxy.ToRemoteCh {
p.ToRemoteCh <- msgBytes
}
}()
DefaultRouter.RegisterRoute(routeId, routeInfo.Proxy, true)
return
}
if msg.AuthToken == "" {
p.sendResponseError(msg, fmt.Errorf("no auth token"))
return
}
routeInfo := p.getRouteInfo(msg.AuthToken)
if routeInfo == nil {
p.sendResponseError(msg, fmt.Errorf("invalid auth token"))
return
}
if msg.Command != "" && msg.Source != routeInfo.RouteId {
p.sendResponseError(msg, fmt.Errorf("invalid source route for auth token"))
return
}
if msg.Command == wshrpc.Command_Dispose {
DefaultRouter.UnregisterRoute(routeInfo.RouteId)
p.removeRouteInfo(msg.AuthToken)
close(routeInfo.Proxy.ToRemoteCh)
close(routeInfo.Proxy.FromRemoteCh)
return
}
routeInfo.Proxy.FromRemoteCh <- msgBytes
}
func (p *WshRpcMultiProxy) RunUnauthLoop() {
// loop over unauthenticated message
// handle Authenicate commands, and pass authenticated messages to the AuthCh
for msgBytes := range p.FromRemoteRawCh {
p.handleUnauthMessage(msgBytes)
}
}

View File

@ -6,7 +6,6 @@ package wshutil
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"log"
"sync" "sync"
"github.com/google/uuid" "github.com/google/uuid"
@ -18,6 +17,7 @@ type WshRpcProxy struct {
RpcContext *wshrpc.RpcContext RpcContext *wshrpc.RpcContext
ToRemoteCh chan []byte ToRemoteCh chan []byte
FromRemoteCh chan []byte FromRemoteCh chan []byte
AuthToken string
} }
func MakeRpcProxy() *WshRpcProxy { func MakeRpcProxy() *WshRpcProxy {
@ -40,6 +40,18 @@ func (p *WshRpcProxy) GetRpcContext() *wshrpc.RpcContext {
return p.RpcContext return p.RpcContext
} }
func (p *WshRpcProxy) SetAuthToken(authToken string) {
p.Lock.Lock()
defer p.Lock.Unlock()
p.AuthToken = authToken
}
func (p *WshRpcProxy) GetAuthToken() string {
p.Lock.Lock()
defer p.Lock.Unlock()
return p.AuthToken
}
func (p *WshRpcProxy) sendResponseError(msg RpcMessage, sendErr error) { func (p *WshRpcProxy) sendResponseError(msg RpcMessage, sendErr error) {
if msg.ReqId == "" { if msg.ReqId == "" {
// no response needed // no response needed
@ -54,7 +66,7 @@ func (p *WshRpcProxy) sendResponseError(msg RpcMessage, sendErr error) {
p.SendRpcMessage(respBytes) p.SendRpcMessage(respBytes)
} }
func (p *WshRpcProxy) sendResponse(msg RpcMessage, routeId string) { func (p *WshRpcProxy) sendAuthenticateResponse(msg RpcMessage, routeId string) {
if msg.ReqId == "" { if msg.ReqId == "" {
// no response needed // no response needed
return return
@ -98,6 +110,49 @@ func handleAuthenticationCommand(msg RpcMessage) (*wshrpc.RpcContext, string, er
return newCtx, routeId, nil return newCtx, routeId, nil
} }
// runs on the client (stdio client)
func (p *WshRpcProxy) HandleClientProxyAuth(router *WshRouter) (string, error) {
for {
msgBytes, ok := <-p.FromRemoteCh
if !ok {
return "", fmt.Errorf("remote closed, not authenticated")
}
var origMsg RpcMessage
err := json.Unmarshal(msgBytes, &origMsg)
if err != nil {
// nothing to do, can't even send a response since we don't have Source or ReqId
continue
}
if origMsg.Command == "" {
// this message is not allowed (protocol error at this point), ignore
continue
}
// we only allow one command "authenticate", everything else returns an error
if origMsg.Command != wshrpc.Command_Authenticate {
respErr := fmt.Errorf("connection not authenticated")
p.sendResponseError(origMsg, respErr)
continue
}
authRtn, err := router.HandleProxyAuth(origMsg.Data)
if err != nil {
respErr := fmt.Errorf("error handling proxy auth: %w", err)
p.sendResponseError(origMsg, respErr)
return "", respErr
}
p.SetAuthToken(authRtn.AuthToken)
announceMsg := RpcMessage{
Command: wshrpc.Command_RouteAnnounce,
Source: authRtn.RouteId,
AuthToken: authRtn.AuthToken,
}
announceBytes, _ := json.Marshal(announceMsg)
router.InjectMessage(announceBytes, authRtn.RouteId)
p.sendAuthenticateResponse(origMsg, authRtn.RouteId)
return authRtn.RouteId, nil
}
}
// runs on the server
func (p *WshRpcProxy) HandleAuthentication() (*wshrpc.RpcContext, error) { func (p *WshRpcProxy) HandleAuthentication() (*wshrpc.RpcContext, error) {
for { for {
msgBytes, ok := <-p.FromRemoteCh msgBytes, ok := <-p.FromRemoteCh
@ -122,11 +177,10 @@ func (p *WshRpcProxy) HandleAuthentication() (*wshrpc.RpcContext, error) {
} }
newCtx, routeId, err := handleAuthenticationCommand(msg) newCtx, routeId, err := handleAuthenticationCommand(msg)
if err != nil { if err != nil {
log.Printf("error handling authentication: %v\n", err)
p.sendResponseError(msg, err) p.sendResponseError(msg, err)
continue continue
} }
p.sendResponse(msg, routeId) p.sendAuthenticateResponse(msg, routeId)
return newCtx, nil return newCtx, nil
} }
} }
@ -136,9 +190,10 @@ func (p *WshRpcProxy) SendRpcMessage(msg []byte) {
} }
func (p *WshRpcProxy) RecvRpcMessage() ([]byte, bool) { func (p *WshRpcProxy) RecvRpcMessage() ([]byte, bool) {
msgBytes, ok := <-p.FromRemoteCh msgBytes, more := <-p.FromRemoteCh
if !ok || p.RpcContext == nil { authToken := p.GetAuthToken()
return msgBytes, ok if !more || (p.RpcContext == nil && authToken == "") {
return msgBytes, more
} }
var msg RpcMessage var msg RpcMessage
err := json.Unmarshal(msgBytes, &msg) err := json.Unmarshal(msgBytes, &msg)
@ -146,10 +201,15 @@ func (p *WshRpcProxy) RecvRpcMessage() ([]byte, bool) {
// nothing to do here -- will error out at another level // nothing to do here -- will error out at another level
return msgBytes, true return msgBytes, true
} }
msg.Data, err = recodeCommandData(msg.Command, msg.Data, p.RpcContext) if p.RpcContext != nil {
if err != nil { msg.Data, err = recodeCommandData(msg.Command, msg.Data, p.RpcContext)
// nothing to do here -- will error out at another level if err != nil {
return msgBytes, true // nothing to do here -- will error out at another level
return msgBytes, true
}
}
if msg.AuthToken == "" {
msg.AuthToken = authToken
} }
newBytes, err := json.Marshal(msg) newBytes, err := json.Marshal(msg)
if err != nil { if err != nil {

View File

@ -12,11 +12,14 @@ import (
"sync" "sync"
"time" "time"
"github.com/google/uuid"
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
"github.com/wavetermdev/waveterm/pkg/wps" "github.com/wavetermdev/waveterm/pkg/wps"
"github.com/wavetermdev/waveterm/pkg/wshrpc" "github.com/wavetermdev/waveterm/pkg/wshrpc"
) )
const DefaultRoute = "wavesrv" const DefaultRoute = "wavesrv"
const UpstreamRoute = "upstream"
const SysRoute = "sys" // this route doesn't exist, just a placeholder for system messages const SysRoute = "sys" // this route doesn't exist, just a placeholder for system messages
const ElectronRoute = "electron" const ElectronRoute = "electron"
@ -36,12 +39,13 @@ type msgAndRoute struct {
} }
type WshRouter struct { type WshRouter struct {
Lock *sync.Mutex Lock *sync.Mutex
RouteMap map[string]AbstractRpcClient // routeid => client RouteMap map[string]AbstractRpcClient // routeid => client
UpstreamClient AbstractRpcClient // upstream client (if we are not the terminal router) UpstreamClient AbstractRpcClient // upstream client (if we are not the terminal router)
AnnouncedRoutes map[string]string // routeid => local routeid AnnouncedRoutes map[string]string // routeid => local routeid
RpcMap map[string]*routeInfo // rpcid => routeinfo RpcMap map[string]*routeInfo // rpcid => routeinfo
InputCh chan msgAndRoute SimpleRequestMap map[string]chan *RpcMessage // simple reqid => response channel
InputCh chan msgAndRoute
} }
func MakeConnectionRouteId(connId string) string { func MakeConnectionRouteId(connId string) string {
@ -68,11 +72,12 @@ var DefaultRouter = NewWshRouter()
func NewWshRouter() *WshRouter { func NewWshRouter() *WshRouter {
rtn := &WshRouter{ rtn := &WshRouter{
Lock: &sync.Mutex{}, Lock: &sync.Mutex{},
RouteMap: make(map[string]AbstractRpcClient), RouteMap: make(map[string]AbstractRpcClient),
AnnouncedRoutes: make(map[string]string), AnnouncedRoutes: make(map[string]string),
RpcMap: make(map[string]*routeInfo), RpcMap: make(map[string]*routeInfo),
InputCh: make(chan msgAndRoute, DefaultInputChSize), SimpleRequestMap: make(map[string]chan *RpcMessage),
InputCh: make(chan msgAndRoute, DefaultInputChSize),
} }
go rtn.runServer() go rtn.runServer()
return rtn return rtn
@ -237,6 +242,10 @@ func (router *WshRouter) runServer() {
router.sendRoutedMessage(msgBytes, routeInfo.DestRouteId) router.sendRoutedMessage(msgBytes, routeInfo.DestRouteId)
continue continue
} else if msg.ResId != "" { } else if msg.ResId != "" {
ok := router.trySimpleResponse(&msg)
if ok {
continue
}
routeInfo := router.getRouteInfo(msg.ResId) routeInfo := router.getRouteInfo(msg.ResId)
if routeInfo == nil { if routeInfo == nil {
// no route info, nothing to do // no route info, nothing to do
@ -269,10 +278,10 @@ func (router *WshRouter) WaitForRegister(ctx context.Context, routeId string) er
} }
// this will also consume the output channel of the abstract client // this will also consume the output channel of the abstract client
func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient) { func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient, shouldAnnounce bool) {
if routeId == SysRoute { if routeId == SysRoute || routeId == UpstreamRoute {
// cannot register sys route // cannot register sys route
log.Printf("error: WshRouter cannot register sys route\n") log.Printf("error: WshRouter cannot register %s route\n", routeId)
return return
} }
log.Printf("[router] registering wsh route %q\n", routeId) log.Printf("[router] registering wsh route %q\n", routeId)
@ -285,7 +294,7 @@ func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient) {
router.RouteMap[routeId] = rpc router.RouteMap[routeId] = rpc
go func() { go func() {
// announce // announce
if !alreadyExists && router.GetUpstreamClient() != nil { if shouldAnnounce && !alreadyExists && router.GetUpstreamClient() != nil {
announceMsg := RpcMessage{Command: wshrpc.Command_RouteAnnounce, Source: routeId} announceMsg := RpcMessage{Command: wshrpc.Command_RouteAnnounce, Source: routeId}
announceBytes, _ := json.Marshal(announceMsg) announceBytes, _ := json.Marshal(announceMsg)
router.GetUpstreamClient().SendRpcMessage(announceBytes) router.GetUpstreamClient().SendRpcMessage(announceBytes)
@ -352,3 +361,97 @@ func (router *WshRouter) GetUpstreamClient() AbstractRpcClient {
defer router.Lock.Unlock() defer router.Lock.Unlock()
return router.UpstreamClient return router.UpstreamClient
} }
func (router *WshRouter) InjectMessage(msgBytes []byte, fromRouteId string) {
router.InputCh <- msgAndRoute{msgBytes: msgBytes, fromRouteId: fromRouteId}
}
func (router *WshRouter) registerSimpleRequest(reqId string) chan *RpcMessage {
router.Lock.Lock()
defer router.Lock.Unlock()
rtn := make(chan *RpcMessage, 1)
router.SimpleRequestMap[reqId] = rtn
return rtn
}
func (router *WshRouter) trySimpleResponse(msg *RpcMessage) bool {
router.Lock.Lock()
defer router.Lock.Unlock()
respCh := router.SimpleRequestMap[msg.ResId]
if respCh == nil {
return false
}
respCh <- msg
delete(router.SimpleRequestMap, msg.ResId)
return true
}
func (router *WshRouter) clearSimpleRequest(reqId string) {
router.Lock.Lock()
defer router.Lock.Unlock()
delete(router.SimpleRequestMap, reqId)
}
func (router *WshRouter) RunSimpleRawCommand(ctx context.Context, msg RpcMessage, fromRouteId string) (*RpcMessage, error) {
if msg.Command == "" {
return nil, errors.New("no command")
}
msgBytes, err := json.Marshal(msg)
if err != nil {
return nil, err
}
var respCh chan *RpcMessage
if msg.ReqId != "" {
respCh = router.registerSimpleRequest(msg.ReqId)
}
router.InjectMessage(msgBytes, fromRouteId)
if respCh == nil {
return nil, nil
}
select {
case <-ctx.Done():
router.clearSimpleRequest(msg.ReqId)
return nil, ctx.Err()
case resp := <-respCh:
if resp.Error != "" {
return nil, errors.New(resp.Error)
}
return resp, nil
}
}
func (router *WshRouter) HandleProxyAuth(jwtTokenAny any) (*wshrpc.CommandAuthenticateRtnData, error) {
if jwtTokenAny == nil {
return nil, errors.New("no jwt token")
}
jwtToken, ok := jwtTokenAny.(string)
if !ok {
return nil, errors.New("jwt token not a string")
}
if jwtToken == "" {
return nil, errors.New("empty jwt token")
}
msg := RpcMessage{
Command: wshrpc.Command_Authenticate,
ReqId: uuid.New().String(),
Data: jwtToken,
}
ctx, cancelFn := context.WithTimeout(context.Background(), DefaultTimeoutMs*time.Millisecond)
defer cancelFn()
resp, err := router.RunSimpleRawCommand(ctx, msg, "")
if err != nil {
return nil, err
}
if resp == nil || resp.Data == nil {
return nil, errors.New("no data in authenticate response")
}
var respData wshrpc.CommandAuthenticateRtnData
err = utilfn.ReUnmarshal(&respData, resp.Data)
if err != nil {
return nil, fmt.Errorf("error unmarshalling authenticate response: %v", err)
}
if respData.AuthToken == "" {
return nil, errors.New("no auth token in authenticate response")
}
return &respData, nil
}

View File

@ -45,10 +45,13 @@ type WshRpc struct {
InputCh chan []byte InputCh chan []byte
OutputCh chan []byte OutputCh chan []byte
RpcContext *atomic.Pointer[wshrpc.RpcContext] RpcContext *atomic.Pointer[wshrpc.RpcContext]
AuthToken string
RpcMap map[string]*rpcData RpcMap map[string]*rpcData
ServerImpl ServerImpl ServerImpl ServerImpl
EventListener *EventListener EventListener *EventListener
ResponseHandlerMap map[string]*RpcResponseHandler // reqId => handler ResponseHandlerMap map[string]*RpcResponseHandler // reqId => handler
Debug bool
DebugName string
} }
type wshRpcContextKey struct{} type wshRpcContextKey struct{}
@ -104,17 +107,18 @@ func (w *WshRpc) RecvRpcMessage() ([]byte, bool) {
} }
type RpcMessage struct { type RpcMessage struct {
Command string `json:"command,omitempty"` Command string `json:"command,omitempty"`
ReqId string `json:"reqid,omitempty"` ReqId string `json:"reqid,omitempty"`
ResId string `json:"resid,omitempty"` ResId string `json:"resid,omitempty"`
Timeout int `json:"timeout,omitempty"` Timeout int `json:"timeout,omitempty"`
Route string `json:"route,omitempty"` // to route/forward requests to alternate servers Route string `json:"route,omitempty"` // to route/forward requests to alternate servers
Source string `json:"source,omitempty"` // source route id AuthToken string `json:"authtoken,omitempty"` // needed for routing unauthenticated requests (WshRpcMultiProxy)
Cont bool `json:"cont,omitempty"` // flag if additional requests/responses are forthcoming Source string `json:"source,omitempty"` // source route id
Cancel bool `json:"cancel,omitempty"` // used to cancel a streaming request or response (sent from the side that is not streaming) Cont bool `json:"cont,omitempty"` // flag if additional requests/responses are forthcoming
Error string `json:"error,omitempty"` Cancel bool `json:"cancel,omitempty"` // used to cancel a streaming request or response (sent from the side that is not streaming)
DataType string `json:"datatype,omitempty"` Error string `json:"error,omitempty"`
Data any `json:"data,omitempty"` DataType string `json:"datatype,omitempty"`
Data any `json:"data,omitempty"`
} }
func (r *RpcMessage) IsRpcRequest() bool { func (r *RpcMessage) IsRpcRequest() bool {
@ -226,6 +230,14 @@ func (w *WshRpc) SetRpcContext(ctx wshrpc.RpcContext) {
w.RpcContext.Store(&ctx) w.RpcContext.Store(&ctx)
} }
func (w *WshRpc) SetAuthToken(token string) {
w.AuthToken = token
}
func (w *WshRpc) GetAuthToken() string {
return w.AuthToken
}
func (w *WshRpc) registerResponseHandler(reqId string, handler *RpcResponseHandler) { func (w *WshRpc) registerResponseHandler(reqId string, handler *RpcResponseHandler) {
w.Lock.Lock() w.Lock.Lock()
defer w.Lock.Unlock() defer w.Lock.Unlock()
@ -323,6 +335,9 @@ func (w *WshRpc) handleRequest(req *RpcMessage) {
func (w *WshRpc) runServer() { func (w *WshRpc) runServer() {
defer close(w.OutputCh) defer close(w.OutputCh)
for msgBytes := range w.InputCh { for msgBytes := range w.InputCh {
if w.Debug {
log.Printf("[%s] received message: %s\n", w.DebugName, string(msgBytes))
}
var msg RpcMessage var msg RpcMessage
err := json.Unmarshal(msgBytes, &msg) err := json.Unmarshal(msgBytes, &msg)
if err != nil { if err != nil {
@ -455,8 +470,9 @@ func (handler *RpcRequestHandler) SendCancel() {
} }
}() }()
msg := &RpcMessage{ msg := &RpcMessage{
Cancel: true, Cancel: true,
ReqId: handler.reqId, ReqId: handler.reqId,
AuthToken: handler.w.GetAuthToken(),
} }
barr, _ := json.Marshal(msg) // will never fail barr, _ := json.Marshal(msg) // will never fail
handler.w.OutputCh <- barr handler.w.OutputCh <- barr
@ -550,6 +566,7 @@ func (handler *RpcResponseHandler) SendMessage(msg string) {
Data: wshrpc.CommandMessageData{ Data: wshrpc.CommandMessageData{
Message: msg, Message: msg,
}, },
AuthToken: handler.w.GetAuthToken(),
} }
msgBytes, _ := json.Marshal(rpcMsg) // will never fail msgBytes, _ := json.Marshal(rpcMsg) // will never fail
handler.w.OutputCh <- msgBytes handler.w.OutputCh <- msgBytes
@ -573,9 +590,10 @@ func (handler *RpcResponseHandler) SendResponse(data any, done bool) error {
defer handler.close() defer handler.close()
} }
msg := &RpcMessage{ msg := &RpcMessage{
ResId: handler.reqId, ResId: handler.reqId,
Data: data, Data: data,
Cont: !done, Cont: !done,
AuthToken: handler.w.GetAuthToken(),
} }
barr, err := json.Marshal(msg) barr, err := json.Marshal(msg)
if err != nil { if err != nil {
@ -598,8 +616,9 @@ func (handler *RpcResponseHandler) SendResponseError(err error) {
} }
defer handler.close() defer handler.close()
msg := &RpcMessage{ msg := &RpcMessage{
ResId: handler.reqId, ResId: handler.reqId,
Error: err.Error(), Error: err.Error(),
AuthToken: handler.w.GetAuthToken(),
} }
barr, _ := json.Marshal(msg) // will never fail barr, _ := json.Marshal(msg) // will never fail
handler.w.OutputCh <- barr handler.w.OutputCh <- barr
@ -660,11 +679,12 @@ func (w *WshRpc) SendComplexRequest(command string, data any, opts *wshrpc.RpcOp
handler.reqId = uuid.New().String() handler.reqId = uuid.New().String()
} }
req := &RpcMessage{ req := &RpcMessage{
Command: command, Command: command,
ReqId: handler.reqId, ReqId: handler.reqId,
Data: data, Data: data,
Timeout: timeoutMs, Timeout: timeoutMs,
Route: opts.Route, Route: opts.Route,
AuthToken: w.GetAuthToken(),
} }
barr, err := json.Marshal(req) barr, err := json.Marshal(req)
if err != nil { if err != nil {

View File

@ -19,6 +19,7 @@ import (
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/wavetermdev/waveterm/pkg/util/packetparser"
"github.com/wavetermdev/waveterm/pkg/wavebase" "github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/wshrpc" "github.com/wavetermdev/waveterm/pkg/wshrpc"
"golang.org/x/term" "golang.org/x/term"
@ -204,11 +205,26 @@ func SetupTerminalRpcClient(serverImpl ServerImpl) (*WshRpc, io.Reader) {
continue continue
} }
os.Stdout.Write(barr) os.Stdout.Write(barr)
os.Stdout.Write([]byte{'\n'})
} }
}() }()
return rpcClient, ptyBuf return rpcClient, ptyBuf
} }
func SetupPacketRpcClient(input io.Reader, output io.Writer, serverImpl ServerImpl) (*WshRpc, chan []byte) {
messageCh := make(chan []byte, DefaultInputChSize)
outputCh := make(chan []byte, DefaultOutputChSize)
rawCh := make(chan []byte, DefaultOutputChSize)
rpcClient := MakeWshRpc(messageCh, outputCh, wshrpc.RpcContext{}, serverImpl)
go packetparser.Parse(input, messageCh, rawCh)
go func() {
for msg := range outputCh {
packetparser.WritePacket(output, msg)
}
}()
return rpcClient, rawCh
}
func SetupConnRpcClient(conn net.Conn, serverImpl ServerImpl) (*WshRpc, chan error, error) { func SetupConnRpcClient(conn net.Conn, serverImpl ServerImpl) (*WshRpc, chan error, error) {
inputCh := make(chan []byte, DefaultInputChSize) inputCh := make(chan []byte, DefaultInputChSize)
outputCh := make(chan []byte, DefaultOutputChSize) outputCh := make(chan []byte, DefaultOutputChSize)
@ -229,10 +245,22 @@ func SetupConnRpcClient(conn net.Conn, serverImpl ServerImpl) (*WshRpc, chan err
return rtn, writeErrCh, nil return rtn, writeErrCh, nil
} }
func SetupDomainSocketRpcClient(sockName string, serverImpl ServerImpl) (*WshRpc, error) { func tryTcpSocket(sockName string) (net.Conn, error) {
conn, err := net.Dial("unix", sockName) addr, err := net.ResolveTCPAddr("tcp", sockName)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to connect to Unix domain socket: %w", err) return nil, err
}
return net.DialTCP("tcp", nil, addr)
}
func SetupDomainSocketRpcClient(sockName string, serverImpl ServerImpl) (*WshRpc, error) {
conn, tcpErr := tryTcpSocket(sockName)
var unixErr error
if tcpErr != nil {
conn, unixErr = net.Dial("unix", sockName)
}
if tcpErr != nil && unixErr != nil {
return nil, fmt.Errorf("failed to connect to tcp or unix domain socket: tcp err:%w: unix socket err: %w", tcpErr, unixErr)
} }
rtn, errCh, err := SetupConnRpcClient(conn, serverImpl) rtn, errCh, err := SetupConnRpcClient(conn, serverImpl)
go func() { go func() {
@ -363,6 +391,46 @@ func MakeRouteIdFromCtx(rpcCtx *wshrpc.RpcContext) (string, error) {
return MakeProcRouteId(procId), nil return MakeProcRouteId(procId), nil
} }
type WriteFlusher interface {
Write([]byte) (int, error)
Flush() error
}
// blocking, returns if there is an error, or on EOF of input
func HandleStdIOClient(logName string, input io.Reader, output io.Writer) {
proxy := MakeRpcMultiProxy()
rawCh := make(chan []byte, DefaultInputChSize)
go packetparser.Parse(input, proxy.FromRemoteRawCh, rawCh)
doneCh := make(chan struct{})
var doneOnce sync.Once
closeDoneCh := func() {
doneOnce.Do(func() {
close(doneCh)
})
proxy.DisposeRoutes()
}
go func() {
proxy.RunUnauthLoop()
}()
go func() {
defer closeDoneCh()
for msg := range proxy.ToRemoteCh {
err := packetparser.WritePacket(output, msg)
if err != nil {
log.Printf("[%s] error writing to output: %v\n", logName, err)
break
}
}
}()
go func() {
defer closeDoneCh()
for msg := range rawCh {
log.Printf("[%s:stdout] %s", logName, msg)
}
}()
<-doneCh
}
func handleDomainSocketClient(conn net.Conn) { func handleDomainSocketClient(conn net.Conn) {
var routeIdContainer atomic.Pointer[string] var routeIdContainer atomic.Pointer[string]
proxy := MakeRpcProxy() proxy := MakeRpcProxy()
@ -399,7 +467,7 @@ func handleDomainSocketClient(conn net.Conn) {
return return
} }
routeIdContainer.Store(&routeId) routeIdContainer.Store(&routeId)
DefaultRouter.RegisterRoute(routeId, proxy) DefaultRouter.RegisterRoute(routeId, proxy, true)
} }
// only for use on client // only for use on client
@ -433,5 +501,6 @@ func ExtractUnverifiedSocketName(tokenStr string) (string, error) {
if !ok { if !ok {
return "", fmt.Errorf("sock claim is missing or invalid") return "", fmt.Errorf("sock claim is missing or invalid")
} }
sockName = wavebase.ExpandHomeDirSafe(sockName)
return sockName, nil return sockName, nil
} }

67
pkg/wsl/wsl-unix.go Normal file
View File

@ -0,0 +1,67 @@
//go:build !windows
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package wsl
import (
"context"
"fmt"
"io"
"os"
"os/exec"
)
func RegisteredDistros(ctx context.Context) (distros []Distro, err error) {
return nil, fmt.Errorf("RegisteredDistros not implemented on this system")
}
func DefaultDistro(ctx context.Context) (d Distro, ok bool, err error) {
return d, false, fmt.Errorf("DefaultDistro not implemented on this system")
}
type Distro struct{}
func (d *Distro) Name() string {
return ""
}
func (d *Distro) WslCommand(ctx context.Context, cmd string) *WslCmd {
return nil
}
// just use the regular cmd since it's
// similar enough to not cause issues
// type WslCmd = exec.Cmd
type WslCmd struct {
exec.Cmd
}
func (wc *WslCmd) GetProcess() *os.Process {
return nil
}
func (wc *WslCmd) GetProcessState() *os.ProcessState {
return nil
}
func (c *WslCmd) SetStdin(stdin io.Reader) {
c.Stdin = stdin
}
func (c *WslCmd) SetStdout(stdout io.Writer) {
c.Stdout = stdout
}
func (c *WslCmd) SetStderr(stderr io.Writer) {
c.Stdout = stderr
}
func GetDistroCmd(ctx context.Context, wslDistroName string, cmd string) (*WslCmd, error) {
return nil, fmt.Errorf("GetDistroCmd not implemented on this system")
}
func GetDistro(ctx context.Context, wslDistroName WslName) (*Distro, error) {
return nil, fmt.Errorf("GetDistro not implemented on this system")
}

296
pkg/wsl/wsl-util.go Normal file
View File

@ -0,0 +1,296 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package wsl
import (
"bytes"
"context"
"errors"
"fmt"
"html/template"
"io"
"log"
"os"
"path/filepath"
"strings"
"time"
)
func DetectShell(ctx context.Context, client *Distro) (string, error) {
wshPath := GetWshPath(ctx, client)
cmd := client.WslCommand(ctx, wshPath+" shell")
log.Printf("shell detecting using command: %s shell", wshPath)
out, err := cmd.Output()
if err != nil {
log.Printf("unable to determine shell. defaulting to /bin/bash: %s", err)
return "/bin/bash", nil
}
log.Printf("detecting shell: %s", out)
// quoting breaks this particular case
return strings.TrimSpace(string(out)), nil
}
func GetWshVersion(ctx context.Context, client *Distro) (string, error) {
wshPath := GetWshPath(ctx, client)
cmd := client.WslCommand(ctx, wshPath+" version")
out, err := cmd.Output()
if err != nil {
return "", err
}
return strings.TrimSpace(string(out)), nil
}
func GetWshPath(ctx context.Context, client *Distro) string {
defaultPath := "~/.waveterm/bin/wsh"
cmd := client.WslCommand(ctx, "which wsh")
out, whichErr := cmd.Output()
if whichErr == nil {
return strings.TrimSpace(string(out))
}
cmd = client.WslCommand(ctx, "where.exe wsh")
out, whereErr := cmd.Output()
if whereErr == nil {
return strings.TrimSpace(string(out))
}
// check cmd on windows since it requires an absolute path with backslashes
cmd = client.WslCommand(ctx, "(dir 2>&1 *``|echo %userprofile%\\.waveterm%\\.waveterm\\bin\\wsh.exe);&<# rem #>echo none")
out, cmdErr := cmd.Output()
if cmdErr == nil && strings.TrimSpace(string(out)) != "none" {
return strings.TrimSpace(string(out))
}
// no custom install, use default path
return defaultPath
}
func hasBashInstalled(ctx context.Context, client *Distro) (bool, error) {
cmd := client.WslCommand(ctx, "which bash")
out, whichErr := cmd.Output()
if whichErr == nil && len(out) != 0 {
return true, nil
}
cmd = client.WslCommand(ctx, "where.exe bash")
out, whereErr := cmd.Output()
if whereErr == nil && len(out) != 0 {
return true, nil
}
// note: we could also check in /bin/bash explicitly
// just in case that wasn't added to the path. but if
// that's true, we will most likely have worse
// problems going forward
return false, nil
}
func GetClientOs(ctx context.Context, client *Distro) (string, error) {
cmd := client.WslCommand(ctx, "uname -s")
out, unixErr := cmd.Output()
if unixErr == nil {
formatted := strings.ToLower(string(out))
formatted = strings.TrimSpace(formatted)
return formatted, nil
}
cmd = client.WslCommand(ctx, "echo %OS%")
out, cmdErr := cmd.Output()
if cmdErr == nil && strings.TrimSpace(string(out)) != "%OS%" {
formatted := strings.ToLower(string(out))
formatted = strings.TrimSpace(formatted)
return strings.Split(formatted, "_")[0], nil
}
cmd = client.WslCommand(ctx, "echo $env:OS")
out, psErr := cmd.Output()
if psErr == nil && strings.TrimSpace(string(out)) != "$env:OS" {
formatted := strings.ToLower(string(out))
formatted = strings.TrimSpace(formatted)
return strings.Split(formatted, "_")[0], nil
}
return "", fmt.Errorf("unable to determine os: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr)
}
func GetClientArch(ctx context.Context, client *Distro) (string, error) {
cmd := client.WslCommand(ctx, "uname -m")
out, unixErr := cmd.Output()
if unixErr == nil {
formatted := strings.ToLower(string(out))
formatted = strings.TrimSpace(formatted)
if formatted == "x86_64" {
return "x64", nil
}
return formatted, nil
}
cmd = client.WslCommand(ctx, "echo %PROCESSOR_ARCHITECTURE%")
out, cmdErr := cmd.Output()
if cmdErr == nil && strings.TrimSpace(string(out)) != "%PROCESSOR_ARCHITECTURE%" {
formatted := strings.ToLower(string(out))
return strings.TrimSpace(formatted), nil
}
cmd = client.WslCommand(ctx, "echo $env:PROCESSOR_ARCHITECTURE")
out, psErr := cmd.Output()
if psErr == nil && strings.TrimSpace(string(out)) != "$env:PROCESSOR_ARCHITECTURE" {
formatted := strings.ToLower(string(out))
return strings.TrimSpace(formatted), nil
}
return "", fmt.Errorf("unable to determine architecture: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr)
}
type CancellableCmd struct {
Cmd *WslCmd
Cancel func()
}
var installTemplatesRawBash = map[string]string{
"mkdir": `bash -c 'mkdir -p {{.installDir}}'`,
"cat": `bash -c 'cat > {{.tempPath}}'`,
"mv": `bash -c 'mv {{.tempPath}} {{.installPath}}'`,
"chmod": `bash -c 'chmod a+x {{.installPath}}'`,
}
var installTemplatesRawDefault = map[string]string{
"mkdir": `mkdir -p {{.installDir}}`,
"cat": `cat > {{.tempPath}}`,
"mv": `mv {{.tempPath}} {{.installPath}}`,
"chmod": `chmod a+x {{.installPath}}`,
}
func makeCancellableCommand(ctx context.Context, client *Distro, cmdTemplateRaw string, words map[string]string) (*CancellableCmd, error) {
cmdContext, cmdCancel := context.WithCancel(ctx)
cmdStr := &bytes.Buffer{}
cmdTemplate, err := template.New("").Parse(cmdTemplateRaw)
if err != nil {
cmdCancel()
return nil, err
}
cmdTemplate.Execute(cmdStr, words)
cmd := client.WslCommand(cmdContext, cmdStr.String())
return &CancellableCmd{cmd, cmdCancel}, nil
}
func CpHostToRemote(ctx context.Context, client *Distro, sourcePath string, destPath string) error {
// warning: does not work on windows remote yet
bashInstalled, err := hasBashInstalled(ctx, client)
if err != nil {
return err
}
var selectedTemplatesRaw map[string]string
if bashInstalled {
selectedTemplatesRaw = installTemplatesRawBash
} else {
log.Printf("bash is not installed on remote. attempting with default shell")
selectedTemplatesRaw = installTemplatesRawDefault
}
// I need to use toSlash here to force unix keybindings
// this means we can't guarantee it will work on a remote windows machine
var installWords = map[string]string{
"installDir": filepath.ToSlash(filepath.Dir(destPath)),
"tempPath": destPath + ".temp",
"installPath": destPath,
}
installStepCmds := make(map[string]*CancellableCmd)
for cmdName, selectedTemplateRaw := range selectedTemplatesRaw {
cancellableCmd, err := makeCancellableCommand(ctx, client, selectedTemplateRaw, installWords)
if err != nil {
return err
}
installStepCmds[cmdName] = cancellableCmd
}
_, err = installStepCmds["mkdir"].Cmd.Output()
if err != nil {
return err
}
// the cat part of this is complicated since it requires stdin
catCmd := installStepCmds["cat"].Cmd
catStdin, err := catCmd.StdinPipe()
if err != nil {
return err
}
err = catCmd.Start()
if err != nil {
return err
}
input, err := os.Open(sourcePath)
if err != nil {
return fmt.Errorf("cannot open local file %s to send to host: %v", sourcePath, err)
}
go func() {
io.Copy(catStdin, input)
installStepCmds["cat"].Cancel()
// backup just in case something weird happens
// could cause potential race condition, but very
// unlikely
time.Sleep(time.Second * 1)
process := catCmd.GetProcess()
if process != nil {
process.Kill()
}
}()
catErr := catCmd.Wait()
if catErr != nil && !errors.Is(catErr, context.Canceled) {
return catErr
}
_, err = installStepCmds["mv"].Cmd.Output()
if err != nil {
return err
}
_, err = installStepCmds["chmod"].Cmd.Output()
if err != nil {
return err
}
return nil
}
func InstallClientRcFiles(ctx context.Context, client *Distro) error {
path := GetWshPath(ctx, client)
log.Printf("path to wsh searched is: %s", path)
cmd := client.WslCommand(ctx, path+" rcfiles")
_, err := cmd.Output()
return err
}
func GetHomeDir(ctx context.Context, client *Distro) string {
// note: also works for powershell
cmd := client.WslCommand(ctx, `echo "$HOME"`)
out, err := cmd.Output()
if err == nil {
return strings.TrimSpace(string(out))
}
cmd = client.WslCommand(ctx, `echo %userprofile%`)
out, err = cmd.Output()
if err == nil {
return strings.TrimSpace(string(out))
}
return "~"
}
func IsPowershell(shellPath string) bool {
// get the base path, and then check contains
shellBase := filepath.Base(shellPath)
return strings.Contains(shellBase, "powershell") || strings.Contains(shellBase, "pwsh")
}

125
pkg/wsl/wsl-win.go Normal file
View File

@ -0,0 +1,125 @@
//go:build windows
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package wsl
import (
"context"
"fmt"
"io"
"os"
"sync"
"github.com/ubuntu/gowsl"
)
var RegisteredDistros = gowsl.RegisteredDistros
var DefaultDistro = gowsl.DefaultDistro
type Distro struct {
gowsl.Distro
}
type WslCmd struct {
c *gowsl.Cmd
wg *sync.WaitGroup
once *sync.Once
lock *sync.Mutex
waitErr error
}
func (d *Distro) WslCommand(ctx context.Context, cmd string) *WslCmd {
if ctx == nil {
panic("nil Context")
}
innerCmd := d.Command(ctx, cmd)
var wg sync.WaitGroup
var lock *sync.Mutex
return &WslCmd{innerCmd, &wg, new(sync.Once), lock, nil}
}
func (c *WslCmd) CombinedOutput() (out []byte, err error) {
return c.c.CombinedOutput()
}
func (c *WslCmd) Output() (out []byte, err error) {
return c.c.Output()
}
func (c *WslCmd) Run() error {
return c.c.Run()
}
func (c *WslCmd) Start() (err error) {
return c.c.Start()
}
func (c *WslCmd) StderrPipe() (r io.ReadCloser, err error) {
return c.c.StderrPipe()
}
func (c *WslCmd) StdinPipe() (w io.WriteCloser, err error) {
return c.c.StdinPipe()
}
func (c *WslCmd) StdoutPipe() (r io.ReadCloser, err error) {
return c.c.StdoutPipe()
}
func (c *WslCmd) Wait() (err error) {
c.wg.Add(1)
c.once.Do(func() {
c.waitErr = c.c.Wait()
})
c.wg.Done()
c.wg.Wait()
if c.waitErr != nil && c.waitErr.Error() == "not started" {
c.once = new(sync.Once)
return c.waitErr
}
return c.waitErr
}
func (c *WslCmd) GetProcess() *os.Process {
return c.c.Process
}
func (c *WslCmd) GetProcessState() *os.ProcessState {
return c.c.ProcessState
}
func (c *WslCmd) SetStdin(stdin io.Reader) {
c.c.Stdin = stdin
}
func (c *WslCmd) SetStdout(stdout io.Writer) {
c.c.Stdout = stdout
}
func (c *WslCmd) SetStderr(stderr io.Writer) {
c.c.Stdout = stderr
}
func GetDistroCmd(ctx context.Context, wslDistroName string, cmd string) (*WslCmd, error) {
distros, err := RegisteredDistros(ctx)
if err != nil {
return nil, err
}
for _, distro := range distros {
if distro.Name() != wslDistroName {
continue
}
wrappedDistro := Distro{distro}
return wrappedDistro.WslCommand(ctx, cmd), nil
}
return nil, fmt.Errorf("wsl distro %s not found", wslDistroName)
}
func GetDistro(ctx context.Context, wslDistroName WslName) (*Distro, error) {
distros, err := RegisteredDistros(ctx)
if err != nil {
return nil, err
}
for _, distro := range distros {
if distro.Name() != wslDistroName.Distro {
continue
}
wrappedDistro := Distro{distro}
return &wrappedDistro, nil
}
return nil, fmt.Errorf("wsl distro %s not found", wslDistroName)
}

494
pkg/wsl/wsl.go Normal file
View File

@ -0,0 +1,494 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package wsl
import (
"context"
"fmt"
"io"
"log"
"net"
"sync"
"sync/atomic"
"time"
"github.com/wavetermdev/waveterm/pkg/userinput"
"github.com/wavetermdev/waveterm/pkg/util/shellutil"
"github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/waveobj"
"github.com/wavetermdev/waveterm/pkg/wconfig"
"github.com/wavetermdev/waveterm/pkg/wps"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshutil"
)
const (
Status_Init = "init"
Status_Connecting = "connecting"
Status_Connected = "connected"
Status_Disconnected = "disconnected"
Status_Error = "error"
)
const DefaultConnectionTimeout = 60 * time.Second
var globalLock = &sync.Mutex{}
var clientControllerMap = make(map[string]*WslConn)
var activeConnCounter = &atomic.Int32{}
type WslConn struct {
Lock *sync.Mutex
Status string
Name WslName
Client *Distro
SockName string
DomainSockListener net.Listener
ConnController *WslCmd
Error string
HasWaiter *atomic.Bool
LastConnectTime int64
ActiveConnNum int
Context context.Context
cancelFn func()
}
type WslName struct {
Distro string `json:"distro"`
}
func GetAllConnStatus() []wshrpc.ConnStatus {
globalLock.Lock()
defer globalLock.Unlock()
var connStatuses []wshrpc.ConnStatus
for _, conn := range clientControllerMap {
connStatuses = append(connStatuses, conn.DeriveConnStatus())
}
return connStatuses
}
func (conn *WslConn) DeriveConnStatus() wshrpc.ConnStatus {
conn.Lock.Lock()
defer conn.Lock.Unlock()
return wshrpc.ConnStatus{
Status: conn.Status,
Connected: conn.Status == Status_Connected,
Connection: conn.GetName(),
HasConnected: (conn.LastConnectTime > 0),
ActiveConnNum: conn.ActiveConnNum,
Error: conn.Error,
}
}
func (conn *WslConn) FireConnChangeEvent() {
status := conn.DeriveConnStatus()
event := wps.WaveEvent{
Event: wps.Event_ConnChange,
Scopes: []string{
fmt.Sprintf("connection:%s", conn.GetName()),
},
Data: status,
}
log.Printf("sending event: %+#v", event)
wps.Broker.Publish(event)
}
func (conn *WslConn) Close() error {
defer conn.FireConnChangeEvent()
conn.WithLock(func() {
if conn.Status == Status_Connected || conn.Status == Status_Connecting {
// if status is init, disconnected, or error don't change it
conn.Status = Status_Disconnected
}
conn.close_nolock()
})
// we must wait for the waiter to complete
startTime := time.Now()
for conn.HasWaiter.Load() {
time.Sleep(10 * time.Millisecond)
if time.Since(startTime) > 2*time.Second {
return fmt.Errorf("timeout waiting for waiter to complete")
}
}
return nil
}
func (conn *WslConn) close_nolock() {
// does not set status (that should happen at another level)
if conn.DomainSockListener != nil {
conn.DomainSockListener.Close()
conn.DomainSockListener = nil
}
if conn.ConnController != nil {
conn.cancelFn() // this suspends the conn controller
conn.ConnController = nil
}
if conn.Client != nil {
// conn.Client.Close() is not relevant here
// we do not want to completely close the wsl in case
// other applications are using it
conn.Client = nil
}
}
func (conn *WslConn) GetDomainSocketName() string {
conn.Lock.Lock()
defer conn.Lock.Unlock()
return conn.SockName
}
func (conn *WslConn) GetStatus() string {
conn.Lock.Lock()
defer conn.Lock.Unlock()
return conn.Status
}
func (conn *WslConn) GetName() string {
// no lock required because opts is immutable
return "wsl://" + conn.Name.Distro
}
/**
* This function is does not set a listener for WslConn
* It is still required in order to set SockName
**/
func (conn *WslConn) OpenDomainSocketListener() error {
var allowed bool
conn.WithLock(func() {
if conn.Status != Status_Connecting {
allowed = false
} else {
allowed = true
}
})
if !allowed {
return fmt.Errorf("cannot open domain socket for %q when status is %q", conn.GetName(), conn.GetStatus())
}
conn.WithLock(func() {
conn.SockName = "~/.waveterm/wave-remote.sock"
})
return nil
}
func (conn *WslConn) StartConnServer() error {
var allowed bool
conn.WithLock(func() {
if conn.Status != Status_Connecting {
allowed = false
} else {
allowed = true
}
})
if !allowed {
return fmt.Errorf("cannot start conn server for %q when status is %q", conn.GetName(), conn.GetStatus())
}
client := conn.GetClient()
wshPath := GetWshPath(conn.Context, client)
rpcCtx := wshrpc.RpcContext{
ClientType: wshrpc.ClientType_ConnServer,
Conn: conn.GetName(),
}
sockName := conn.GetDomainSocketName()
jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, sockName)
if err != nil {
return fmt.Errorf("unable to create jwt token for conn controller: %w", err)
}
shellPath, err := DetectShell(conn.Context, client)
if err != nil {
return err
}
var cmdStr string
if IsPowershell(shellPath) {
cmdStr = fmt.Sprintf("$env:%s=\"%s\"; %s connserver --router", wshutil.WaveJwtTokenVarName, jwtToken, wshPath)
} else {
cmdStr = fmt.Sprintf("%s=\"%s\" %s connserver --router", wshutil.WaveJwtTokenVarName, jwtToken, wshPath)
}
log.Printf("starting conn controller: %s\n", cmdStr)
cmd := client.WslCommand(conn.Context, cmdStr)
pipeRead, pipeWrite := io.Pipe()
inputPipeRead, inputPipeWrite := io.Pipe()
cmd.SetStdout(pipeWrite)
cmd.SetStderr(pipeWrite)
cmd.SetStdin(inputPipeRead)
err = cmd.Start()
if err != nil {
return fmt.Errorf("unable to start conn controller: %w", err)
}
conn.WithLock(func() {
conn.ConnController = cmd
})
// service the I/O
go func() {
// wait for termination, clear the controller
defer conn.WithLock(func() {
conn.ConnController = nil
})
waitErr := cmd.Wait()
log.Printf("conn controller (%q) terminated: %v", conn.GetName(), waitErr)
}()
go func() {
logName := fmt.Sprintf("conncontroller:%s", conn.GetName())
wshutil.HandleStdIOClient(logName, pipeRead, inputPipeWrite)
}()
regCtx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelFn()
err = wshutil.DefaultRouter.WaitForRegister(regCtx, wshutil.MakeConnectionRouteId(rpcCtx.Conn))
if err != nil {
return fmt.Errorf("timeout waiting for connserver to register")
}
time.Sleep(300 * time.Millisecond) // TODO remove this sleep (but we need to wait until connserver is "ready")
return nil
}
type WshInstallOpts struct {
Force bool
NoUserPrompt bool
}
func (conn *WslConn) CheckAndInstallWsh(ctx context.Context, clientDisplayName string, opts *WshInstallOpts) error {
if opts == nil {
opts = &WshInstallOpts{}
}
client := conn.GetClient()
if client == nil {
return fmt.Errorf("client is nil")
}
// check that correct wsh extensions are installed
expectedVersion := fmt.Sprintf("wsh v%s", wavebase.WaveVersion)
clientVersion, err := GetWshVersion(ctx, client)
if err == nil && clientVersion == expectedVersion && !opts.Force {
return nil
}
var queryText string
var title string
if opts.Force {
queryText = fmt.Sprintf("ReInstalling Wave Shell Extensions (%s) on `%s`\n", wavebase.WaveVersion, clientDisplayName)
title = "Install Wave Shell Extensions"
} else if err != nil {
queryText = fmt.Sprintf("Wave requires Wave Shell Extensions to be \n"+
"installed on `%s` \n"+
"to ensure a seamless experience. \n\n"+
"Would you like to install them?", clientDisplayName)
title = "Install Wave Shell Extensions"
} else {
// don't ask for upgrading the version
opts.NoUserPrompt = true
}
if !opts.NoUserPrompt {
request := &userinput.UserInputRequest{
ResponseType: "confirm",
QueryText: queryText,
Title: title,
Markdown: true,
CheckBoxMsg: "Don't show me this again",
}
response, err := userinput.GetUserInput(ctx, request)
if err != nil || !response.Confirm {
return err
}
if response.CheckboxStat {
meta := waveobj.MetaMapType{
wconfig.ConfigKey_ConnAskBeforeWshInstall: false,
}
err := wconfig.SetBaseConfigValue(meta)
if err != nil {
return fmt.Errorf("error setting conn:askbeforewshinstall value: %w", err)
}
}
}
log.Printf("attempting to install wsh to `%s`", clientDisplayName)
clientOs, err := GetClientOs(ctx, client)
if err != nil {
return err
}
clientArch, err := GetClientArch(ctx, client)
if err != nil {
return err
}
// attempt to install extension
wshLocalPath := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch)
err = CpHostToRemote(ctx, client, wshLocalPath, "~/.waveterm/bin/wsh")
if err != nil {
return err
}
log.Printf("successfully installed wsh on %s\n", conn.GetName())
return nil
}
func (conn *WslConn) GetClient() *Distro {
conn.Lock.Lock()
defer conn.Lock.Unlock()
return conn.Client
}
func (conn *WslConn) Reconnect(ctx context.Context) error {
err := conn.Close()
if err != nil {
return err
}
return conn.Connect(ctx)
}
func (conn *WslConn) WaitForConnect(ctx context.Context) error {
for {
status := conn.DeriveConnStatus()
if status.Status == Status_Connected {
return nil
}
if status.Status == Status_Connecting {
select {
case <-ctx.Done():
return fmt.Errorf("context timeout")
case <-time.After(100 * time.Millisecond):
continue
}
}
if status.Status == Status_Init || status.Status == Status_Disconnected {
return fmt.Errorf("disconnected")
}
if status.Status == Status_Error {
return fmt.Errorf("error: %v", status.Error)
}
return fmt.Errorf("unknown status: %q", status.Status)
}
}
// does not return an error since that error is stored inside of WslConn
func (conn *WslConn) Connect(ctx context.Context) error {
var connectAllowed bool
conn.WithLock(func() {
if conn.Status == Status_Connecting || conn.Status == Status_Connected {
connectAllowed = false
} else {
conn.Status = Status_Connecting
conn.Error = ""
connectAllowed = true
}
})
log.Printf("Connect %s\n", conn.GetName())
if !connectAllowed {
return fmt.Errorf("cannot connect to %q when status is %q", conn.GetName(), conn.GetStatus())
}
conn.FireConnChangeEvent()
err := conn.connectInternal(ctx)
conn.WithLock(func() {
if err != nil {
conn.Status = Status_Error
conn.Error = err.Error()
conn.close_nolock()
} else {
conn.Status = Status_Connected
conn.LastConnectTime = time.Now().UnixMilli()
if conn.ActiveConnNum == 0 {
conn.ActiveConnNum = int(activeConnCounter.Add(1))
}
}
})
conn.FireConnChangeEvent()
return err
}
func (conn *WslConn) WithLock(fn func()) {
conn.Lock.Lock()
defer conn.Lock.Unlock()
fn()
}
func (conn *WslConn) connectInternal(ctx context.Context) error {
client, err := GetDistro(ctx, conn.Name)
if err != nil {
return err
}
conn.WithLock(func() {
conn.Client = client
})
err = conn.OpenDomainSocketListener()
if err != nil {
return err
}
config := wconfig.ReadFullConfig()
installErr := conn.CheckAndInstallWsh(ctx, conn.GetName(), &WshInstallOpts{NoUserPrompt: !config.Settings.ConnAskBeforeWshInstall})
if installErr != nil {
return fmt.Errorf("conncontroller %s wsh install error: %v", conn.GetName(), installErr)
}
csErr := conn.StartConnServer()
if csErr != nil {
return fmt.Errorf("conncontroller %s start wsh connserver error: %v", conn.GetName(), csErr)
}
conn.HasWaiter.Store(true)
go conn.waitForDisconnect()
return nil
}
func (conn *WslConn) waitForDisconnect() {
defer conn.FireConnChangeEvent()
defer conn.HasWaiter.Store(false)
err := conn.ConnController.Wait()
conn.WithLock(func() {
// disconnects happen for a variety of reasons (like network, etc. and are typically transient)
// so we just set the status to "disconnected" here (not error)
// don't overwrite any existing error (or error status)
if err != nil && conn.Error == "" {
conn.Error = err.Error()
}
if conn.Status != Status_Error {
conn.Status = Status_Disconnected
}
conn.close_nolock()
})
}
func getConnInternal(name string) *WslConn {
globalLock.Lock()
defer globalLock.Unlock()
connName := WslName{Distro: name}
rtn := clientControllerMap[name]
if rtn == nil {
ctx, cancelFn := context.WithCancel(context.Background())
rtn = &WslConn{Lock: &sync.Mutex{}, Status: Status_Init, Name: connName, HasWaiter: &atomic.Bool{}, Context: ctx, cancelFn: cancelFn}
clientControllerMap[name] = rtn
}
return rtn
}
func GetWslConn(ctx context.Context, name string, shouldConnect bool) *WslConn {
conn := getConnInternal(name)
if conn.Client == nil && shouldConnect {
conn.Connect(ctx)
}
return conn
}
// Convenience function for ensuring a connection is established
func EnsureConnection(ctx context.Context, connName string) error {
if connName == "" {
return nil
}
conn := GetWslConn(ctx, connName, false)
if conn == nil {
return fmt.Errorf("connection not found: %s", connName)
}
connStatus := conn.DeriveConnStatus()
switch connStatus.Status {
case Status_Connected:
return nil
case Status_Connecting:
return conn.WaitForConnect(ctx)
case Status_Init, Status_Disconnected:
return conn.Connect(ctx)
case Status_Error:
return fmt.Errorf("connection error: %s", connStatus.Error)
default:
return fmt.Errorf("unknown connection status %q", connStatus.Status)
}
}
func DisconnectClient(connName string) error {
conn := getConnInternal(connName)
if conn == nil {
return fmt.Errorf("client %q not found", connName)
}
err := conn.Close()
return err
}