WSL Integration (#1031)

Adds support for connecting to local WSL installations on Windows.

(also adds wshrpcmmultiproxy / connserver router)
This commit is contained in:
Sylvie Crowe 2024-10-23 22:43:17 -07:00 committed by GitHub
parent 4e86b67936
commit 8248637e00
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 2101 additions and 75 deletions

2
.gitattributes vendored
View File

@ -1 +1 @@
* text=auto
* text=auto eol=lf

View File

@ -159,11 +159,11 @@ func shutdownActivityUpdate() {
func createMainWshClient() {
rpc := wshserver.GetMainRpcClient()
wshutil.DefaultRouter.RegisterRoute(wshutil.DefaultRoute, rpc)
wshutil.DefaultRouter.RegisterRoute(wshutil.DefaultRoute, rpc, true)
wps.Broker.SetClient(wshutil.DefaultRouter)
localConnWsh := wshutil.MakeWshRpc(nil, nil, wshrpc.RpcContext{Conn: wshrpc.LocalConnName}, &wshremote.ServerImpl{})
go wshremote.RunSysInfoLoop(localConnWsh, wshrpc.LocalConnName)
wshutil.DefaultRouter.RegisterRoute(wshutil.MakeConnectionRouteId(wshrpc.LocalConnName), localConnWsh)
wshutil.DefaultRouter.RegisterRoute(wshutil.MakeConnectionRouteId(wshrpc.LocalConnName), localConnWsh, true)
}
func main() {

View File

@ -5,6 +5,7 @@ package cmd
import (
"fmt"
"strings"
"github.com/spf13/cobra"
"github.com/wavetermdev/waveterm/pkg/remote"
@ -25,17 +26,24 @@ func init() {
}
func connStatus() error {
resp, err := wshclient.ConnStatusCommand(RpcClient, nil)
var allResp []wshrpc.ConnStatus
sshResp, err := wshclient.ConnStatusCommand(RpcClient, nil)
if err != nil {
return fmt.Errorf("getting connection status: %w", err)
return fmt.Errorf("getting ssh connection status: %w", err)
}
if len(resp) == 0 {
allResp = append(allResp, sshResp...)
wslResp, err := wshclient.WslStatusCommand(RpcClient, nil)
if err != nil {
return fmt.Errorf("getting wsl connection status: %w", err)
}
allResp = append(allResp, wslResp...)
if len(allResp) == 0 {
WriteStdout("no connections\n")
return nil
}
WriteStdout("%-30s %-12s\n", "connection", "status")
WriteStdout("----------------------------------------------\n")
for _, conn := range resp {
for _, conn := range allResp {
str := fmt.Sprintf("%-30s %-12s", conn.Connection, conn.Status)
if conn.Error != "" {
str += fmt.Sprintf(" (%s)", conn.Error)
@ -110,7 +118,7 @@ func connRun(cmd *cobra.Command, args []string) error {
}
connName = args[1]
_, err := remote.ParseOpts(connName)
if err != nil {
if err != nil && !strings.HasPrefix(connName, "wsl://") {
return fmt.Errorf("cannot parse connection name: %w", err)
}
}

View File

@ -4,29 +4,186 @@
package cmd
import (
"encoding/json"
"fmt"
"io"
"log"
"net"
"os"
"sync/atomic"
"time"
"github.com/spf13/cobra"
"github.com/wavetermdev/waveterm/pkg/util/packetparser"
"github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshremote"
"github.com/wavetermdev/waveterm/pkg/wshutil"
)
var serverCmd = &cobra.Command{
Use: "connserver",
Hidden: true,
Short: "remote server to power wave blocks",
Args: cobra.NoArgs,
Run: serverRun,
PreRunE: preRunSetupRpcClient,
Use: "connserver",
Hidden: true,
Short: "remote server to power wave blocks",
Args: cobra.NoArgs,
RunE: serverRun,
}
var connServerRouter bool
func init() {
serverCmd.Flags().BoolVar(&connServerRouter, "router", false, "run in local router mode")
rootCmd.AddCommand(serverCmd)
}
func serverRun(cmd *cobra.Command, args []string) {
func MakeRemoteUnixListener() (net.Listener, error) {
serverAddr := wavebase.GetRemoteDomainSocketName()
os.Remove(serverAddr) // ignore error
rtn, err := net.Listen("unix", serverAddr)
if err != nil {
return nil, fmt.Errorf("error creating listener at %v: %v", serverAddr, err)
}
os.Chmod(serverAddr, 0700)
log.Printf("Server [unix-domain] listening on %s\n", serverAddr)
return rtn, nil
}
func handleNewListenerConn(conn net.Conn, router *wshutil.WshRouter) {
var routeIdContainer atomic.Pointer[string]
proxy := wshutil.MakeRpcProxy()
go func() {
writeErr := wshutil.AdaptOutputChToStream(proxy.ToRemoteCh, conn)
if writeErr != nil {
log.Printf("error writing to domain socket: %v\n", writeErr)
}
}()
go func() {
// when input is closed, close the connection
defer func() {
conn.Close()
routeIdPtr := routeIdContainer.Load()
if routeIdPtr != nil && *routeIdPtr != "" {
router.UnregisterRoute(*routeIdPtr)
disposeMsg := &wshutil.RpcMessage{
Command: wshrpc.Command_Dispose,
Data: wshrpc.CommandDisposeData{
RouteId: *routeIdPtr,
},
Source: *routeIdPtr,
AuthToken: proxy.GetAuthToken(),
}
disposeBytes, _ := json.Marshal(disposeMsg)
router.InjectMessage(disposeBytes, *routeIdPtr)
}
}()
wshutil.AdaptStreamToMsgCh(conn, proxy.FromRemoteCh)
}()
routeId, err := proxy.HandleClientProxyAuth(router)
if err != nil {
log.Printf("error handling client proxy auth: %v\n", err)
conn.Close()
return
}
router.RegisterRoute(routeId, proxy, false)
routeIdContainer.Store(&routeId)
}
func runListener(listener net.Listener, router *wshutil.WshRouter) {
defer func() {
log.Printf("listener closed, exiting\n")
time.Sleep(500 * time.Millisecond)
wshutil.DoShutdown("", 1, true)
}()
for {
conn, err := listener.Accept()
if err == io.EOF {
break
}
if err != nil {
log.Printf("error accepting connection: %v\n", err)
continue
}
go handleNewListenerConn(conn, router)
}
}
func setupConnServerRpcClientWithRouter(router *wshutil.WshRouter) (*wshutil.WshRpc, error) {
jwtToken := os.Getenv(wshutil.WaveJwtTokenVarName)
if jwtToken == "" {
return nil, fmt.Errorf("no jwt token found for connserver")
}
rpcCtx, err := wshutil.ExtractUnverifiedRpcContext(jwtToken)
if err != nil {
return nil, fmt.Errorf("error extracting rpc context from %s: %v", wshutil.WaveJwtTokenVarName, err)
}
authRtn, err := router.HandleProxyAuth(jwtToken)
if err != nil {
return nil, fmt.Errorf("error handling proxy auth: %v", err)
}
inputCh := make(chan []byte, wshutil.DefaultInputChSize)
outputCh := make(chan []byte, wshutil.DefaultOutputChSize)
connServerClient := wshutil.MakeWshRpc(inputCh, outputCh, *rpcCtx, &wshremote.ServerImpl{LogWriter: os.Stdout})
connServerClient.SetAuthToken(authRtn.AuthToken)
router.RegisterRoute(authRtn.RouteId, connServerClient, false)
wshclient.RouteAnnounceCommand(connServerClient, nil)
return connServerClient, nil
}
func serverRunRouter() error {
router := wshutil.NewWshRouter()
termProxy := wshutil.MakeRpcProxy()
rawCh := make(chan []byte, wshutil.DefaultOutputChSize)
go packetparser.Parse(os.Stdin, termProxy.FromRemoteCh, rawCh)
go func() {
for msg := range termProxy.ToRemoteCh {
packetparser.WritePacket(os.Stdout, msg)
}
}()
go func() {
// just ignore and drain the rawCh (stdin)
// when stdin is closed, shutdown
defer wshutil.DoShutdown("", 0, true)
for range rawCh {
// ignore
}
}()
go func() {
for msg := range termProxy.FromRemoteCh {
// send this to the router
router.InjectMessage(msg, wshutil.UpstreamRoute)
}
}()
router.SetUpstreamClient(termProxy)
// now set up the domain socket
unixListener, err := MakeRemoteUnixListener()
if err != nil {
return fmt.Errorf("cannot create unix listener: %v", err)
}
client, err := setupConnServerRpcClientWithRouter(router)
if err != nil {
return fmt.Errorf("error setting up connserver rpc client: %v", err)
}
go runListener(unixListener, router)
// run the sysinfo loop
wshremote.RunSysInfoLoop(client, client.GetRpcContext().Conn)
select {}
}
func serverRunNormal() error {
err := setupRpcClient(&wshremote.ServerImpl{LogWriter: os.Stdout})
if err != nil {
return err
}
WriteStdout("running wsh connserver (%s)\n", RpcContext.Conn)
go wshremote.RunSysInfoLoop(RpcClient, RpcContext.Conn)
RpcClient.SetServerImpl(&wshremote.ServerImpl{LogWriter: os.Stdout})
select {} // run forever
}
func serverRun(cmd *cobra.Command, args []string) error {
if connServerRouter {
return serverRunRouter()
} else {
return serverRunNormal()
}
}

60
cmd/wsh/cmd/wshcmd-wsl.go Normal file
View File

@ -0,0 +1,60 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package cmd
import (
"strings"
"github.com/spf13/cobra"
"github.com/wavetermdev/waveterm/pkg/waveobj"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
)
var distroName string
var wslCmd = &cobra.Command{
Use: "wsl [-d <Distro>]",
Short: "connect this terminal to a local wsl connection",
Args: cobra.NoArgs,
Run: wslRun,
PreRunE: preRunSetupRpcClient,
}
func init() {
wslCmd.Flags().StringVarP(&distroName, "distribution", "d", "", "Run the specified distribution")
rootCmd.AddCommand(wslCmd)
}
func wslRun(cmd *cobra.Command, args []string) {
var err error
if distroName == "" {
// get default distro from the host
distroName, err = wshclient.WslDefaultDistroCommand(RpcClient, nil)
if err != nil {
WriteStderr("[error] %s\n", err)
return
}
}
if !strings.HasPrefix(distroName, "wsl://") {
distroName = "wsl://" + distroName
}
blockId := RpcContext.BlockId
if blockId == "" {
WriteStderr("[error] cannot determine blockid (not in JWT)\n")
return
}
data := wshrpc.CommandSetMetaData{
ORef: waveobj.MakeORef(waveobj.OType_Block, blockId),
Meta: map[string]any{
waveobj.MetaKey_Connection: distroName,
},
}
err = wshclient.SetMetaCommand(RpcClient, data, nil)
if err != nil {
WriteStderr("[error] setting switching connection: %v\n", err)
return
}
WriteStderr("switched connection to %q\n", distroName)
}

View File

@ -521,6 +521,7 @@ const ChangeConnectionBlockModal = React.memo(
const connStatusAtom = getConnStatusAtom(connection);
const connStatus = jotai.useAtomValue(connStatusAtom);
const [connList, setConnList] = React.useState<Array<string>>([]);
const [wslList, setWslList] = React.useState<Array<string>>([]);
const allConnStatus = jotai.useAtomValue(atoms.allConnStatus);
const [rowIndex, setRowIndex] = React.useState(0);
const connStatusMap = new Map<string, ConnStatus>();
@ -540,6 +541,18 @@ const ChangeConnectionBlockModal = React.memo(
prtn.then((newConnList) => {
setConnList(newConnList ?? []);
}).catch((e) => console.log("unable to load conn list from backend. using blank list: ", e));
const p2rtn = RpcApi.WslListCommand(TabRpcClient, { timeout: 2000 });
p2rtn
.then((newWslList) => {
console.log(newWslList);
setWslList(newWslList ?? []);
})
.catch((e) => {
// removing this log and failing silentyly since it will happen
// if a system isn't using the wsl. and would happen every time the
// typeahead was opened. good candidate for verbose log level.
//console.log("unable to load wsl list from backend. using blank list: ", e)
});
}, [changeConnModalOpen, setConnList]);
const changeConnection = React.useCallback(
@ -588,6 +601,15 @@ const ChangeConnectionBlockModal = React.memo(
filteredList.push(conn);
}
}
const filteredWslList: Array<string> = [];
for (const conn of wslList) {
if (conn === connSelected) {
createNew = false;
}
if (conn.includes(connSelected)) {
filteredWslList.push(conn);
}
}
// priority handles special suggestions when necessary
// for instance, when reconnecting
const newConnectionSuggestion: SuggestionConnectionItem = {
@ -637,6 +659,20 @@ const ChangeConnectionBlockModal = React.memo(
label: localName,
});
}
for (const wslConn of filteredWslList) {
const connStatus = connStatusMap.get(wslConn);
const connColorNum = computeConnColorNum(connStatus);
localSuggestion.items.push({
status: "connected",
icon: "arrow-right-arrow-left",
iconColor:
connStatus?.status == "connected"
? `var(--conn-icon-color-${connColorNum})`
: "var(--grey-text-color)",
value: "wsl://" + wslConn,
label: "wsl://" + wslConn,
});
}
const remoteItems = filteredList.map((connName) => {
const connStatus = connStatusMap.get(connName);
const connColorNum = computeConnColorNum(connStatus);

View File

@ -72,6 +72,11 @@ class RpcApiType {
return client.wshRpcCall("deleteblock", data, opts);
}
// command "dispose" [call]
DisposeCommand(client: WshClient, data: CommandDisposeData, opts?: RpcOpts): Promise<void> {
return client.wshRpcCall("dispose", data, opts);
}
// command "eventpublish" [call]
EventPublishCommand(client: WshClient, data: WaveEvent, opts?: RpcOpts): Promise<void> {
return client.wshRpcCall("eventpublish", data, opts);
@ -237,6 +242,21 @@ class RpcApiType {
return client.wshRpcCall("webselector", data, opts);
}
// command "wsldefaultdistro" [call]
WslDefaultDistroCommand(client: WshClient, opts?: RpcOpts): Promise<string> {
return client.wshRpcCall("wsldefaultdistro", null, opts);
}
// command "wsllist" [call]
WslListCommand(client: WshClient, opts?: RpcOpts): Promise<string[]> {
return client.wshRpcCall("wsllist", null, opts);
}
// command "wslstatus" [call]
WslStatusCommand(client: WshClient, opts?: RpcOpts): Promise<ConnStatus[]> {
return client.wshRpcCall("wslstatus", null, opts);
}
}
export const RpcApi = new RpcApiType();

View File

@ -63,6 +63,7 @@ declare global {
// wshrpc.CommandAuthenticateRtnData
type CommandAuthenticateRtnData = {
routeid: string;
authtoken?: string;
};
// wshrpc.CommandBlockInputData
@ -100,6 +101,11 @@ declare global {
blockid: string;
};
// wshrpc.CommandDisposeData
type CommandDisposeData = {
routeid: string;
};
// wshrpc.CommandEventReadHistoryData
type CommandEventReadHistoryData = {
event: string;
@ -416,6 +422,7 @@ declare global {
resid?: string;
timeout?: number;
route?: string;
authtoken?: string;
source?: string;
cont?: boolean;
cancel?: boolean;

3
go.mod
View File

@ -21,6 +21,7 @@ require (
github.com/shirou/gopsutil/v4 v4.24.9
github.com/skeema/knownhosts v1.3.0
github.com/spf13/cobra v1.8.1
github.com/ubuntu/gowsl v0.0.0-20240906163211-049fd49bd93b
github.com/wavetermdev/htmltoken v0.1.0
golang.org/x/crypto v0.28.0
golang.org/x/sys v0.26.0
@ -36,9 +37,11 @@ require (
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect
github.com/ubuntu/decorate v0.0.0-20230125165522-2d5b0a9bb117 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.uber.org/atomic v1.7.0 // indirect
golang.org/x/net v0.29.0 // indirect

11
go.sum
View File

@ -1,5 +1,7 @@
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/0xrawsec/golang-utils v1.3.2 h1:ww4jrtHRSnX9xrGzJYbalx5nXoZewy4zPxiY+ubJgtg=
github.com/0xrawsec/golang-utils v1.3.2/go.mod h1:m7AzHXgdSAkFCD9tWWsApxNVxMlyy7anpPVOyT/yM7E=
github.com/alexflint/go-filemutex v1.3.0 h1:LgE+nTUWnQCyRKbpoceKZsPQbs84LivvgwUymZXdOcM=
github.com/alexflint/go-filemutex v1.3.0/go.mod h1:U0+VA/i30mGBlLCrFPGtTe9y6wGQfNAWPBTekHQ+c8A=
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
@ -62,6 +64,8 @@ github.com/sawka/txwrap v0.2.0 h1:V3LfvKVLULxcYSxdMguLwFyQFMEU9nFDJopg0ZkL+94=
github.com/sawka/txwrap v0.2.0/go.mod h1:wwQ2SQiN4U+6DU/iVPhbvr7OzXAtgZlQCIGuvOswEfA=
github.com/shirou/gopsutil/v4 v4.24.9 h1:KIV+/HaHD5ka5f570RZq+2SaeFsb/pq+fp2DGNWYoOI=
github.com/shirou/gopsutil/v4 v4.24.9/go.mod h1:3fkaHNeYsUFCGZ8+9vZVWtbyM1k2eRnlL+bWO8Bxa/Q=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/skeema/knownhosts v1.3.0 h1:AM+y0rI04VksttfwjkSTNQorvGqmwATnvnAHpSgc0LY=
github.com/skeema/knownhosts v1.3.0/go.mod h1:sPINvnADmT/qYH1kfv+ePMmOBTH6Tbl7b5LvTDjFK7M=
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
@ -71,12 +75,17 @@ github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
github.com/ubuntu/decorate v0.0.0-20230125165522-2d5b0a9bb117 h1:XQpsQG5lqRJlx4mUVHcJvyyc1rdTI9nHvwrdfcuy8aM=
github.com/ubuntu/decorate v0.0.0-20230125165522-2d5b0a9bb117/go.mod h1:mx0TjbqsaDD9DUT5gA1s3hw47U6RIbbIBfvGzR85K0g=
github.com/ubuntu/gowsl v0.0.0-20240906163211-049fd49bd93b h1:wFBKF5k5xbJQU8bYgcSoQ/ScvmYyq6KHUabAuVUjOWM=
github.com/ubuntu/gowsl v0.0.0-20240906163211-049fd49bd93b/go.mod h1:N1CYNinssZru+ikvYTgVbVeSi21thHUTCoJ9xMvWe+s=
github.com/wavetermdev/htmltoken v0.1.0 h1:RMdA9zTfnYa5jRC4RRG3XNoV5NOP8EDxpaVPjuVz//Q=
github.com/wavetermdev/htmltoken v0.1.0/go.mod h1:5FM0XV6zNYiNza2iaTcFGj+hnMtgqumFHO31Z8euquk=
github.com/wavetermdev/ssh_config v0.0.0-20240306041034-17e2087ebde2 h1:onqZrJVap1sm15AiIGTfWzdr6cEF0KdtddeuuOVhzyY=
@ -91,6 +100,7 @@ golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo=
golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220721230656-c6bc011c0c49/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@ -102,5 +112,6 @@ golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -11,6 +11,7 @@ import (
"io"
"io/fs"
"log"
"strings"
"sync"
"time"
@ -24,6 +25,7 @@ import (
"github.com/wavetermdev/waveterm/pkg/wps"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshutil"
"github.com/wavetermdev/waveterm/pkg/wsl"
"github.com/wavetermdev/waveterm/pkg/wstore"
)
@ -262,7 +264,30 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
return fmt.Errorf("unknown controller type %q", bc.ControllerType)
}
var shellProc *shellexec.ShellProc
if remoteName != "" {
if strings.HasPrefix(remoteName, "wsl://") {
wslName := strings.TrimPrefix(remoteName, "wsl://")
credentialCtx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second)
defer cancelFunc()
wslConn := wsl.GetWslConn(credentialCtx, wslName, false)
connStatus := wslConn.DeriveConnStatus()
if connStatus.Status != conncontroller.Status_Connected {
return fmt.Errorf("not connected, cannot start shellproc")
}
// create jwt
if !blockMeta.GetBool(waveobj.MetaKey_CmdNoWsh, false) {
jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId, Conn: wslConn.GetName()}, wslConn.GetDomainSocketName())
if err != nil {
return fmt.Errorf("error making jwt token: %w", err)
}
cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr
}
shellProc, err = shellexec.StartWslShellProc(ctx, rc.TermSize, cmdStr, cmdOpts, wslConn)
if err != nil {
return err
}
} else if remoteName != "" {
credentialCtx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second)
defer cancelFunc()
@ -325,7 +350,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
// we don't need to authenticate this wshProxy since it is coming direct
wshProxy := wshutil.MakeRpcProxy()
wshProxy.SetRpcContext(&wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId})
wshutil.DefaultRouter.RegisterRoute(wshutil.MakeControllerRouteId(bc.BlockId), wshProxy)
wshutil.DefaultRouter.RegisterRoute(wshutil.MakeControllerRouteId(bc.BlockId), wshProxy, true)
ptyBuffer := wshutil.MakePtyBuffer(wshutil.WaveOSCPrefix, bc.ShellProc.Cmd, wshProxy.FromRemoteCh)
go func() {
// handles regular output from the pty (goes to the blockfile and xterm)
@ -494,6 +519,15 @@ func CheckConnStatus(blockId string) error {
if connName == "" {
return nil
}
if strings.HasPrefix(connName, "wsl://") {
distroName := strings.TrimPrefix(connName, "wsl://")
conn := wsl.GetWslConn(context.Background(), distroName, false)
connStatus := conn.DeriveConnStatus()
if connStatus.Status != conncontroller.Status_Connected {
return fmt.Errorf("not connected: %s", connStatus.Status)
}
return nil
}
opts, err := remote.ParseOpts(connName)
if err != nil {
return fmt.Errorf("error parsing connection name: %w", err)

View File

@ -1,3 +1,6 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package remote
import (

View File

@ -17,6 +17,7 @@ import (
"github.com/wavetermdev/waveterm/pkg/wcore"
"github.com/wavetermdev/waveterm/pkg/wlayout"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wsl"
"github.com/wavetermdev/waveterm/pkg/wstore"
)
@ -77,7 +78,9 @@ func (cs *ClientService) MakeWindow(ctx context.Context) (*waveobj.Window, error
}
func (cs *ClientService) GetAllConnStatus(ctx context.Context) ([]wshrpc.ConnStatus, error) {
return conncontroller.GetAllConnStatus(), nil
sshStatuses := conncontroller.GetAllConnStatus()
wslStatuses := wsl.GetAllConnStatus()
return append(sshStatuses, wslStatuses...), nil
}
// moves the window to the front of the windowId stack

View File

@ -7,6 +7,7 @@ import (
"time"
"github.com/creack/pty"
"github.com/wavetermdev/waveterm/pkg/wsl"
"golang.org/x/crypto/ssh"
)
@ -129,3 +130,42 @@ func (sw SessionWrap) StderrPipe() (io.ReadCloser, error) {
func (sw SessionWrap) SetSize(h int, w int) error {
return sw.Session.WindowChange(h, w)
}
type WslCmdWrap struct {
*wsl.WslCmd
Tty pty.Tty
pty.Pty
}
func (wcw WslCmdWrap) Kill() {
wcw.Tty.Close()
wcw.Close()
}
func (wcw WslCmdWrap) KillGraceful(timeout time.Duration) {
process := wcw.WslCmd.GetProcess()
if process == nil {
return
}
processState := wcw.WslCmd.GetProcessState()
if processState != nil && processState.Exited() {
return
}
process.Signal(os.Interrupt)
go func() {
time.Sleep(timeout)
process := wcw.WslCmd.GetProcess()
processState := wcw.WslCmd.GetProcessState()
if processState == nil || !processState.Exited() {
process.Kill() // force kill if it is already not exited
}
}()
}
/**
* SetSize does nothing for WslCmdWrap as there
* is no pty to manage.
**/
func (wcw WslCmdWrap) SetSize(w int, h int) error {
return nil
}

View File

@ -5,6 +5,7 @@ package shellexec
import (
"bytes"
"context"
"fmt"
"io"
"log"
@ -25,6 +26,7 @@ import (
"github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/waveobj"
"github.com/wavetermdev/waveterm/pkg/wshutil"
"github.com/wavetermdev/waveterm/pkg/wsl"
)
const DefaultGracefulKillWait = 400 * time.Millisecond
@ -141,6 +143,96 @@ func (pp *PipePty) WriteString(s string) (n int, err error) {
return pp.Write([]byte(s))
}
func StartWslShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *wsl.WslConn) (*ShellProc, error) {
client := conn.GetClient()
shellPath := cmdOpts.ShellPath
if shellPath == "" {
remoteShellPath, err := wsl.DetectShell(conn.Context, client)
if err != nil {
return nil, err
}
shellPath = remoteShellPath
}
var shellOpts []string
log.Printf("detected shell: %s", shellPath)
err := wsl.InstallClientRcFiles(conn.Context, client)
if err != nil {
log.Printf("error installing rc files: %v", err)
return nil, err
}
homeDir := wsl.GetHomeDir(conn.Context, client)
shellOpts = append(shellOpts, "~", "-d", client.Name())
if isZshShell(shellPath) {
shellOpts = append(shellOpts, fmt.Sprintf(`ZDOTDIR="%s/.waveterm/%s"`, homeDir, shellutil.ZshIntegrationDir))
}
var subShellOpts []string
if cmdStr == "" {
/* transform command in order to inject environment vars */
if isBashShell(shellPath) {
log.Printf("recognized as bash shell")
// add --rcfile
// cant set -l or -i with --rcfile
subShellOpts = append(subShellOpts, "--rcfile", fmt.Sprintf(`%s/.waveterm/%s/.bashrc`, homeDir, shellutil.BashIntegrationDir))
} else if isFishShell(shellPath) {
carg := fmt.Sprintf(`"set -x PATH \"%s\"/.waveterm/%s $PATH"`, homeDir, shellutil.WaveHomeBinDir)
subShellOpts = append(subShellOpts, "-C", carg)
} else if wsl.IsPowershell(shellPath) {
// powershell is weird about quoted path executables and requires an ampersand first
shellPath = "& " + shellPath
subShellOpts = append(subShellOpts, "-ExecutionPolicy", "Bypass", "-NoExit", "-File", homeDir+fmt.Sprintf("/.waveterm/%s/wavepwsh.ps1", shellutil.PwshIntegrationDir))
} else {
if cmdOpts.Login {
subShellOpts = append(subShellOpts, "-l")
}
if cmdOpts.Interactive {
subShellOpts = append(subShellOpts, "-i")
}
// can't set environment vars this way
// will try to do later if possible
}
} else {
shellPath = cmdStr
if cmdOpts.Login {
subShellOpts = append(subShellOpts, "-l")
}
if cmdOpts.Interactive {
subShellOpts = append(subShellOpts, "-i")
}
subShellOpts = append(subShellOpts, "-c", cmdStr)
}
jwtToken, ok := cmdOpts.Env[wshutil.WaveJwtTokenVarName]
if !ok {
return nil, fmt.Errorf("no jwt token provided to connection")
}
if remote.IsPowershell(shellPath) {
shellOpts = append(shellOpts, "--", fmt.Sprintf(`$env:%s=%s;`, wshutil.WaveJwtTokenVarName, jwtToken))
} else {
shellOpts = append(shellOpts, "--", fmt.Sprintf(`%s=%s`, wshutil.WaveJwtTokenVarName, jwtToken))
}
shellOpts = append(shellOpts, shellPath)
shellOpts = append(shellOpts, subShellOpts...)
log.Printf("full cmd is: %s %s", "wsl.exe", strings.Join(shellOpts, " "))
ecmd := exec.Command("wsl.exe", shellOpts...)
if termSize.Rows == 0 || termSize.Cols == 0 {
termSize.Rows = shellutil.DefaultTermRows
termSize.Cols = shellutil.DefaultTermCols
}
if termSize.Rows <= 0 || termSize.Cols <= 0 {
return nil, fmt.Errorf("invalid term size: %v", termSize)
}
cmdPty, err := pty.StartWithSize(ecmd, &pty.Winsize{Rows: uint16(termSize.Rows), Cols: uint16(termSize.Cols)})
if err != nil {
return nil, err
}
return &ShellProc{Cmd: CmdWrap{ecmd, cmdPty}, ConnName: conn.GetName(), CloseOnce: &sync.Once{}, DoneCh: make(chan any)}, nil
}
func StartRemoteShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *conncontroller.SSHConn) (*ShellProc, error) {
client := conn.GetClient()
shellPath := cmdOpts.ShellPath

View File

@ -0,0 +1,58 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package packetparser
import (
"bufio"
"bytes"
"fmt"
"io"
)
type PacketParser struct {
Reader io.Reader
Ch chan []byte
}
func Parse(input io.Reader, packetCh chan []byte, rawCh chan []byte) error {
bufReader := bufio.NewReader(input)
defer close(packetCh)
defer close(rawCh)
for {
line, err := bufReader.ReadBytes('\n')
if err == io.EOF {
return nil
}
if err != nil {
return err
}
if len(line) <= 1 {
// just a blank line
continue
}
if bytes.HasPrefix(line, []byte{'#', '#', 'N', '{'}) && bytes.HasSuffix(line, []byte{'}', '\n'}) {
// strip off the leading "##" and trailing "\n" (single byte)
packetCh <- line[3 : len(line)-1]
} else {
rawCh <- line
}
}
}
func WritePacket(output io.Writer, packet []byte) error {
if len(packet) < 2 {
return nil
}
if packet[0] != '{' || packet[len(packet)-1] != '}' {
return fmt.Errorf("invalid packet, must start with '{' and end with '}'")
}
fullPacket := make([]byte, 0, len(packet)+5)
// we add the extra newline to make sure the ## appears at the beginning of the line
// since writer isn't buffered, we want to send this all at once
fullPacket = append(fullPacket, '\n', '#', '#', 'N')
fullPacket = append(fullPacket, packet...)
fullPacket = append(fullPacket, '\n')
_, err := output.Write(fullPacket)
return err
}

View File

@ -30,10 +30,13 @@ const WaveDataHomeEnvVar = "WAVETERM_DATA_HOME"
const WaveDevVarName = "WAVETERM_DEV"
const WaveLockFile = "wave.lock"
const DomainSocketBaseName = "wave.sock"
const RemoteDomainSocketBaseName = "wave-remote.sock"
const WaveDBDir = "db"
const JwtSecret = "waveterm" // TODO generate and store this
const ConfigDir = "config"
var RemoteWaveHome = ExpandHomeDirSafe("~/.waveterm")
const WaveAppPathVarName = "WAVETERM_APP_PATH"
const AppPathBinDir = "bin"
@ -101,6 +104,10 @@ func GetDomainSocketName() string {
return filepath.Join(GetWaveDataDir(), DomainSocketBaseName)
}
func GetRemoteDomainSocketName() string {
return filepath.Join(RemoteWaveHome, RemoteDomainSocketBaseName)
}
func GetWaveDataDir() string {
retVal, found := os.LookupEnv(WaveDataHomeEnvVar)
if !found {

View File

@ -431,7 +431,7 @@ func MakeTCPListener(serviceName string) (net.Listener, error) {
}
func MakeUnixListener() (net.Listener, error) {
serverAddr := wavebase.GetWaveDataDir() + "/wave.sock"
serverAddr := wavebase.GetDomainSocketName()
os.Remove(serverAddr) // ignore error
rtn, err := net.Listen("unix", serverAddr)
if err != nil {

View File

@ -252,7 +252,7 @@ func registerConn(wsConnId string, routeId string, wproxy *wshutil.WshRpcProxy)
wshutil.DefaultRouter.UnregisterRoute(routeId)
}
RouteToConnMap[routeId] = wsConnId
wshutil.DefaultRouter.RegisterRoute(routeId, wproxy)
wshutil.DefaultRouter.RegisterRoute(routeId, wproxy, true)
}
func unregisterConn(wsConnId string, routeId string) {

View File

@ -92,6 +92,12 @@ func DeleteBlockCommand(w *wshutil.WshRpc, data wshrpc.CommandDeleteBlockData, o
return err
}
// command "dispose", wshserver.DisposeCommand
func DisposeCommand(w *wshutil.WshRpc, data wshrpc.CommandDisposeData, opts *wshrpc.RpcOpts) error {
_, err := sendRpcRequestCallHelper[any](w, "dispose", data, opts)
return err
}
// command "eventpublish", wshserver.EventPublishCommand
func EventPublishCommand(w *wshutil.WshRpc, data wps.WaveEvent, opts *wshrpc.RpcOpts) error {
_, err := sendRpcRequestCallHelper[any](w, "eventpublish", data, opts)
@ -285,4 +291,22 @@ func WebSelectorCommand(w *wshutil.WshRpc, data wshrpc.CommandWebSelectorData, o
return resp, err
}
// command "wsldefaultdistro", wshserver.WslDefaultDistroCommand
func WslDefaultDistroCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) (string, error) {
resp, err := sendRpcRequestCallHelper[string](w, "wsldefaultdistro", nil, opts)
return resp, err
}
// command "wsllist", wshserver.WslListCommand
func WslListCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) ([]string, error) {
resp, err := sendRpcRequestCallHelper[[]string](w, "wsllist", nil, opts)
return resp, err
}
// command "wslstatus", wshserver.WslStatusCommand
func WslStatusCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) ([]wshrpc.ConnStatus, error) {
resp, err := sendRpcRequestCallHelper[[]wshrpc.ConnStatus](w, "wslstatus", nil, opts)
return resp, err
}

View File

@ -28,6 +28,7 @@ const (
const (
Command_Authenticate = "authenticate" // special
Command_Dispose = "dispose" // special (disposes of the route, for multiproxy only)
Command_RouteAnnounce = "routeannounce" // special (for routing)
Command_RouteUnannounce = "routeunannounce" // special (for routing)
Command_Message = "message"
@ -62,11 +63,15 @@ const (
Command_RemoteFileDelete = "remotefiledelete"
Command_RemoteFileJoiin = "remotefilejoin"
Command_ConnStatus = "connstatus"
Command_WslStatus = "wslstatus"
Command_ConnEnsure = "connensure"
Command_ConnReinstallWsh = "connreinstallwsh"
Command_ConnConnect = "connconnect"
Command_ConnDisconnect = "conndisconnect"
Command_ConnList = "connlist"
Command_WslList = "wsllist"
Command_WslDefaultDistro = "wsldefaultdistro"
Command_WebSelector = "webselector"
Command_Notify = "notify"
@ -83,6 +88,7 @@ type RespOrErrorUnion[T any] struct {
type WshRpcInterface interface {
AuthenticateCommand(ctx context.Context, data string) (CommandAuthenticateRtnData, error)
DisposeCommand(ctx context.Context, data CommandDisposeData) error
RouteAnnounceCommand(ctx context.Context) error // (special) announces a new route to the main router
RouteUnannounceCommand(ctx context.Context) error // (special) unannounces a route to the main router
@ -114,11 +120,14 @@ type WshRpcInterface interface {
// connection functions
ConnStatusCommand(ctx context.Context) ([]ConnStatus, error)
WslStatusCommand(ctx context.Context) ([]ConnStatus, error)
ConnEnsureCommand(ctx context.Context, connName string) error
ConnReinstallWshCommand(ctx context.Context, connName string) error
ConnConnectCommand(ctx context.Context, connName string) error
ConnDisconnectCommand(ctx context.Context, connName string) error
ConnListCommand(ctx context.Context) ([]string, error)
WslListCommand(ctx context.Context) ([]string, error)
WslDefaultDistroCommand(ctx context.Context) (string, error)
// eventrecv is special, it's handled internally by WshRpc with EventListener
EventRecvCommand(ctx context.Context, data wps.WaveEvent) error
@ -200,7 +209,13 @@ func HackRpcContextIntoData(dataPtr any, rpcContext RpcContext) {
}
type CommandAuthenticateRtnData struct {
RouteId string `json:"routeid"`
AuthToken string `json:"authtoken,omitempty"`
}
type CommandDisposeData struct {
RouteId string `json:"routeid"`
// auth token travels in the packet directly
}
type CommandMessageData struct {

View File

@ -21,6 +21,7 @@ import (
"github.com/wavetermdev/waveterm/pkg/filestore"
"github.com/wavetermdev/waveterm/pkg/remote"
"github.com/wavetermdev/waveterm/pkg/remote/conncontroller"
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
"github.com/wavetermdev/waveterm/pkg/waveai"
"github.com/wavetermdev/waveterm/pkg/waveobj"
"github.com/wavetermdev/waveterm/pkg/wconfig"
@ -29,6 +30,7 @@ import (
"github.com/wavetermdev/waveterm/pkg/wps"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshutil"
"github.com/wavetermdev/waveterm/pkg/wsl"
"github.com/wavetermdev/waveterm/pkg/wstore"
)
@ -36,6 +38,7 @@ const SimpleId_This = "this"
const SimpleId_Tab = "tab"
var SimpleId_BlockNum_Regex = regexp.MustCompile(`^\d+$`)
var InvalidWslDistroNames = []string{"docker-desktop", "docker-desktop-data"}
type WshServer struct{}
@ -463,11 +466,28 @@ func (ws *WshServer) ConnStatusCommand(ctx context.Context) ([]wshrpc.ConnStatus
return rtn, nil
}
func (ws *WshServer) WslStatusCommand(ctx context.Context) ([]wshrpc.ConnStatus, error) {
rtn := wsl.GetAllConnStatus()
return rtn, nil
}
func (ws *WshServer) ConnEnsureCommand(ctx context.Context, connName string) error {
if strings.HasPrefix(connName, "wsl://") {
distroName := strings.TrimPrefix(connName, "wsl://")
return wsl.EnsureConnection(ctx, distroName)
}
return conncontroller.EnsureConnection(ctx, connName)
}
func (ws *WshServer) ConnDisconnectCommand(ctx context.Context, connName string) error {
if strings.HasPrefix(connName, "wsl://") {
distroName := strings.TrimPrefix(connName, "wsl://")
conn := wsl.GetWslConn(ctx, distroName, false)
if conn == nil {
return fmt.Errorf("distro not found: %s", connName)
}
return conn.Close()
}
connOpts, err := remote.ParseOpts(connName)
if err != nil {
return fmt.Errorf("error parsing connection name: %w", err)
@ -480,6 +500,14 @@ func (ws *WshServer) ConnDisconnectCommand(ctx context.Context, connName string)
}
func (ws *WshServer) ConnConnectCommand(ctx context.Context, connName string) error {
if strings.HasPrefix(connName, "wsl://") {
distroName := strings.TrimPrefix(connName, "wsl://")
conn := wsl.GetWslConn(ctx, distroName, false)
if conn == nil {
return fmt.Errorf("connection not found: %s", connName)
}
return conn.Connect(ctx)
}
connOpts, err := remote.ParseOpts(connName)
if err != nil {
return fmt.Errorf("error parsing connection name: %w", err)
@ -492,6 +520,14 @@ func (ws *WshServer) ConnConnectCommand(ctx context.Context, connName string) er
}
func (ws *WshServer) ConnReinstallWshCommand(ctx context.Context, connName string) error {
if strings.HasPrefix(connName, "wsl://") {
distroName := strings.TrimPrefix(connName, "wsl://")
conn := wsl.GetWslConn(ctx, distroName, false)
if conn == nil {
return fmt.Errorf("connection not found: %s", connName)
}
return conn.CheckAndInstallWsh(ctx, connName, &wsl.WshInstallOpts{Force: true, NoUserPrompt: true})
}
connOpts, err := remote.ParseOpts(connName)
if err != nil {
return fmt.Errorf("error parsing connection name: %w", err)
@ -507,6 +543,33 @@ func (ws *WshServer) ConnListCommand(ctx context.Context) ([]string, error) {
return conncontroller.GetConnectionsList()
}
func (ws *WshServer) WslListCommand(ctx context.Context) ([]string, error) {
distros, err := wsl.RegisteredDistros(ctx)
if err != nil {
return nil, err
}
var distroNames []string
for _, distro := range distros {
distroName := distro.Name()
if utilfn.ContainsStr(InvalidWslDistroNames, distroName) {
continue
}
distroNames = append(distroNames, distroName)
}
return distroNames, nil
}
func (ws *WshServer) WslDefaultDistroCommand(ctx context.Context) (string, error) {
distro, ok, err := wsl.DefaultDistro(ctx)
if err != nil {
return "", fmt.Errorf("unable to determine default distro: %w", err)
}
if !ok {
return "", fmt.Errorf("unable to determine default distro")
}
return distro.Name(), nil
}
func (ws *WshServer) BlockInfoCommand(ctx context.Context, blockId string) (*wshrpc.BlockInfoData, error) {
blockData, err := wstore.DBMustGet[*waveobj.Block](ctx, blockId)
if err != nil {

View File

@ -0,0 +1,151 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package wshutil
import (
"encoding/json"
"fmt"
"sync"
"github.com/google/uuid"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
)
type multiProxyRouteInfo struct {
RouteId string
AuthToken string
Proxy *WshRpcProxy
RpcContext *wshrpc.RpcContext
}
// handles messages from multiple unauthenitcated clients
type WshRpcMultiProxy struct {
Lock *sync.Mutex
RouteInfo map[string]*multiProxyRouteInfo // authtoken to info
ToRemoteCh chan []byte
FromRemoteRawCh chan []byte // raw message from the remote
}
func MakeRpcMultiProxy() *WshRpcMultiProxy {
return &WshRpcMultiProxy{
Lock: &sync.Mutex{},
RouteInfo: make(map[string]*multiProxyRouteInfo),
ToRemoteCh: make(chan []byte, DefaultInputChSize),
FromRemoteRawCh: make(chan []byte, DefaultOutputChSize),
}
}
func (p *WshRpcMultiProxy) DisposeRoutes() {
p.Lock.Lock()
defer p.Lock.Unlock()
for authToken, routeInfo := range p.RouteInfo {
DefaultRouter.UnregisterRoute(routeInfo.RouteId)
delete(p.RouteInfo, authToken)
}
}
func (p *WshRpcMultiProxy) getRouteInfo(authToken string) *multiProxyRouteInfo {
p.Lock.Lock()
defer p.Lock.Unlock()
return p.RouteInfo[authToken]
}
func (p *WshRpcMultiProxy) setRouteInfo(authToken string, routeInfo *multiProxyRouteInfo) {
p.Lock.Lock()
defer p.Lock.Unlock()
p.RouteInfo[authToken] = routeInfo
}
func (p *WshRpcMultiProxy) removeRouteInfo(authToken string) {
p.Lock.Lock()
defer p.Lock.Unlock()
delete(p.RouteInfo, authToken)
}
func (p *WshRpcMultiProxy) sendResponseError(msg RpcMessage, sendErr error) {
if msg.ReqId == "" {
// no response needed
return
}
resp := RpcMessage{
ResId: msg.ReqId,
Error: sendErr.Error(),
}
respBytes, _ := json.Marshal(resp)
p.ToRemoteCh <- respBytes
}
func (p *WshRpcMultiProxy) sendAuthResponse(msg RpcMessage, routeId string, authToken string) {
if msg.ReqId == "" {
// no response needed
return
}
resp := RpcMessage{
ResId: msg.ReqId,
Data: wshrpc.CommandAuthenticateRtnData{RouteId: routeId, AuthToken: authToken},
}
respBytes, _ := json.Marshal(resp)
p.ToRemoteCh <- respBytes
}
func (p *WshRpcMultiProxy) handleUnauthMessage(msgBytes []byte) {
var msg RpcMessage
err := json.Unmarshal(msgBytes, &msg)
if err != nil {
// nothing to do here, malformed message
return
}
if msg.Command == wshrpc.Command_Authenticate {
rpcContext, routeId, err := handleAuthenticationCommand(msg)
if err != nil {
p.sendResponseError(msg, err)
return
}
routeInfo := &multiProxyRouteInfo{
RouteId: routeId,
AuthToken: uuid.New().String(),
RpcContext: rpcContext,
}
routeInfo.Proxy = MakeRpcProxy()
routeInfo.Proxy.SetRpcContext(rpcContext)
p.setRouteInfo(routeInfo.AuthToken, routeInfo)
p.sendAuthResponse(msg, routeId, routeInfo.AuthToken)
go func() {
for msgBytes := range routeInfo.Proxy.ToRemoteCh {
p.ToRemoteCh <- msgBytes
}
}()
DefaultRouter.RegisterRoute(routeId, routeInfo.Proxy, true)
return
}
if msg.AuthToken == "" {
p.sendResponseError(msg, fmt.Errorf("no auth token"))
return
}
routeInfo := p.getRouteInfo(msg.AuthToken)
if routeInfo == nil {
p.sendResponseError(msg, fmt.Errorf("invalid auth token"))
return
}
if msg.Command != "" && msg.Source != routeInfo.RouteId {
p.sendResponseError(msg, fmt.Errorf("invalid source route for auth token"))
return
}
if msg.Command == wshrpc.Command_Dispose {
DefaultRouter.UnregisterRoute(routeInfo.RouteId)
p.removeRouteInfo(msg.AuthToken)
close(routeInfo.Proxy.ToRemoteCh)
close(routeInfo.Proxy.FromRemoteCh)
return
}
routeInfo.Proxy.FromRemoteCh <- msgBytes
}
func (p *WshRpcMultiProxy) RunUnauthLoop() {
// loop over unauthenticated message
// handle Authenicate commands, and pass authenticated messages to the AuthCh
for msgBytes := range p.FromRemoteRawCh {
p.handleUnauthMessage(msgBytes)
}
}

View File

@ -6,7 +6,6 @@ package wshutil
import (
"encoding/json"
"fmt"
"log"
"sync"
"github.com/google/uuid"
@ -18,6 +17,7 @@ type WshRpcProxy struct {
RpcContext *wshrpc.RpcContext
ToRemoteCh chan []byte
FromRemoteCh chan []byte
AuthToken string
}
func MakeRpcProxy() *WshRpcProxy {
@ -40,6 +40,18 @@ func (p *WshRpcProxy) GetRpcContext() *wshrpc.RpcContext {
return p.RpcContext
}
func (p *WshRpcProxy) SetAuthToken(authToken string) {
p.Lock.Lock()
defer p.Lock.Unlock()
p.AuthToken = authToken
}
func (p *WshRpcProxy) GetAuthToken() string {
p.Lock.Lock()
defer p.Lock.Unlock()
return p.AuthToken
}
func (p *WshRpcProxy) sendResponseError(msg RpcMessage, sendErr error) {
if msg.ReqId == "" {
// no response needed
@ -54,7 +66,7 @@ func (p *WshRpcProxy) sendResponseError(msg RpcMessage, sendErr error) {
p.SendRpcMessage(respBytes)
}
func (p *WshRpcProxy) sendResponse(msg RpcMessage, routeId string) {
func (p *WshRpcProxy) sendAuthenticateResponse(msg RpcMessage, routeId string) {
if msg.ReqId == "" {
// no response needed
return
@ -98,6 +110,49 @@ func handleAuthenticationCommand(msg RpcMessage) (*wshrpc.RpcContext, string, er
return newCtx, routeId, nil
}
// runs on the client (stdio client)
func (p *WshRpcProxy) HandleClientProxyAuth(router *WshRouter) (string, error) {
for {
msgBytes, ok := <-p.FromRemoteCh
if !ok {
return "", fmt.Errorf("remote closed, not authenticated")
}
var origMsg RpcMessage
err := json.Unmarshal(msgBytes, &origMsg)
if err != nil {
// nothing to do, can't even send a response since we don't have Source or ReqId
continue
}
if origMsg.Command == "" {
// this message is not allowed (protocol error at this point), ignore
continue
}
// we only allow one command "authenticate", everything else returns an error
if origMsg.Command != wshrpc.Command_Authenticate {
respErr := fmt.Errorf("connection not authenticated")
p.sendResponseError(origMsg, respErr)
continue
}
authRtn, err := router.HandleProxyAuth(origMsg.Data)
if err != nil {
respErr := fmt.Errorf("error handling proxy auth: %w", err)
p.sendResponseError(origMsg, respErr)
return "", respErr
}
p.SetAuthToken(authRtn.AuthToken)
announceMsg := RpcMessage{
Command: wshrpc.Command_RouteAnnounce,
Source: authRtn.RouteId,
AuthToken: authRtn.AuthToken,
}
announceBytes, _ := json.Marshal(announceMsg)
router.InjectMessage(announceBytes, authRtn.RouteId)
p.sendAuthenticateResponse(origMsg, authRtn.RouteId)
return authRtn.RouteId, nil
}
}
// runs on the server
func (p *WshRpcProxy) HandleAuthentication() (*wshrpc.RpcContext, error) {
for {
msgBytes, ok := <-p.FromRemoteCh
@ -122,11 +177,10 @@ func (p *WshRpcProxy) HandleAuthentication() (*wshrpc.RpcContext, error) {
}
newCtx, routeId, err := handleAuthenticationCommand(msg)
if err != nil {
log.Printf("error handling authentication: %v\n", err)
p.sendResponseError(msg, err)
continue
}
p.sendResponse(msg, routeId)
p.sendAuthenticateResponse(msg, routeId)
return newCtx, nil
}
}
@ -136,9 +190,10 @@ func (p *WshRpcProxy) SendRpcMessage(msg []byte) {
}
func (p *WshRpcProxy) RecvRpcMessage() ([]byte, bool) {
msgBytes, ok := <-p.FromRemoteCh
if !ok || p.RpcContext == nil {
return msgBytes, ok
msgBytes, more := <-p.FromRemoteCh
authToken := p.GetAuthToken()
if !more || (p.RpcContext == nil && authToken == "") {
return msgBytes, more
}
var msg RpcMessage
err := json.Unmarshal(msgBytes, &msg)
@ -146,10 +201,15 @@ func (p *WshRpcProxy) RecvRpcMessage() ([]byte, bool) {
// nothing to do here -- will error out at another level
return msgBytes, true
}
msg.Data, err = recodeCommandData(msg.Command, msg.Data, p.RpcContext)
if err != nil {
// nothing to do here -- will error out at another level
return msgBytes, true
if p.RpcContext != nil {
msg.Data, err = recodeCommandData(msg.Command, msg.Data, p.RpcContext)
if err != nil {
// nothing to do here -- will error out at another level
return msgBytes, true
}
}
if msg.AuthToken == "" {
msg.AuthToken = authToken
}
newBytes, err := json.Marshal(msg)
if err != nil {

View File

@ -12,11 +12,14 @@ import (
"sync"
"time"
"github.com/google/uuid"
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
"github.com/wavetermdev/waveterm/pkg/wps"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
)
const DefaultRoute = "wavesrv"
const UpstreamRoute = "upstream"
const SysRoute = "sys" // this route doesn't exist, just a placeholder for system messages
const ElectronRoute = "electron"
@ -36,12 +39,13 @@ type msgAndRoute struct {
}
type WshRouter struct {
Lock *sync.Mutex
RouteMap map[string]AbstractRpcClient // routeid => client
UpstreamClient AbstractRpcClient // upstream client (if we are not the terminal router)
AnnouncedRoutes map[string]string // routeid => local routeid
RpcMap map[string]*routeInfo // rpcid => routeinfo
InputCh chan msgAndRoute
Lock *sync.Mutex
RouteMap map[string]AbstractRpcClient // routeid => client
UpstreamClient AbstractRpcClient // upstream client (if we are not the terminal router)
AnnouncedRoutes map[string]string // routeid => local routeid
RpcMap map[string]*routeInfo // rpcid => routeinfo
SimpleRequestMap map[string]chan *RpcMessage // simple reqid => response channel
InputCh chan msgAndRoute
}
func MakeConnectionRouteId(connId string) string {
@ -68,11 +72,12 @@ var DefaultRouter = NewWshRouter()
func NewWshRouter() *WshRouter {
rtn := &WshRouter{
Lock: &sync.Mutex{},
RouteMap: make(map[string]AbstractRpcClient),
AnnouncedRoutes: make(map[string]string),
RpcMap: make(map[string]*routeInfo),
InputCh: make(chan msgAndRoute, DefaultInputChSize),
Lock: &sync.Mutex{},
RouteMap: make(map[string]AbstractRpcClient),
AnnouncedRoutes: make(map[string]string),
RpcMap: make(map[string]*routeInfo),
SimpleRequestMap: make(map[string]chan *RpcMessage),
InputCh: make(chan msgAndRoute, DefaultInputChSize),
}
go rtn.runServer()
return rtn
@ -237,6 +242,10 @@ func (router *WshRouter) runServer() {
router.sendRoutedMessage(msgBytes, routeInfo.DestRouteId)
continue
} else if msg.ResId != "" {
ok := router.trySimpleResponse(&msg)
if ok {
continue
}
routeInfo := router.getRouteInfo(msg.ResId)
if routeInfo == nil {
// no route info, nothing to do
@ -269,10 +278,10 @@ func (router *WshRouter) WaitForRegister(ctx context.Context, routeId string) er
}
// this will also consume the output channel of the abstract client
func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient) {
if routeId == SysRoute {
func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient, shouldAnnounce bool) {
if routeId == SysRoute || routeId == UpstreamRoute {
// cannot register sys route
log.Printf("error: WshRouter cannot register sys route\n")
log.Printf("error: WshRouter cannot register %s route\n", routeId)
return
}
log.Printf("[router] registering wsh route %q\n", routeId)
@ -285,7 +294,7 @@ func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient) {
router.RouteMap[routeId] = rpc
go func() {
// announce
if !alreadyExists && router.GetUpstreamClient() != nil {
if shouldAnnounce && !alreadyExists && router.GetUpstreamClient() != nil {
announceMsg := RpcMessage{Command: wshrpc.Command_RouteAnnounce, Source: routeId}
announceBytes, _ := json.Marshal(announceMsg)
router.GetUpstreamClient().SendRpcMessage(announceBytes)
@ -352,3 +361,97 @@ func (router *WshRouter) GetUpstreamClient() AbstractRpcClient {
defer router.Lock.Unlock()
return router.UpstreamClient
}
func (router *WshRouter) InjectMessage(msgBytes []byte, fromRouteId string) {
router.InputCh <- msgAndRoute{msgBytes: msgBytes, fromRouteId: fromRouteId}
}
func (router *WshRouter) registerSimpleRequest(reqId string) chan *RpcMessage {
router.Lock.Lock()
defer router.Lock.Unlock()
rtn := make(chan *RpcMessage, 1)
router.SimpleRequestMap[reqId] = rtn
return rtn
}
func (router *WshRouter) trySimpleResponse(msg *RpcMessage) bool {
router.Lock.Lock()
defer router.Lock.Unlock()
respCh := router.SimpleRequestMap[msg.ResId]
if respCh == nil {
return false
}
respCh <- msg
delete(router.SimpleRequestMap, msg.ResId)
return true
}
func (router *WshRouter) clearSimpleRequest(reqId string) {
router.Lock.Lock()
defer router.Lock.Unlock()
delete(router.SimpleRequestMap, reqId)
}
func (router *WshRouter) RunSimpleRawCommand(ctx context.Context, msg RpcMessage, fromRouteId string) (*RpcMessage, error) {
if msg.Command == "" {
return nil, errors.New("no command")
}
msgBytes, err := json.Marshal(msg)
if err != nil {
return nil, err
}
var respCh chan *RpcMessage
if msg.ReqId != "" {
respCh = router.registerSimpleRequest(msg.ReqId)
}
router.InjectMessage(msgBytes, fromRouteId)
if respCh == nil {
return nil, nil
}
select {
case <-ctx.Done():
router.clearSimpleRequest(msg.ReqId)
return nil, ctx.Err()
case resp := <-respCh:
if resp.Error != "" {
return nil, errors.New(resp.Error)
}
return resp, nil
}
}
func (router *WshRouter) HandleProxyAuth(jwtTokenAny any) (*wshrpc.CommandAuthenticateRtnData, error) {
if jwtTokenAny == nil {
return nil, errors.New("no jwt token")
}
jwtToken, ok := jwtTokenAny.(string)
if !ok {
return nil, errors.New("jwt token not a string")
}
if jwtToken == "" {
return nil, errors.New("empty jwt token")
}
msg := RpcMessage{
Command: wshrpc.Command_Authenticate,
ReqId: uuid.New().String(),
Data: jwtToken,
}
ctx, cancelFn := context.WithTimeout(context.Background(), DefaultTimeoutMs*time.Millisecond)
defer cancelFn()
resp, err := router.RunSimpleRawCommand(ctx, msg, "")
if err != nil {
return nil, err
}
if resp == nil || resp.Data == nil {
return nil, errors.New("no data in authenticate response")
}
var respData wshrpc.CommandAuthenticateRtnData
err = utilfn.ReUnmarshal(&respData, resp.Data)
if err != nil {
return nil, fmt.Errorf("error unmarshalling authenticate response: %v", err)
}
if respData.AuthToken == "" {
return nil, errors.New("no auth token in authenticate response")
}
return &respData, nil
}

View File

@ -45,10 +45,13 @@ type WshRpc struct {
InputCh chan []byte
OutputCh chan []byte
RpcContext *atomic.Pointer[wshrpc.RpcContext]
AuthToken string
RpcMap map[string]*rpcData
ServerImpl ServerImpl
EventListener *EventListener
ResponseHandlerMap map[string]*RpcResponseHandler // reqId => handler
Debug bool
DebugName string
}
type wshRpcContextKey struct{}
@ -104,17 +107,18 @@ func (w *WshRpc) RecvRpcMessage() ([]byte, bool) {
}
type RpcMessage struct {
Command string `json:"command,omitempty"`
ReqId string `json:"reqid,omitempty"`
ResId string `json:"resid,omitempty"`
Timeout int `json:"timeout,omitempty"`
Route string `json:"route,omitempty"` // to route/forward requests to alternate servers
Source string `json:"source,omitempty"` // source route id
Cont bool `json:"cont,omitempty"` // flag if additional requests/responses are forthcoming
Cancel bool `json:"cancel,omitempty"` // used to cancel a streaming request or response (sent from the side that is not streaming)
Error string `json:"error,omitempty"`
DataType string `json:"datatype,omitempty"`
Data any `json:"data,omitempty"`
Command string `json:"command,omitempty"`
ReqId string `json:"reqid,omitempty"`
ResId string `json:"resid,omitempty"`
Timeout int `json:"timeout,omitempty"`
Route string `json:"route,omitempty"` // to route/forward requests to alternate servers
AuthToken string `json:"authtoken,omitempty"` // needed for routing unauthenticated requests (WshRpcMultiProxy)
Source string `json:"source,omitempty"` // source route id
Cont bool `json:"cont,omitempty"` // flag if additional requests/responses are forthcoming
Cancel bool `json:"cancel,omitempty"` // used to cancel a streaming request or response (sent from the side that is not streaming)
Error string `json:"error,omitempty"`
DataType string `json:"datatype,omitempty"`
Data any `json:"data,omitempty"`
}
func (r *RpcMessage) IsRpcRequest() bool {
@ -226,6 +230,14 @@ func (w *WshRpc) SetRpcContext(ctx wshrpc.RpcContext) {
w.RpcContext.Store(&ctx)
}
func (w *WshRpc) SetAuthToken(token string) {
w.AuthToken = token
}
func (w *WshRpc) GetAuthToken() string {
return w.AuthToken
}
func (w *WshRpc) registerResponseHandler(reqId string, handler *RpcResponseHandler) {
w.Lock.Lock()
defer w.Lock.Unlock()
@ -323,6 +335,9 @@ func (w *WshRpc) handleRequest(req *RpcMessage) {
func (w *WshRpc) runServer() {
defer close(w.OutputCh)
for msgBytes := range w.InputCh {
if w.Debug {
log.Printf("[%s] received message: %s\n", w.DebugName, string(msgBytes))
}
var msg RpcMessage
err := json.Unmarshal(msgBytes, &msg)
if err != nil {
@ -455,8 +470,9 @@ func (handler *RpcRequestHandler) SendCancel() {
}
}()
msg := &RpcMessage{
Cancel: true,
ReqId: handler.reqId,
Cancel: true,
ReqId: handler.reqId,
AuthToken: handler.w.GetAuthToken(),
}
barr, _ := json.Marshal(msg) // will never fail
handler.w.OutputCh <- barr
@ -550,6 +566,7 @@ func (handler *RpcResponseHandler) SendMessage(msg string) {
Data: wshrpc.CommandMessageData{
Message: msg,
},
AuthToken: handler.w.GetAuthToken(),
}
msgBytes, _ := json.Marshal(rpcMsg) // will never fail
handler.w.OutputCh <- msgBytes
@ -573,9 +590,10 @@ func (handler *RpcResponseHandler) SendResponse(data any, done bool) error {
defer handler.close()
}
msg := &RpcMessage{
ResId: handler.reqId,
Data: data,
Cont: !done,
ResId: handler.reqId,
Data: data,
Cont: !done,
AuthToken: handler.w.GetAuthToken(),
}
barr, err := json.Marshal(msg)
if err != nil {
@ -598,8 +616,9 @@ func (handler *RpcResponseHandler) SendResponseError(err error) {
}
defer handler.close()
msg := &RpcMessage{
ResId: handler.reqId,
Error: err.Error(),
ResId: handler.reqId,
Error: err.Error(),
AuthToken: handler.w.GetAuthToken(),
}
barr, _ := json.Marshal(msg) // will never fail
handler.w.OutputCh <- barr
@ -660,11 +679,12 @@ func (w *WshRpc) SendComplexRequest(command string, data any, opts *wshrpc.RpcOp
handler.reqId = uuid.New().String()
}
req := &RpcMessage{
Command: command,
ReqId: handler.reqId,
Data: data,
Timeout: timeoutMs,
Route: opts.Route,
Command: command,
ReqId: handler.reqId,
Data: data,
Timeout: timeoutMs,
Route: opts.Route,
AuthToken: w.GetAuthToken(),
}
barr, err := json.Marshal(req)
if err != nil {

View File

@ -19,6 +19,7 @@ import (
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/wavetermdev/waveterm/pkg/util/packetparser"
"github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"golang.org/x/term"
@ -204,11 +205,26 @@ func SetupTerminalRpcClient(serverImpl ServerImpl) (*WshRpc, io.Reader) {
continue
}
os.Stdout.Write(barr)
os.Stdout.Write([]byte{'\n'})
}
}()
return rpcClient, ptyBuf
}
func SetupPacketRpcClient(input io.Reader, output io.Writer, serverImpl ServerImpl) (*WshRpc, chan []byte) {
messageCh := make(chan []byte, DefaultInputChSize)
outputCh := make(chan []byte, DefaultOutputChSize)
rawCh := make(chan []byte, DefaultOutputChSize)
rpcClient := MakeWshRpc(messageCh, outputCh, wshrpc.RpcContext{}, serverImpl)
go packetparser.Parse(input, messageCh, rawCh)
go func() {
for msg := range outputCh {
packetparser.WritePacket(output, msg)
}
}()
return rpcClient, rawCh
}
func SetupConnRpcClient(conn net.Conn, serverImpl ServerImpl) (*WshRpc, chan error, error) {
inputCh := make(chan []byte, DefaultInputChSize)
outputCh := make(chan []byte, DefaultOutputChSize)
@ -229,10 +245,22 @@ func SetupConnRpcClient(conn net.Conn, serverImpl ServerImpl) (*WshRpc, chan err
return rtn, writeErrCh, nil
}
func SetupDomainSocketRpcClient(sockName string, serverImpl ServerImpl) (*WshRpc, error) {
conn, err := net.Dial("unix", sockName)
func tryTcpSocket(sockName string) (net.Conn, error) {
addr, err := net.ResolveTCPAddr("tcp", sockName)
if err != nil {
return nil, fmt.Errorf("failed to connect to Unix domain socket: %w", err)
return nil, err
}
return net.DialTCP("tcp", nil, addr)
}
func SetupDomainSocketRpcClient(sockName string, serverImpl ServerImpl) (*WshRpc, error) {
conn, tcpErr := tryTcpSocket(sockName)
var unixErr error
if tcpErr != nil {
conn, unixErr = net.Dial("unix", sockName)
}
if tcpErr != nil && unixErr != nil {
return nil, fmt.Errorf("failed to connect to tcp or unix domain socket: tcp err:%w: unix socket err: %w", tcpErr, unixErr)
}
rtn, errCh, err := SetupConnRpcClient(conn, serverImpl)
go func() {
@ -363,6 +391,46 @@ func MakeRouteIdFromCtx(rpcCtx *wshrpc.RpcContext) (string, error) {
return MakeProcRouteId(procId), nil
}
type WriteFlusher interface {
Write([]byte) (int, error)
Flush() error
}
// blocking, returns if there is an error, or on EOF of input
func HandleStdIOClient(logName string, input io.Reader, output io.Writer) {
proxy := MakeRpcMultiProxy()
rawCh := make(chan []byte, DefaultInputChSize)
go packetparser.Parse(input, proxy.FromRemoteRawCh, rawCh)
doneCh := make(chan struct{})
var doneOnce sync.Once
closeDoneCh := func() {
doneOnce.Do(func() {
close(doneCh)
})
proxy.DisposeRoutes()
}
go func() {
proxy.RunUnauthLoop()
}()
go func() {
defer closeDoneCh()
for msg := range proxy.ToRemoteCh {
err := packetparser.WritePacket(output, msg)
if err != nil {
log.Printf("[%s] error writing to output: %v\n", logName, err)
break
}
}
}()
go func() {
defer closeDoneCh()
for msg := range rawCh {
log.Printf("[%s:stdout] %s", logName, msg)
}
}()
<-doneCh
}
func handleDomainSocketClient(conn net.Conn) {
var routeIdContainer atomic.Pointer[string]
proxy := MakeRpcProxy()
@ -399,7 +467,7 @@ func handleDomainSocketClient(conn net.Conn) {
return
}
routeIdContainer.Store(&routeId)
DefaultRouter.RegisterRoute(routeId, proxy)
DefaultRouter.RegisterRoute(routeId, proxy, true)
}
// only for use on client
@ -433,5 +501,6 @@ func ExtractUnverifiedSocketName(tokenStr string) (string, error) {
if !ok {
return "", fmt.Errorf("sock claim is missing or invalid")
}
sockName = wavebase.ExpandHomeDirSafe(sockName)
return sockName, nil
}

67
pkg/wsl/wsl-unix.go Normal file
View File

@ -0,0 +1,67 @@
//go:build !windows
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package wsl
import (
"context"
"fmt"
"io"
"os"
"os/exec"
)
func RegisteredDistros(ctx context.Context) (distros []Distro, err error) {
return nil, fmt.Errorf("RegisteredDistros not implemented on this system")
}
func DefaultDistro(ctx context.Context) (d Distro, ok bool, err error) {
return d, false, fmt.Errorf("DefaultDistro not implemented on this system")
}
type Distro struct{}
func (d *Distro) Name() string {
return ""
}
func (d *Distro) WslCommand(ctx context.Context, cmd string) *WslCmd {
return nil
}
// just use the regular cmd since it's
// similar enough to not cause issues
// type WslCmd = exec.Cmd
type WslCmd struct {
exec.Cmd
}
func (wc *WslCmd) GetProcess() *os.Process {
return nil
}
func (wc *WslCmd) GetProcessState() *os.ProcessState {
return nil
}
func (c *WslCmd) SetStdin(stdin io.Reader) {
c.Stdin = stdin
}
func (c *WslCmd) SetStdout(stdout io.Writer) {
c.Stdout = stdout
}
func (c *WslCmd) SetStderr(stderr io.Writer) {
c.Stdout = stderr
}
func GetDistroCmd(ctx context.Context, wslDistroName string, cmd string) (*WslCmd, error) {
return nil, fmt.Errorf("GetDistroCmd not implemented on this system")
}
func GetDistro(ctx context.Context, wslDistroName WslName) (*Distro, error) {
return nil, fmt.Errorf("GetDistro not implemented on this system")
}

296
pkg/wsl/wsl-util.go Normal file
View File

@ -0,0 +1,296 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package wsl
import (
"bytes"
"context"
"errors"
"fmt"
"html/template"
"io"
"log"
"os"
"path/filepath"
"strings"
"time"
)
func DetectShell(ctx context.Context, client *Distro) (string, error) {
wshPath := GetWshPath(ctx, client)
cmd := client.WslCommand(ctx, wshPath+" shell")
log.Printf("shell detecting using command: %s shell", wshPath)
out, err := cmd.Output()
if err != nil {
log.Printf("unable to determine shell. defaulting to /bin/bash: %s", err)
return "/bin/bash", nil
}
log.Printf("detecting shell: %s", out)
// quoting breaks this particular case
return strings.TrimSpace(string(out)), nil
}
func GetWshVersion(ctx context.Context, client *Distro) (string, error) {
wshPath := GetWshPath(ctx, client)
cmd := client.WslCommand(ctx, wshPath+" version")
out, err := cmd.Output()
if err != nil {
return "", err
}
return strings.TrimSpace(string(out)), nil
}
func GetWshPath(ctx context.Context, client *Distro) string {
defaultPath := "~/.waveterm/bin/wsh"
cmd := client.WslCommand(ctx, "which wsh")
out, whichErr := cmd.Output()
if whichErr == nil {
return strings.TrimSpace(string(out))
}
cmd = client.WslCommand(ctx, "where.exe wsh")
out, whereErr := cmd.Output()
if whereErr == nil {
return strings.TrimSpace(string(out))
}
// check cmd on windows since it requires an absolute path with backslashes
cmd = client.WslCommand(ctx, "(dir 2>&1 *``|echo %userprofile%\\.waveterm%\\.waveterm\\bin\\wsh.exe);&<# rem #>echo none")
out, cmdErr := cmd.Output()
if cmdErr == nil && strings.TrimSpace(string(out)) != "none" {
return strings.TrimSpace(string(out))
}
// no custom install, use default path
return defaultPath
}
func hasBashInstalled(ctx context.Context, client *Distro) (bool, error) {
cmd := client.WslCommand(ctx, "which bash")
out, whichErr := cmd.Output()
if whichErr == nil && len(out) != 0 {
return true, nil
}
cmd = client.WslCommand(ctx, "where.exe bash")
out, whereErr := cmd.Output()
if whereErr == nil && len(out) != 0 {
return true, nil
}
// note: we could also check in /bin/bash explicitly
// just in case that wasn't added to the path. but if
// that's true, we will most likely have worse
// problems going forward
return false, nil
}
func GetClientOs(ctx context.Context, client *Distro) (string, error) {
cmd := client.WslCommand(ctx, "uname -s")
out, unixErr := cmd.Output()
if unixErr == nil {
formatted := strings.ToLower(string(out))
formatted = strings.TrimSpace(formatted)
return formatted, nil
}
cmd = client.WslCommand(ctx, "echo %OS%")
out, cmdErr := cmd.Output()
if cmdErr == nil && strings.TrimSpace(string(out)) != "%OS%" {
formatted := strings.ToLower(string(out))
formatted = strings.TrimSpace(formatted)
return strings.Split(formatted, "_")[0], nil
}
cmd = client.WslCommand(ctx, "echo $env:OS")
out, psErr := cmd.Output()
if psErr == nil && strings.TrimSpace(string(out)) != "$env:OS" {
formatted := strings.ToLower(string(out))
formatted = strings.TrimSpace(formatted)
return strings.Split(formatted, "_")[0], nil
}
return "", fmt.Errorf("unable to determine os: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr)
}
func GetClientArch(ctx context.Context, client *Distro) (string, error) {
cmd := client.WslCommand(ctx, "uname -m")
out, unixErr := cmd.Output()
if unixErr == nil {
formatted := strings.ToLower(string(out))
formatted = strings.TrimSpace(formatted)
if formatted == "x86_64" {
return "x64", nil
}
return formatted, nil
}
cmd = client.WslCommand(ctx, "echo %PROCESSOR_ARCHITECTURE%")
out, cmdErr := cmd.Output()
if cmdErr == nil && strings.TrimSpace(string(out)) != "%PROCESSOR_ARCHITECTURE%" {
formatted := strings.ToLower(string(out))
return strings.TrimSpace(formatted), nil
}
cmd = client.WslCommand(ctx, "echo $env:PROCESSOR_ARCHITECTURE")
out, psErr := cmd.Output()
if psErr == nil && strings.TrimSpace(string(out)) != "$env:PROCESSOR_ARCHITECTURE" {
formatted := strings.ToLower(string(out))
return strings.TrimSpace(formatted), nil
}
return "", fmt.Errorf("unable to determine architecture: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr)
}
type CancellableCmd struct {
Cmd *WslCmd
Cancel func()
}
var installTemplatesRawBash = map[string]string{
"mkdir": `bash -c 'mkdir -p {{.installDir}}'`,
"cat": `bash -c 'cat > {{.tempPath}}'`,
"mv": `bash -c 'mv {{.tempPath}} {{.installPath}}'`,
"chmod": `bash -c 'chmod a+x {{.installPath}}'`,
}
var installTemplatesRawDefault = map[string]string{
"mkdir": `mkdir -p {{.installDir}}`,
"cat": `cat > {{.tempPath}}`,
"mv": `mv {{.tempPath}} {{.installPath}}`,
"chmod": `chmod a+x {{.installPath}}`,
}
func makeCancellableCommand(ctx context.Context, client *Distro, cmdTemplateRaw string, words map[string]string) (*CancellableCmd, error) {
cmdContext, cmdCancel := context.WithCancel(ctx)
cmdStr := &bytes.Buffer{}
cmdTemplate, err := template.New("").Parse(cmdTemplateRaw)
if err != nil {
cmdCancel()
return nil, err
}
cmdTemplate.Execute(cmdStr, words)
cmd := client.WslCommand(cmdContext, cmdStr.String())
return &CancellableCmd{cmd, cmdCancel}, nil
}
func CpHostToRemote(ctx context.Context, client *Distro, sourcePath string, destPath string) error {
// warning: does not work on windows remote yet
bashInstalled, err := hasBashInstalled(ctx, client)
if err != nil {
return err
}
var selectedTemplatesRaw map[string]string
if bashInstalled {
selectedTemplatesRaw = installTemplatesRawBash
} else {
log.Printf("bash is not installed on remote. attempting with default shell")
selectedTemplatesRaw = installTemplatesRawDefault
}
// I need to use toSlash here to force unix keybindings
// this means we can't guarantee it will work on a remote windows machine
var installWords = map[string]string{
"installDir": filepath.ToSlash(filepath.Dir(destPath)),
"tempPath": destPath + ".temp",
"installPath": destPath,
}
installStepCmds := make(map[string]*CancellableCmd)
for cmdName, selectedTemplateRaw := range selectedTemplatesRaw {
cancellableCmd, err := makeCancellableCommand(ctx, client, selectedTemplateRaw, installWords)
if err != nil {
return err
}
installStepCmds[cmdName] = cancellableCmd
}
_, err = installStepCmds["mkdir"].Cmd.Output()
if err != nil {
return err
}
// the cat part of this is complicated since it requires stdin
catCmd := installStepCmds["cat"].Cmd
catStdin, err := catCmd.StdinPipe()
if err != nil {
return err
}
err = catCmd.Start()
if err != nil {
return err
}
input, err := os.Open(sourcePath)
if err != nil {
return fmt.Errorf("cannot open local file %s to send to host: %v", sourcePath, err)
}
go func() {
io.Copy(catStdin, input)
installStepCmds["cat"].Cancel()
// backup just in case something weird happens
// could cause potential race condition, but very
// unlikely
time.Sleep(time.Second * 1)
process := catCmd.GetProcess()
if process != nil {
process.Kill()
}
}()
catErr := catCmd.Wait()
if catErr != nil && !errors.Is(catErr, context.Canceled) {
return catErr
}
_, err = installStepCmds["mv"].Cmd.Output()
if err != nil {
return err
}
_, err = installStepCmds["chmod"].Cmd.Output()
if err != nil {
return err
}
return nil
}
func InstallClientRcFiles(ctx context.Context, client *Distro) error {
path := GetWshPath(ctx, client)
log.Printf("path to wsh searched is: %s", path)
cmd := client.WslCommand(ctx, path+" rcfiles")
_, err := cmd.Output()
return err
}
func GetHomeDir(ctx context.Context, client *Distro) string {
// note: also works for powershell
cmd := client.WslCommand(ctx, `echo "$HOME"`)
out, err := cmd.Output()
if err == nil {
return strings.TrimSpace(string(out))
}
cmd = client.WslCommand(ctx, `echo %userprofile%`)
out, err = cmd.Output()
if err == nil {
return strings.TrimSpace(string(out))
}
return "~"
}
func IsPowershell(shellPath string) bool {
// get the base path, and then check contains
shellBase := filepath.Base(shellPath)
return strings.Contains(shellBase, "powershell") || strings.Contains(shellBase, "pwsh")
}

125
pkg/wsl/wsl-win.go Normal file
View File

@ -0,0 +1,125 @@
//go:build windows
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package wsl
import (
"context"
"fmt"
"io"
"os"
"sync"
"github.com/ubuntu/gowsl"
)
var RegisteredDistros = gowsl.RegisteredDistros
var DefaultDistro = gowsl.DefaultDistro
type Distro struct {
gowsl.Distro
}
type WslCmd struct {
c *gowsl.Cmd
wg *sync.WaitGroup
once *sync.Once
lock *sync.Mutex
waitErr error
}
func (d *Distro) WslCommand(ctx context.Context, cmd string) *WslCmd {
if ctx == nil {
panic("nil Context")
}
innerCmd := d.Command(ctx, cmd)
var wg sync.WaitGroup
var lock *sync.Mutex
return &WslCmd{innerCmd, &wg, new(sync.Once), lock, nil}
}
func (c *WslCmd) CombinedOutput() (out []byte, err error) {
return c.c.CombinedOutput()
}
func (c *WslCmd) Output() (out []byte, err error) {
return c.c.Output()
}
func (c *WslCmd) Run() error {
return c.c.Run()
}
func (c *WslCmd) Start() (err error) {
return c.c.Start()
}
func (c *WslCmd) StderrPipe() (r io.ReadCloser, err error) {
return c.c.StderrPipe()
}
func (c *WslCmd) StdinPipe() (w io.WriteCloser, err error) {
return c.c.StdinPipe()
}
func (c *WslCmd) StdoutPipe() (r io.ReadCloser, err error) {
return c.c.StdoutPipe()
}
func (c *WslCmd) Wait() (err error) {
c.wg.Add(1)
c.once.Do(func() {
c.waitErr = c.c.Wait()
})
c.wg.Done()
c.wg.Wait()
if c.waitErr != nil && c.waitErr.Error() == "not started" {
c.once = new(sync.Once)
return c.waitErr
}
return c.waitErr
}
func (c *WslCmd) GetProcess() *os.Process {
return c.c.Process
}
func (c *WslCmd) GetProcessState() *os.ProcessState {
return c.c.ProcessState
}
func (c *WslCmd) SetStdin(stdin io.Reader) {
c.c.Stdin = stdin
}
func (c *WslCmd) SetStdout(stdout io.Writer) {
c.c.Stdout = stdout
}
func (c *WslCmd) SetStderr(stderr io.Writer) {
c.c.Stdout = stderr
}
func GetDistroCmd(ctx context.Context, wslDistroName string, cmd string) (*WslCmd, error) {
distros, err := RegisteredDistros(ctx)
if err != nil {
return nil, err
}
for _, distro := range distros {
if distro.Name() != wslDistroName {
continue
}
wrappedDistro := Distro{distro}
return wrappedDistro.WslCommand(ctx, cmd), nil
}
return nil, fmt.Errorf("wsl distro %s not found", wslDistroName)
}
func GetDistro(ctx context.Context, wslDistroName WslName) (*Distro, error) {
distros, err := RegisteredDistros(ctx)
if err != nil {
return nil, err
}
for _, distro := range distros {
if distro.Name() != wslDistroName.Distro {
continue
}
wrappedDistro := Distro{distro}
return &wrappedDistro, nil
}
return nil, fmt.Errorf("wsl distro %s not found", wslDistroName)
}

494
pkg/wsl/wsl.go Normal file
View File

@ -0,0 +1,494 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package wsl
import (
"context"
"fmt"
"io"
"log"
"net"
"sync"
"sync/atomic"
"time"
"github.com/wavetermdev/waveterm/pkg/userinput"
"github.com/wavetermdev/waveterm/pkg/util/shellutil"
"github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/waveobj"
"github.com/wavetermdev/waveterm/pkg/wconfig"
"github.com/wavetermdev/waveterm/pkg/wps"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshutil"
)
const (
Status_Init = "init"
Status_Connecting = "connecting"
Status_Connected = "connected"
Status_Disconnected = "disconnected"
Status_Error = "error"
)
const DefaultConnectionTimeout = 60 * time.Second
var globalLock = &sync.Mutex{}
var clientControllerMap = make(map[string]*WslConn)
var activeConnCounter = &atomic.Int32{}
type WslConn struct {
Lock *sync.Mutex
Status string
Name WslName
Client *Distro
SockName string
DomainSockListener net.Listener
ConnController *WslCmd
Error string
HasWaiter *atomic.Bool
LastConnectTime int64
ActiveConnNum int
Context context.Context
cancelFn func()
}
type WslName struct {
Distro string `json:"distro"`
}
func GetAllConnStatus() []wshrpc.ConnStatus {
globalLock.Lock()
defer globalLock.Unlock()
var connStatuses []wshrpc.ConnStatus
for _, conn := range clientControllerMap {
connStatuses = append(connStatuses, conn.DeriveConnStatus())
}
return connStatuses
}
func (conn *WslConn) DeriveConnStatus() wshrpc.ConnStatus {
conn.Lock.Lock()
defer conn.Lock.Unlock()
return wshrpc.ConnStatus{
Status: conn.Status,
Connected: conn.Status == Status_Connected,
Connection: conn.GetName(),
HasConnected: (conn.LastConnectTime > 0),
ActiveConnNum: conn.ActiveConnNum,
Error: conn.Error,
}
}
func (conn *WslConn) FireConnChangeEvent() {
status := conn.DeriveConnStatus()
event := wps.WaveEvent{
Event: wps.Event_ConnChange,
Scopes: []string{
fmt.Sprintf("connection:%s", conn.GetName()),
},
Data: status,
}
log.Printf("sending event: %+#v", event)
wps.Broker.Publish(event)
}
func (conn *WslConn) Close() error {
defer conn.FireConnChangeEvent()
conn.WithLock(func() {
if conn.Status == Status_Connected || conn.Status == Status_Connecting {
// if status is init, disconnected, or error don't change it
conn.Status = Status_Disconnected
}
conn.close_nolock()
})
// we must wait for the waiter to complete
startTime := time.Now()
for conn.HasWaiter.Load() {
time.Sleep(10 * time.Millisecond)
if time.Since(startTime) > 2*time.Second {
return fmt.Errorf("timeout waiting for waiter to complete")
}
}
return nil
}
func (conn *WslConn) close_nolock() {
// does not set status (that should happen at another level)
if conn.DomainSockListener != nil {
conn.DomainSockListener.Close()
conn.DomainSockListener = nil
}
if conn.ConnController != nil {
conn.cancelFn() // this suspends the conn controller
conn.ConnController = nil
}
if conn.Client != nil {
// conn.Client.Close() is not relevant here
// we do not want to completely close the wsl in case
// other applications are using it
conn.Client = nil
}
}
func (conn *WslConn) GetDomainSocketName() string {
conn.Lock.Lock()
defer conn.Lock.Unlock()
return conn.SockName
}
func (conn *WslConn) GetStatus() string {
conn.Lock.Lock()
defer conn.Lock.Unlock()
return conn.Status
}
func (conn *WslConn) GetName() string {
// no lock required because opts is immutable
return "wsl://" + conn.Name.Distro
}
/**
* This function is does not set a listener for WslConn
* It is still required in order to set SockName
**/
func (conn *WslConn) OpenDomainSocketListener() error {
var allowed bool
conn.WithLock(func() {
if conn.Status != Status_Connecting {
allowed = false
} else {
allowed = true
}
})
if !allowed {
return fmt.Errorf("cannot open domain socket for %q when status is %q", conn.GetName(), conn.GetStatus())
}
conn.WithLock(func() {
conn.SockName = "~/.waveterm/wave-remote.sock"
})
return nil
}
func (conn *WslConn) StartConnServer() error {
var allowed bool
conn.WithLock(func() {
if conn.Status != Status_Connecting {
allowed = false
} else {
allowed = true
}
})
if !allowed {
return fmt.Errorf("cannot start conn server for %q when status is %q", conn.GetName(), conn.GetStatus())
}
client := conn.GetClient()
wshPath := GetWshPath(conn.Context, client)
rpcCtx := wshrpc.RpcContext{
ClientType: wshrpc.ClientType_ConnServer,
Conn: conn.GetName(),
}
sockName := conn.GetDomainSocketName()
jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, sockName)
if err != nil {
return fmt.Errorf("unable to create jwt token for conn controller: %w", err)
}
shellPath, err := DetectShell(conn.Context, client)
if err != nil {
return err
}
var cmdStr string
if IsPowershell(shellPath) {
cmdStr = fmt.Sprintf("$env:%s=\"%s\"; %s connserver --router", wshutil.WaveJwtTokenVarName, jwtToken, wshPath)
} else {
cmdStr = fmt.Sprintf("%s=\"%s\" %s connserver --router", wshutil.WaveJwtTokenVarName, jwtToken, wshPath)
}
log.Printf("starting conn controller: %s\n", cmdStr)
cmd := client.WslCommand(conn.Context, cmdStr)
pipeRead, pipeWrite := io.Pipe()
inputPipeRead, inputPipeWrite := io.Pipe()
cmd.SetStdout(pipeWrite)
cmd.SetStderr(pipeWrite)
cmd.SetStdin(inputPipeRead)
err = cmd.Start()
if err != nil {
return fmt.Errorf("unable to start conn controller: %w", err)
}
conn.WithLock(func() {
conn.ConnController = cmd
})
// service the I/O
go func() {
// wait for termination, clear the controller
defer conn.WithLock(func() {
conn.ConnController = nil
})
waitErr := cmd.Wait()
log.Printf("conn controller (%q) terminated: %v", conn.GetName(), waitErr)
}()
go func() {
logName := fmt.Sprintf("conncontroller:%s", conn.GetName())
wshutil.HandleStdIOClient(logName, pipeRead, inputPipeWrite)
}()
regCtx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelFn()
err = wshutil.DefaultRouter.WaitForRegister(regCtx, wshutil.MakeConnectionRouteId(rpcCtx.Conn))
if err != nil {
return fmt.Errorf("timeout waiting for connserver to register")
}
time.Sleep(300 * time.Millisecond) // TODO remove this sleep (but we need to wait until connserver is "ready")
return nil
}
type WshInstallOpts struct {
Force bool
NoUserPrompt bool
}
func (conn *WslConn) CheckAndInstallWsh(ctx context.Context, clientDisplayName string, opts *WshInstallOpts) error {
if opts == nil {
opts = &WshInstallOpts{}
}
client := conn.GetClient()
if client == nil {
return fmt.Errorf("client is nil")
}
// check that correct wsh extensions are installed
expectedVersion := fmt.Sprintf("wsh v%s", wavebase.WaveVersion)
clientVersion, err := GetWshVersion(ctx, client)
if err == nil && clientVersion == expectedVersion && !opts.Force {
return nil
}
var queryText string
var title string
if opts.Force {
queryText = fmt.Sprintf("ReInstalling Wave Shell Extensions (%s) on `%s`\n", wavebase.WaveVersion, clientDisplayName)
title = "Install Wave Shell Extensions"
} else if err != nil {
queryText = fmt.Sprintf("Wave requires Wave Shell Extensions to be \n"+
"installed on `%s` \n"+
"to ensure a seamless experience. \n\n"+
"Would you like to install them?", clientDisplayName)
title = "Install Wave Shell Extensions"
} else {
// don't ask for upgrading the version
opts.NoUserPrompt = true
}
if !opts.NoUserPrompt {
request := &userinput.UserInputRequest{
ResponseType: "confirm",
QueryText: queryText,
Title: title,
Markdown: true,
CheckBoxMsg: "Don't show me this again",
}
response, err := userinput.GetUserInput(ctx, request)
if err != nil || !response.Confirm {
return err
}
if response.CheckboxStat {
meta := waveobj.MetaMapType{
wconfig.ConfigKey_ConnAskBeforeWshInstall: false,
}
err := wconfig.SetBaseConfigValue(meta)
if err != nil {
return fmt.Errorf("error setting conn:askbeforewshinstall value: %w", err)
}
}
}
log.Printf("attempting to install wsh to `%s`", clientDisplayName)
clientOs, err := GetClientOs(ctx, client)
if err != nil {
return err
}
clientArch, err := GetClientArch(ctx, client)
if err != nil {
return err
}
// attempt to install extension
wshLocalPath := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch)
err = CpHostToRemote(ctx, client, wshLocalPath, "~/.waveterm/bin/wsh")
if err != nil {
return err
}
log.Printf("successfully installed wsh on %s\n", conn.GetName())
return nil
}
func (conn *WslConn) GetClient() *Distro {
conn.Lock.Lock()
defer conn.Lock.Unlock()
return conn.Client
}
func (conn *WslConn) Reconnect(ctx context.Context) error {
err := conn.Close()
if err != nil {
return err
}
return conn.Connect(ctx)
}
func (conn *WslConn) WaitForConnect(ctx context.Context) error {
for {
status := conn.DeriveConnStatus()
if status.Status == Status_Connected {
return nil
}
if status.Status == Status_Connecting {
select {
case <-ctx.Done():
return fmt.Errorf("context timeout")
case <-time.After(100 * time.Millisecond):
continue
}
}
if status.Status == Status_Init || status.Status == Status_Disconnected {
return fmt.Errorf("disconnected")
}
if status.Status == Status_Error {
return fmt.Errorf("error: %v", status.Error)
}
return fmt.Errorf("unknown status: %q", status.Status)
}
}
// does not return an error since that error is stored inside of WslConn
func (conn *WslConn) Connect(ctx context.Context) error {
var connectAllowed bool
conn.WithLock(func() {
if conn.Status == Status_Connecting || conn.Status == Status_Connected {
connectAllowed = false
} else {
conn.Status = Status_Connecting
conn.Error = ""
connectAllowed = true
}
})
log.Printf("Connect %s\n", conn.GetName())
if !connectAllowed {
return fmt.Errorf("cannot connect to %q when status is %q", conn.GetName(), conn.GetStatus())
}
conn.FireConnChangeEvent()
err := conn.connectInternal(ctx)
conn.WithLock(func() {
if err != nil {
conn.Status = Status_Error
conn.Error = err.Error()
conn.close_nolock()
} else {
conn.Status = Status_Connected
conn.LastConnectTime = time.Now().UnixMilli()
if conn.ActiveConnNum == 0 {
conn.ActiveConnNum = int(activeConnCounter.Add(1))
}
}
})
conn.FireConnChangeEvent()
return err
}
func (conn *WslConn) WithLock(fn func()) {
conn.Lock.Lock()
defer conn.Lock.Unlock()
fn()
}
func (conn *WslConn) connectInternal(ctx context.Context) error {
client, err := GetDistro(ctx, conn.Name)
if err != nil {
return err
}
conn.WithLock(func() {
conn.Client = client
})
err = conn.OpenDomainSocketListener()
if err != nil {
return err
}
config := wconfig.ReadFullConfig()
installErr := conn.CheckAndInstallWsh(ctx, conn.GetName(), &WshInstallOpts{NoUserPrompt: !config.Settings.ConnAskBeforeWshInstall})
if installErr != nil {
return fmt.Errorf("conncontroller %s wsh install error: %v", conn.GetName(), installErr)
}
csErr := conn.StartConnServer()
if csErr != nil {
return fmt.Errorf("conncontroller %s start wsh connserver error: %v", conn.GetName(), csErr)
}
conn.HasWaiter.Store(true)
go conn.waitForDisconnect()
return nil
}
func (conn *WslConn) waitForDisconnect() {
defer conn.FireConnChangeEvent()
defer conn.HasWaiter.Store(false)
err := conn.ConnController.Wait()
conn.WithLock(func() {
// disconnects happen for a variety of reasons (like network, etc. and are typically transient)
// so we just set the status to "disconnected" here (not error)
// don't overwrite any existing error (or error status)
if err != nil && conn.Error == "" {
conn.Error = err.Error()
}
if conn.Status != Status_Error {
conn.Status = Status_Disconnected
}
conn.close_nolock()
})
}
func getConnInternal(name string) *WslConn {
globalLock.Lock()
defer globalLock.Unlock()
connName := WslName{Distro: name}
rtn := clientControllerMap[name]
if rtn == nil {
ctx, cancelFn := context.WithCancel(context.Background())
rtn = &WslConn{Lock: &sync.Mutex{}, Status: Status_Init, Name: connName, HasWaiter: &atomic.Bool{}, Context: ctx, cancelFn: cancelFn}
clientControllerMap[name] = rtn
}
return rtn
}
func GetWslConn(ctx context.Context, name string, shouldConnect bool) *WslConn {
conn := getConnInternal(name)
if conn.Client == nil && shouldConnect {
conn.Connect(ctx)
}
return conn
}
// Convenience function for ensuring a connection is established
func EnsureConnection(ctx context.Context, connName string) error {
if connName == "" {
return nil
}
conn := GetWslConn(ctx, connName, false)
if conn == nil {
return fmt.Errorf("connection not found: %s", connName)
}
connStatus := conn.DeriveConnStatus()
switch connStatus.Status {
case Status_Connected:
return nil
case Status_Connecting:
return conn.WaitForConnect(ctx)
case Status_Init, Status_Disconnected:
return conn.Connect(ctx)
case Status_Error:
return fmt.Errorf("connection error: %s", connStatus.Error)
default:
return fmt.Errorf("unknown connection status %q", connStatus.Status)
}
}
func DisconnectClient(connName string) error {
conn := getConnInternal(connName)
if conn == nil {
return fmt.Errorf("client %q not found", connName)
}
err := conn.Close()
return err
}