mirror of
https://github.com/wavetermdev/waveterm.git
synced 2024-12-21 16:38:23 +01:00
WSL Integration (#1031)
Adds support for connecting to local WSL installations on Windows. (also adds wshrpcmmultiproxy / connserver router)
This commit is contained in:
parent
4e86b67936
commit
8248637e00
2
.gitattributes
vendored
2
.gitattributes
vendored
@ -1 +1 @@
|
||||
* text=auto
|
||||
* text=auto eol=lf
|
@ -159,11 +159,11 @@ func shutdownActivityUpdate() {
|
||||
|
||||
func createMainWshClient() {
|
||||
rpc := wshserver.GetMainRpcClient()
|
||||
wshutil.DefaultRouter.RegisterRoute(wshutil.DefaultRoute, rpc)
|
||||
wshutil.DefaultRouter.RegisterRoute(wshutil.DefaultRoute, rpc, true)
|
||||
wps.Broker.SetClient(wshutil.DefaultRouter)
|
||||
localConnWsh := wshutil.MakeWshRpc(nil, nil, wshrpc.RpcContext{Conn: wshrpc.LocalConnName}, &wshremote.ServerImpl{})
|
||||
go wshremote.RunSysInfoLoop(localConnWsh, wshrpc.LocalConnName)
|
||||
wshutil.DefaultRouter.RegisterRoute(wshutil.MakeConnectionRouteId(wshrpc.LocalConnName), localConnWsh)
|
||||
wshutil.DefaultRouter.RegisterRoute(wshutil.MakeConnectionRouteId(wshrpc.LocalConnName), localConnWsh, true)
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
@ -5,6 +5,7 @@ package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/wavetermdev/waveterm/pkg/remote"
|
||||
@ -25,17 +26,24 @@ func init() {
|
||||
}
|
||||
|
||||
func connStatus() error {
|
||||
resp, err := wshclient.ConnStatusCommand(RpcClient, nil)
|
||||
var allResp []wshrpc.ConnStatus
|
||||
sshResp, err := wshclient.ConnStatusCommand(RpcClient, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting connection status: %w", err)
|
||||
return fmt.Errorf("getting ssh connection status: %w", err)
|
||||
}
|
||||
if len(resp) == 0 {
|
||||
allResp = append(allResp, sshResp...)
|
||||
wslResp, err := wshclient.WslStatusCommand(RpcClient, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting wsl connection status: %w", err)
|
||||
}
|
||||
allResp = append(allResp, wslResp...)
|
||||
if len(allResp) == 0 {
|
||||
WriteStdout("no connections\n")
|
||||
return nil
|
||||
}
|
||||
WriteStdout("%-30s %-12s\n", "connection", "status")
|
||||
WriteStdout("----------------------------------------------\n")
|
||||
for _, conn := range resp {
|
||||
for _, conn := range allResp {
|
||||
str := fmt.Sprintf("%-30s %-12s", conn.Connection, conn.Status)
|
||||
if conn.Error != "" {
|
||||
str += fmt.Sprintf(" (%s)", conn.Error)
|
||||
@ -110,7 +118,7 @@ func connRun(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
connName = args[1]
|
||||
_, err := remote.ParseOpts(connName)
|
||||
if err != nil {
|
||||
if err != nil && !strings.HasPrefix(connName, "wsl://") {
|
||||
return fmt.Errorf("cannot parse connection name: %w", err)
|
||||
}
|
||||
}
|
||||
|
@ -4,10 +4,22 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/wavetermdev/waveterm/pkg/util/packetparser"
|
||||
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshremote"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||
)
|
||||
|
||||
var serverCmd = &cobra.Command{
|
||||
@ -15,18 +27,163 @@ var serverCmd = &cobra.Command{
|
||||
Hidden: true,
|
||||
Short: "remote server to power wave blocks",
|
||||
Args: cobra.NoArgs,
|
||||
Run: serverRun,
|
||||
PreRunE: preRunSetupRpcClient,
|
||||
RunE: serverRun,
|
||||
}
|
||||
|
||||
var connServerRouter bool
|
||||
|
||||
func init() {
|
||||
serverCmd.Flags().BoolVar(&connServerRouter, "router", false, "run in local router mode")
|
||||
rootCmd.AddCommand(serverCmd)
|
||||
}
|
||||
|
||||
func serverRun(cmd *cobra.Command, args []string) {
|
||||
func MakeRemoteUnixListener() (net.Listener, error) {
|
||||
serverAddr := wavebase.GetRemoteDomainSocketName()
|
||||
os.Remove(serverAddr) // ignore error
|
||||
rtn, err := net.Listen("unix", serverAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating listener at %v: %v", serverAddr, err)
|
||||
}
|
||||
os.Chmod(serverAddr, 0700)
|
||||
log.Printf("Server [unix-domain] listening on %s\n", serverAddr)
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func handleNewListenerConn(conn net.Conn, router *wshutil.WshRouter) {
|
||||
var routeIdContainer atomic.Pointer[string]
|
||||
proxy := wshutil.MakeRpcProxy()
|
||||
go func() {
|
||||
writeErr := wshutil.AdaptOutputChToStream(proxy.ToRemoteCh, conn)
|
||||
if writeErr != nil {
|
||||
log.Printf("error writing to domain socket: %v\n", writeErr)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
// when input is closed, close the connection
|
||||
defer func() {
|
||||
conn.Close()
|
||||
routeIdPtr := routeIdContainer.Load()
|
||||
if routeIdPtr != nil && *routeIdPtr != "" {
|
||||
router.UnregisterRoute(*routeIdPtr)
|
||||
disposeMsg := &wshutil.RpcMessage{
|
||||
Command: wshrpc.Command_Dispose,
|
||||
Data: wshrpc.CommandDisposeData{
|
||||
RouteId: *routeIdPtr,
|
||||
},
|
||||
Source: *routeIdPtr,
|
||||
AuthToken: proxy.GetAuthToken(),
|
||||
}
|
||||
disposeBytes, _ := json.Marshal(disposeMsg)
|
||||
router.InjectMessage(disposeBytes, *routeIdPtr)
|
||||
}
|
||||
}()
|
||||
wshutil.AdaptStreamToMsgCh(conn, proxy.FromRemoteCh)
|
||||
}()
|
||||
routeId, err := proxy.HandleClientProxyAuth(router)
|
||||
if err != nil {
|
||||
log.Printf("error handling client proxy auth: %v\n", err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
router.RegisterRoute(routeId, proxy, false)
|
||||
routeIdContainer.Store(&routeId)
|
||||
}
|
||||
|
||||
func runListener(listener net.Listener, router *wshutil.WshRouter) {
|
||||
defer func() {
|
||||
log.Printf("listener closed, exiting\n")
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
wshutil.DoShutdown("", 1, true)
|
||||
}()
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("error accepting connection: %v\n", err)
|
||||
continue
|
||||
}
|
||||
go handleNewListenerConn(conn, router)
|
||||
}
|
||||
}
|
||||
|
||||
func setupConnServerRpcClientWithRouter(router *wshutil.WshRouter) (*wshutil.WshRpc, error) {
|
||||
jwtToken := os.Getenv(wshutil.WaveJwtTokenVarName)
|
||||
if jwtToken == "" {
|
||||
return nil, fmt.Errorf("no jwt token found for connserver")
|
||||
}
|
||||
rpcCtx, err := wshutil.ExtractUnverifiedRpcContext(jwtToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error extracting rpc context from %s: %v", wshutil.WaveJwtTokenVarName, err)
|
||||
}
|
||||
authRtn, err := router.HandleProxyAuth(jwtToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error handling proxy auth: %v", err)
|
||||
}
|
||||
inputCh := make(chan []byte, wshutil.DefaultInputChSize)
|
||||
outputCh := make(chan []byte, wshutil.DefaultOutputChSize)
|
||||
connServerClient := wshutil.MakeWshRpc(inputCh, outputCh, *rpcCtx, &wshremote.ServerImpl{LogWriter: os.Stdout})
|
||||
connServerClient.SetAuthToken(authRtn.AuthToken)
|
||||
router.RegisterRoute(authRtn.RouteId, connServerClient, false)
|
||||
wshclient.RouteAnnounceCommand(connServerClient, nil)
|
||||
return connServerClient, nil
|
||||
}
|
||||
|
||||
func serverRunRouter() error {
|
||||
router := wshutil.NewWshRouter()
|
||||
termProxy := wshutil.MakeRpcProxy()
|
||||
rawCh := make(chan []byte, wshutil.DefaultOutputChSize)
|
||||
go packetparser.Parse(os.Stdin, termProxy.FromRemoteCh, rawCh)
|
||||
go func() {
|
||||
for msg := range termProxy.ToRemoteCh {
|
||||
packetparser.WritePacket(os.Stdout, msg)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
// just ignore and drain the rawCh (stdin)
|
||||
// when stdin is closed, shutdown
|
||||
defer wshutil.DoShutdown("", 0, true)
|
||||
for range rawCh {
|
||||
// ignore
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
for msg := range termProxy.FromRemoteCh {
|
||||
// send this to the router
|
||||
router.InjectMessage(msg, wshutil.UpstreamRoute)
|
||||
}
|
||||
}()
|
||||
router.SetUpstreamClient(termProxy)
|
||||
// now set up the domain socket
|
||||
unixListener, err := MakeRemoteUnixListener()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot create unix listener: %v", err)
|
||||
}
|
||||
client, err := setupConnServerRpcClientWithRouter(router)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error setting up connserver rpc client: %v", err)
|
||||
}
|
||||
go runListener(unixListener, router)
|
||||
// run the sysinfo loop
|
||||
wshremote.RunSysInfoLoop(client, client.GetRpcContext().Conn)
|
||||
select {}
|
||||
}
|
||||
|
||||
func serverRunNormal() error {
|
||||
err := setupRpcClient(&wshremote.ServerImpl{LogWriter: os.Stdout})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
WriteStdout("running wsh connserver (%s)\n", RpcContext.Conn)
|
||||
go wshremote.RunSysInfoLoop(RpcClient, RpcContext.Conn)
|
||||
RpcClient.SetServerImpl(&wshremote.ServerImpl{LogWriter: os.Stdout})
|
||||
|
||||
select {} // run forever
|
||||
}
|
||||
|
||||
func serverRun(cmd *cobra.Command, args []string) error {
|
||||
if connServerRouter {
|
||||
return serverRunRouter()
|
||||
} else {
|
||||
return serverRunNormal()
|
||||
}
|
||||
}
|
||||
|
60
cmd/wsh/cmd/wshcmd-wsl.go
Normal file
60
cmd/wsh/cmd/wshcmd-wsl.go
Normal file
@ -0,0 +1,60 @@
|
||||
// Copyright 2024, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/wavetermdev/waveterm/pkg/waveobj"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
|
||||
)
|
||||
|
||||
var distroName string
|
||||
|
||||
var wslCmd = &cobra.Command{
|
||||
Use: "wsl [-d <Distro>]",
|
||||
Short: "connect this terminal to a local wsl connection",
|
||||
Args: cobra.NoArgs,
|
||||
Run: wslRun,
|
||||
PreRunE: preRunSetupRpcClient,
|
||||
}
|
||||
|
||||
func init() {
|
||||
wslCmd.Flags().StringVarP(&distroName, "distribution", "d", "", "Run the specified distribution")
|
||||
rootCmd.AddCommand(wslCmd)
|
||||
}
|
||||
|
||||
func wslRun(cmd *cobra.Command, args []string) {
|
||||
var err error
|
||||
if distroName == "" {
|
||||
// get default distro from the host
|
||||
distroName, err = wshclient.WslDefaultDistroCommand(RpcClient, nil)
|
||||
if err != nil {
|
||||
WriteStderr("[error] %s\n", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if !strings.HasPrefix(distroName, "wsl://") {
|
||||
distroName = "wsl://" + distroName
|
||||
}
|
||||
blockId := RpcContext.BlockId
|
||||
if blockId == "" {
|
||||
WriteStderr("[error] cannot determine blockid (not in JWT)\n")
|
||||
return
|
||||
}
|
||||
data := wshrpc.CommandSetMetaData{
|
||||
ORef: waveobj.MakeORef(waveobj.OType_Block, blockId),
|
||||
Meta: map[string]any{
|
||||
waveobj.MetaKey_Connection: distroName,
|
||||
},
|
||||
}
|
||||
err = wshclient.SetMetaCommand(RpcClient, data, nil)
|
||||
if err != nil {
|
||||
WriteStderr("[error] setting switching connection: %v\n", err)
|
||||
return
|
||||
}
|
||||
WriteStderr("switched connection to %q\n", distroName)
|
||||
}
|
@ -521,6 +521,7 @@ const ChangeConnectionBlockModal = React.memo(
|
||||
const connStatusAtom = getConnStatusAtom(connection);
|
||||
const connStatus = jotai.useAtomValue(connStatusAtom);
|
||||
const [connList, setConnList] = React.useState<Array<string>>([]);
|
||||
const [wslList, setWslList] = React.useState<Array<string>>([]);
|
||||
const allConnStatus = jotai.useAtomValue(atoms.allConnStatus);
|
||||
const [rowIndex, setRowIndex] = React.useState(0);
|
||||
const connStatusMap = new Map<string, ConnStatus>();
|
||||
@ -540,6 +541,18 @@ const ChangeConnectionBlockModal = React.memo(
|
||||
prtn.then((newConnList) => {
|
||||
setConnList(newConnList ?? []);
|
||||
}).catch((e) => console.log("unable to load conn list from backend. using blank list: ", e));
|
||||
const p2rtn = RpcApi.WslListCommand(TabRpcClient, { timeout: 2000 });
|
||||
p2rtn
|
||||
.then((newWslList) => {
|
||||
console.log(newWslList);
|
||||
setWslList(newWslList ?? []);
|
||||
})
|
||||
.catch((e) => {
|
||||
// removing this log and failing silentyly since it will happen
|
||||
// if a system isn't using the wsl. and would happen every time the
|
||||
// typeahead was opened. good candidate for verbose log level.
|
||||
//console.log("unable to load wsl list from backend. using blank list: ", e)
|
||||
});
|
||||
}, [changeConnModalOpen, setConnList]);
|
||||
|
||||
const changeConnection = React.useCallback(
|
||||
@ -588,6 +601,15 @@ const ChangeConnectionBlockModal = React.memo(
|
||||
filteredList.push(conn);
|
||||
}
|
||||
}
|
||||
const filteredWslList: Array<string> = [];
|
||||
for (const conn of wslList) {
|
||||
if (conn === connSelected) {
|
||||
createNew = false;
|
||||
}
|
||||
if (conn.includes(connSelected)) {
|
||||
filteredWslList.push(conn);
|
||||
}
|
||||
}
|
||||
// priority handles special suggestions when necessary
|
||||
// for instance, when reconnecting
|
||||
const newConnectionSuggestion: SuggestionConnectionItem = {
|
||||
@ -637,6 +659,20 @@ const ChangeConnectionBlockModal = React.memo(
|
||||
label: localName,
|
||||
});
|
||||
}
|
||||
for (const wslConn of filteredWslList) {
|
||||
const connStatus = connStatusMap.get(wslConn);
|
||||
const connColorNum = computeConnColorNum(connStatus);
|
||||
localSuggestion.items.push({
|
||||
status: "connected",
|
||||
icon: "arrow-right-arrow-left",
|
||||
iconColor:
|
||||
connStatus?.status == "connected"
|
||||
? `var(--conn-icon-color-${connColorNum})`
|
||||
: "var(--grey-text-color)",
|
||||
value: "wsl://" + wslConn,
|
||||
label: "wsl://" + wslConn,
|
||||
});
|
||||
}
|
||||
const remoteItems = filteredList.map((connName) => {
|
||||
const connStatus = connStatusMap.get(connName);
|
||||
const connColorNum = computeConnColorNum(connStatus);
|
||||
|
@ -72,6 +72,11 @@ class RpcApiType {
|
||||
return client.wshRpcCall("deleteblock", data, opts);
|
||||
}
|
||||
|
||||
// command "dispose" [call]
|
||||
DisposeCommand(client: WshClient, data: CommandDisposeData, opts?: RpcOpts): Promise<void> {
|
||||
return client.wshRpcCall("dispose", data, opts);
|
||||
}
|
||||
|
||||
// command "eventpublish" [call]
|
||||
EventPublishCommand(client: WshClient, data: WaveEvent, opts?: RpcOpts): Promise<void> {
|
||||
return client.wshRpcCall("eventpublish", data, opts);
|
||||
@ -237,6 +242,21 @@ class RpcApiType {
|
||||
return client.wshRpcCall("webselector", data, opts);
|
||||
}
|
||||
|
||||
// command "wsldefaultdistro" [call]
|
||||
WslDefaultDistroCommand(client: WshClient, opts?: RpcOpts): Promise<string> {
|
||||
return client.wshRpcCall("wsldefaultdistro", null, opts);
|
||||
}
|
||||
|
||||
// command "wsllist" [call]
|
||||
WslListCommand(client: WshClient, opts?: RpcOpts): Promise<string[]> {
|
||||
return client.wshRpcCall("wsllist", null, opts);
|
||||
}
|
||||
|
||||
// command "wslstatus" [call]
|
||||
WslStatusCommand(client: WshClient, opts?: RpcOpts): Promise<ConnStatus[]> {
|
||||
return client.wshRpcCall("wslstatus", null, opts);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
export const RpcApi = new RpcApiType();
|
||||
|
7
frontend/types/gotypes.d.ts
vendored
7
frontend/types/gotypes.d.ts
vendored
@ -63,6 +63,7 @@ declare global {
|
||||
// wshrpc.CommandAuthenticateRtnData
|
||||
type CommandAuthenticateRtnData = {
|
||||
routeid: string;
|
||||
authtoken?: string;
|
||||
};
|
||||
|
||||
// wshrpc.CommandBlockInputData
|
||||
@ -100,6 +101,11 @@ declare global {
|
||||
blockid: string;
|
||||
};
|
||||
|
||||
// wshrpc.CommandDisposeData
|
||||
type CommandDisposeData = {
|
||||
routeid: string;
|
||||
};
|
||||
|
||||
// wshrpc.CommandEventReadHistoryData
|
||||
type CommandEventReadHistoryData = {
|
||||
event: string;
|
||||
@ -416,6 +422,7 @@ declare global {
|
||||
resid?: string;
|
||||
timeout?: number;
|
||||
route?: string;
|
||||
authtoken?: string;
|
||||
source?: string;
|
||||
cont?: boolean;
|
||||
cancel?: boolean;
|
||||
|
3
go.mod
3
go.mod
@ -21,6 +21,7 @@ require (
|
||||
github.com/shirou/gopsutil/v4 v4.24.9
|
||||
github.com/skeema/knownhosts v1.3.0
|
||||
github.com/spf13/cobra v1.8.1
|
||||
github.com/ubuntu/gowsl v0.0.0-20240906163211-049fd49bd93b
|
||||
github.com/wavetermdev/htmltoken v0.1.0
|
||||
golang.org/x/crypto v0.28.0
|
||||
golang.org/x/sys v0.26.0
|
||||
@ -36,9 +37,11 @@ require (
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||
github.com/ubuntu/decorate v0.0.0-20230125165522-2d5b0a9bb117 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
go.uber.org/atomic v1.7.0 // indirect
|
||||
golang.org/x/net v0.29.0 // indirect
|
||||
|
11
go.sum
11
go.sum
@ -1,5 +1,7 @@
|
||||
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
github.com/0xrawsec/golang-utils v1.3.2 h1:ww4jrtHRSnX9xrGzJYbalx5nXoZewy4zPxiY+ubJgtg=
|
||||
github.com/0xrawsec/golang-utils v1.3.2/go.mod h1:m7AzHXgdSAkFCD9tWWsApxNVxMlyy7anpPVOyT/yM7E=
|
||||
github.com/alexflint/go-filemutex v1.3.0 h1:LgE+nTUWnQCyRKbpoceKZsPQbs84LivvgwUymZXdOcM=
|
||||
github.com/alexflint/go-filemutex v1.3.0/go.mod h1:U0+VA/i30mGBlLCrFPGtTe9y6wGQfNAWPBTekHQ+c8A=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
@ -62,6 +64,8 @@ github.com/sawka/txwrap v0.2.0 h1:V3LfvKVLULxcYSxdMguLwFyQFMEU9nFDJopg0ZkL+94=
|
||||
github.com/sawka/txwrap v0.2.0/go.mod h1:wwQ2SQiN4U+6DU/iVPhbvr7OzXAtgZlQCIGuvOswEfA=
|
||||
github.com/shirou/gopsutil/v4 v4.24.9 h1:KIV+/HaHD5ka5f570RZq+2SaeFsb/pq+fp2DGNWYoOI=
|
||||
github.com/shirou/gopsutil/v4 v4.24.9/go.mod h1:3fkaHNeYsUFCGZ8+9vZVWtbyM1k2eRnlL+bWO8Bxa/Q=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/skeema/knownhosts v1.3.0 h1:AM+y0rI04VksttfwjkSTNQorvGqmwATnvnAHpSgc0LY=
|
||||
github.com/skeema/knownhosts v1.3.0/go.mod h1:sPINvnADmT/qYH1kfv+ePMmOBTH6Tbl7b5LvTDjFK7M=
|
||||
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
|
||||
@ -71,12 +75,17 @@ github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
||||
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
|
||||
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
|
||||
github.com/ubuntu/decorate v0.0.0-20230125165522-2d5b0a9bb117 h1:XQpsQG5lqRJlx4mUVHcJvyyc1rdTI9nHvwrdfcuy8aM=
|
||||
github.com/ubuntu/decorate v0.0.0-20230125165522-2d5b0a9bb117/go.mod h1:mx0TjbqsaDD9DUT5gA1s3hw47U6RIbbIBfvGzR85K0g=
|
||||
github.com/ubuntu/gowsl v0.0.0-20240906163211-049fd49bd93b h1:wFBKF5k5xbJQU8bYgcSoQ/ScvmYyq6KHUabAuVUjOWM=
|
||||
github.com/ubuntu/gowsl v0.0.0-20240906163211-049fd49bd93b/go.mod h1:N1CYNinssZru+ikvYTgVbVeSi21thHUTCoJ9xMvWe+s=
|
||||
github.com/wavetermdev/htmltoken v0.1.0 h1:RMdA9zTfnYa5jRC4RRG3XNoV5NOP8EDxpaVPjuVz//Q=
|
||||
github.com/wavetermdev/htmltoken v0.1.0/go.mod h1:5FM0XV6zNYiNza2iaTcFGj+hnMtgqumFHO31Z8euquk=
|
||||
github.com/wavetermdev/ssh_config v0.0.0-20240306041034-17e2087ebde2 h1:onqZrJVap1sm15AiIGTfWzdr6cEF0KdtddeuuOVhzyY=
|
||||
@ -91,6 +100,7 @@ golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo=
|
||||
golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220721230656-c6bc011c0c49/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
@ -102,5 +112,6 @@ golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"io"
|
||||
"io/fs"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -24,6 +25,7 @@ import (
|
||||
"github.com/wavetermdev/waveterm/pkg/wps"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||
"github.com/wavetermdev/waveterm/pkg/wsl"
|
||||
"github.com/wavetermdev/waveterm/pkg/wstore"
|
||||
)
|
||||
|
||||
@ -262,7 +264,30 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
|
||||
return fmt.Errorf("unknown controller type %q", bc.ControllerType)
|
||||
}
|
||||
var shellProc *shellexec.ShellProc
|
||||
if remoteName != "" {
|
||||
if strings.HasPrefix(remoteName, "wsl://") {
|
||||
wslName := strings.TrimPrefix(remoteName, "wsl://")
|
||||
credentialCtx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancelFunc()
|
||||
|
||||
wslConn := wsl.GetWslConn(credentialCtx, wslName, false)
|
||||
connStatus := wslConn.DeriveConnStatus()
|
||||
if connStatus.Status != conncontroller.Status_Connected {
|
||||
return fmt.Errorf("not connected, cannot start shellproc")
|
||||
}
|
||||
|
||||
// create jwt
|
||||
if !blockMeta.GetBool(waveobj.MetaKey_CmdNoWsh, false) {
|
||||
jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId, Conn: wslConn.GetName()}, wslConn.GetDomainSocketName())
|
||||
if err != nil {
|
||||
return fmt.Errorf("error making jwt token: %w", err)
|
||||
}
|
||||
cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr
|
||||
}
|
||||
shellProc, err = shellexec.StartWslShellProc(ctx, rc.TermSize, cmdStr, cmdOpts, wslConn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if remoteName != "" {
|
||||
credentialCtx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancelFunc()
|
||||
|
||||
@ -325,7 +350,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
|
||||
// we don't need to authenticate this wshProxy since it is coming direct
|
||||
wshProxy := wshutil.MakeRpcProxy()
|
||||
wshProxy.SetRpcContext(&wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId})
|
||||
wshutil.DefaultRouter.RegisterRoute(wshutil.MakeControllerRouteId(bc.BlockId), wshProxy)
|
||||
wshutil.DefaultRouter.RegisterRoute(wshutil.MakeControllerRouteId(bc.BlockId), wshProxy, true)
|
||||
ptyBuffer := wshutil.MakePtyBuffer(wshutil.WaveOSCPrefix, bc.ShellProc.Cmd, wshProxy.FromRemoteCh)
|
||||
go func() {
|
||||
// handles regular output from the pty (goes to the blockfile and xterm)
|
||||
@ -494,6 +519,15 @@ func CheckConnStatus(blockId string) error {
|
||||
if connName == "" {
|
||||
return nil
|
||||
}
|
||||
if strings.HasPrefix(connName, "wsl://") {
|
||||
distroName := strings.TrimPrefix(connName, "wsl://")
|
||||
conn := wsl.GetWslConn(context.Background(), distroName, false)
|
||||
connStatus := conn.DeriveConnStatus()
|
||||
if connStatus.Status != conncontroller.Status_Connected {
|
||||
return fmt.Errorf("not connected: %s", connStatus.Status)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
opts, err := remote.ParseOpts(connName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing connection name: %w", err)
|
||||
|
@ -1,3 +1,6 @@
|
||||
// Copyright 2024, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package remote
|
||||
|
||||
import (
|
||||
|
@ -17,6 +17,7 @@ import (
|
||||
"github.com/wavetermdev/waveterm/pkg/wcore"
|
||||
"github.com/wavetermdev/waveterm/pkg/wlayout"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
"github.com/wavetermdev/waveterm/pkg/wsl"
|
||||
"github.com/wavetermdev/waveterm/pkg/wstore"
|
||||
)
|
||||
|
||||
@ -77,7 +78,9 @@ func (cs *ClientService) MakeWindow(ctx context.Context) (*waveobj.Window, error
|
||||
}
|
||||
|
||||
func (cs *ClientService) GetAllConnStatus(ctx context.Context) ([]wshrpc.ConnStatus, error) {
|
||||
return conncontroller.GetAllConnStatus(), nil
|
||||
sshStatuses := conncontroller.GetAllConnStatus()
|
||||
wslStatuses := wsl.GetAllConnStatus()
|
||||
return append(sshStatuses, wslStatuses...), nil
|
||||
}
|
||||
|
||||
// moves the window to the front of the windowId stack
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/creack/pty"
|
||||
"github.com/wavetermdev/waveterm/pkg/wsl"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
@ -129,3 +130,42 @@ func (sw SessionWrap) StderrPipe() (io.ReadCloser, error) {
|
||||
func (sw SessionWrap) SetSize(h int, w int) error {
|
||||
return sw.Session.WindowChange(h, w)
|
||||
}
|
||||
|
||||
type WslCmdWrap struct {
|
||||
*wsl.WslCmd
|
||||
Tty pty.Tty
|
||||
pty.Pty
|
||||
}
|
||||
|
||||
func (wcw WslCmdWrap) Kill() {
|
||||
wcw.Tty.Close()
|
||||
wcw.Close()
|
||||
}
|
||||
|
||||
func (wcw WslCmdWrap) KillGraceful(timeout time.Duration) {
|
||||
process := wcw.WslCmd.GetProcess()
|
||||
if process == nil {
|
||||
return
|
||||
}
|
||||
processState := wcw.WslCmd.GetProcessState()
|
||||
if processState != nil && processState.Exited() {
|
||||
return
|
||||
}
|
||||
process.Signal(os.Interrupt)
|
||||
go func() {
|
||||
time.Sleep(timeout)
|
||||
process := wcw.WslCmd.GetProcess()
|
||||
processState := wcw.WslCmd.GetProcessState()
|
||||
if processState == nil || !processState.Exited() {
|
||||
process.Kill() // force kill if it is already not exited
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
/**
|
||||
* SetSize does nothing for WslCmdWrap as there
|
||||
* is no pty to manage.
|
||||
**/
|
||||
func (wcw WslCmdWrap) SetSize(w int, h int) error {
|
||||
return nil
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ package shellexec
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
@ -25,6 +26,7 @@ import (
|
||||
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
||||
"github.com/wavetermdev/waveterm/pkg/waveobj"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||
"github.com/wavetermdev/waveterm/pkg/wsl"
|
||||
)
|
||||
|
||||
const DefaultGracefulKillWait = 400 * time.Millisecond
|
||||
@ -141,6 +143,96 @@ func (pp *PipePty) WriteString(s string) (n int, err error) {
|
||||
return pp.Write([]byte(s))
|
||||
}
|
||||
|
||||
func StartWslShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *wsl.WslConn) (*ShellProc, error) {
|
||||
client := conn.GetClient()
|
||||
shellPath := cmdOpts.ShellPath
|
||||
if shellPath == "" {
|
||||
remoteShellPath, err := wsl.DetectShell(conn.Context, client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
shellPath = remoteShellPath
|
||||
}
|
||||
var shellOpts []string
|
||||
log.Printf("detected shell: %s", shellPath)
|
||||
|
||||
err := wsl.InstallClientRcFiles(conn.Context, client)
|
||||
if err != nil {
|
||||
log.Printf("error installing rc files: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
homeDir := wsl.GetHomeDir(conn.Context, client)
|
||||
shellOpts = append(shellOpts, "~", "-d", client.Name())
|
||||
|
||||
if isZshShell(shellPath) {
|
||||
shellOpts = append(shellOpts, fmt.Sprintf(`ZDOTDIR="%s/.waveterm/%s"`, homeDir, shellutil.ZshIntegrationDir))
|
||||
}
|
||||
var subShellOpts []string
|
||||
|
||||
if cmdStr == "" {
|
||||
/* transform command in order to inject environment vars */
|
||||
if isBashShell(shellPath) {
|
||||
log.Printf("recognized as bash shell")
|
||||
// add --rcfile
|
||||
// cant set -l or -i with --rcfile
|
||||
subShellOpts = append(subShellOpts, "--rcfile", fmt.Sprintf(`%s/.waveterm/%s/.bashrc`, homeDir, shellutil.BashIntegrationDir))
|
||||
} else if isFishShell(shellPath) {
|
||||
carg := fmt.Sprintf(`"set -x PATH \"%s\"/.waveterm/%s $PATH"`, homeDir, shellutil.WaveHomeBinDir)
|
||||
subShellOpts = append(subShellOpts, "-C", carg)
|
||||
} else if wsl.IsPowershell(shellPath) {
|
||||
// powershell is weird about quoted path executables and requires an ampersand first
|
||||
shellPath = "& " + shellPath
|
||||
subShellOpts = append(subShellOpts, "-ExecutionPolicy", "Bypass", "-NoExit", "-File", homeDir+fmt.Sprintf("/.waveterm/%s/wavepwsh.ps1", shellutil.PwshIntegrationDir))
|
||||
} else {
|
||||
if cmdOpts.Login {
|
||||
subShellOpts = append(subShellOpts, "-l")
|
||||
}
|
||||
if cmdOpts.Interactive {
|
||||
subShellOpts = append(subShellOpts, "-i")
|
||||
}
|
||||
// can't set environment vars this way
|
||||
// will try to do later if possible
|
||||
}
|
||||
} else {
|
||||
shellPath = cmdStr
|
||||
if cmdOpts.Login {
|
||||
subShellOpts = append(subShellOpts, "-l")
|
||||
}
|
||||
if cmdOpts.Interactive {
|
||||
subShellOpts = append(subShellOpts, "-i")
|
||||
}
|
||||
subShellOpts = append(subShellOpts, "-c", cmdStr)
|
||||
}
|
||||
|
||||
jwtToken, ok := cmdOpts.Env[wshutil.WaveJwtTokenVarName]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no jwt token provided to connection")
|
||||
}
|
||||
if remote.IsPowershell(shellPath) {
|
||||
shellOpts = append(shellOpts, "--", fmt.Sprintf(`$env:%s=%s;`, wshutil.WaveJwtTokenVarName, jwtToken))
|
||||
} else {
|
||||
shellOpts = append(shellOpts, "--", fmt.Sprintf(`%s=%s`, wshutil.WaveJwtTokenVarName, jwtToken))
|
||||
}
|
||||
shellOpts = append(shellOpts, shellPath)
|
||||
shellOpts = append(shellOpts, subShellOpts...)
|
||||
log.Printf("full cmd is: %s %s", "wsl.exe", strings.Join(shellOpts, " "))
|
||||
|
||||
ecmd := exec.Command("wsl.exe", shellOpts...)
|
||||
if termSize.Rows == 0 || termSize.Cols == 0 {
|
||||
termSize.Rows = shellutil.DefaultTermRows
|
||||
termSize.Cols = shellutil.DefaultTermCols
|
||||
}
|
||||
if termSize.Rows <= 0 || termSize.Cols <= 0 {
|
||||
return nil, fmt.Errorf("invalid term size: %v", termSize)
|
||||
}
|
||||
cmdPty, err := pty.StartWithSize(ecmd, &pty.Winsize{Rows: uint16(termSize.Rows), Cols: uint16(termSize.Cols)})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ShellProc{Cmd: CmdWrap{ecmd, cmdPty}, ConnName: conn.GetName(), CloseOnce: &sync.Once{}, DoneCh: make(chan any)}, nil
|
||||
}
|
||||
|
||||
func StartRemoteShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *conncontroller.SSHConn) (*ShellProc, error) {
|
||||
client := conn.GetClient()
|
||||
shellPath := cmdOpts.ShellPath
|
||||
|
58
pkg/util/packetparser/packetparser.go
Normal file
58
pkg/util/packetparser/packetparser.go
Normal file
@ -0,0 +1,58 @@
|
||||
// Copyright 2024, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package packetparser
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
type PacketParser struct {
|
||||
Reader io.Reader
|
||||
Ch chan []byte
|
||||
}
|
||||
|
||||
func Parse(input io.Reader, packetCh chan []byte, rawCh chan []byte) error {
|
||||
bufReader := bufio.NewReader(input)
|
||||
defer close(packetCh)
|
||||
defer close(rawCh)
|
||||
for {
|
||||
line, err := bufReader.ReadBytes('\n')
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(line) <= 1 {
|
||||
// just a blank line
|
||||
continue
|
||||
}
|
||||
if bytes.HasPrefix(line, []byte{'#', '#', 'N', '{'}) && bytes.HasSuffix(line, []byte{'}', '\n'}) {
|
||||
// strip off the leading "##" and trailing "\n" (single byte)
|
||||
packetCh <- line[3 : len(line)-1]
|
||||
} else {
|
||||
rawCh <- line
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func WritePacket(output io.Writer, packet []byte) error {
|
||||
if len(packet) < 2 {
|
||||
return nil
|
||||
}
|
||||
if packet[0] != '{' || packet[len(packet)-1] != '}' {
|
||||
return fmt.Errorf("invalid packet, must start with '{' and end with '}'")
|
||||
}
|
||||
fullPacket := make([]byte, 0, len(packet)+5)
|
||||
// we add the extra newline to make sure the ## appears at the beginning of the line
|
||||
// since writer isn't buffered, we want to send this all at once
|
||||
fullPacket = append(fullPacket, '\n', '#', '#', 'N')
|
||||
fullPacket = append(fullPacket, packet...)
|
||||
fullPacket = append(fullPacket, '\n')
|
||||
_, err := output.Write(fullPacket)
|
||||
return err
|
||||
}
|
@ -30,10 +30,13 @@ const WaveDataHomeEnvVar = "WAVETERM_DATA_HOME"
|
||||
const WaveDevVarName = "WAVETERM_DEV"
|
||||
const WaveLockFile = "wave.lock"
|
||||
const DomainSocketBaseName = "wave.sock"
|
||||
const RemoteDomainSocketBaseName = "wave-remote.sock"
|
||||
const WaveDBDir = "db"
|
||||
const JwtSecret = "waveterm" // TODO generate and store this
|
||||
const ConfigDir = "config"
|
||||
|
||||
var RemoteWaveHome = ExpandHomeDirSafe("~/.waveterm")
|
||||
|
||||
const WaveAppPathVarName = "WAVETERM_APP_PATH"
|
||||
const AppPathBinDir = "bin"
|
||||
|
||||
@ -101,6 +104,10 @@ func GetDomainSocketName() string {
|
||||
return filepath.Join(GetWaveDataDir(), DomainSocketBaseName)
|
||||
}
|
||||
|
||||
func GetRemoteDomainSocketName() string {
|
||||
return filepath.Join(RemoteWaveHome, RemoteDomainSocketBaseName)
|
||||
}
|
||||
|
||||
func GetWaveDataDir() string {
|
||||
retVal, found := os.LookupEnv(WaveDataHomeEnvVar)
|
||||
if !found {
|
||||
|
@ -431,7 +431,7 @@ func MakeTCPListener(serviceName string) (net.Listener, error) {
|
||||
}
|
||||
|
||||
func MakeUnixListener() (net.Listener, error) {
|
||||
serverAddr := wavebase.GetWaveDataDir() + "/wave.sock"
|
||||
serverAddr := wavebase.GetDomainSocketName()
|
||||
os.Remove(serverAddr) // ignore error
|
||||
rtn, err := net.Listen("unix", serverAddr)
|
||||
if err != nil {
|
||||
|
@ -252,7 +252,7 @@ func registerConn(wsConnId string, routeId string, wproxy *wshutil.WshRpcProxy)
|
||||
wshutil.DefaultRouter.UnregisterRoute(routeId)
|
||||
}
|
||||
RouteToConnMap[routeId] = wsConnId
|
||||
wshutil.DefaultRouter.RegisterRoute(routeId, wproxy)
|
||||
wshutil.DefaultRouter.RegisterRoute(routeId, wproxy, true)
|
||||
}
|
||||
|
||||
func unregisterConn(wsConnId string, routeId string) {
|
||||
|
@ -92,6 +92,12 @@ func DeleteBlockCommand(w *wshutil.WshRpc, data wshrpc.CommandDeleteBlockData, o
|
||||
return err
|
||||
}
|
||||
|
||||
// command "dispose", wshserver.DisposeCommand
|
||||
func DisposeCommand(w *wshutil.WshRpc, data wshrpc.CommandDisposeData, opts *wshrpc.RpcOpts) error {
|
||||
_, err := sendRpcRequestCallHelper[any](w, "dispose", data, opts)
|
||||
return err
|
||||
}
|
||||
|
||||
// command "eventpublish", wshserver.EventPublishCommand
|
||||
func EventPublishCommand(w *wshutil.WshRpc, data wps.WaveEvent, opts *wshrpc.RpcOpts) error {
|
||||
_, err := sendRpcRequestCallHelper[any](w, "eventpublish", data, opts)
|
||||
@ -285,4 +291,22 @@ func WebSelectorCommand(w *wshutil.WshRpc, data wshrpc.CommandWebSelectorData, o
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// command "wsldefaultdistro", wshserver.WslDefaultDistroCommand
|
||||
func WslDefaultDistroCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) (string, error) {
|
||||
resp, err := sendRpcRequestCallHelper[string](w, "wsldefaultdistro", nil, opts)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// command "wsllist", wshserver.WslListCommand
|
||||
func WslListCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) ([]string, error) {
|
||||
resp, err := sendRpcRequestCallHelper[[]string](w, "wsllist", nil, opts)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// command "wslstatus", wshserver.WslStatusCommand
|
||||
func WslStatusCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) ([]wshrpc.ConnStatus, error) {
|
||||
resp, err := sendRpcRequestCallHelper[[]wshrpc.ConnStatus](w, "wslstatus", nil, opts)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
|
||||
|
@ -28,6 +28,7 @@ const (
|
||||
|
||||
const (
|
||||
Command_Authenticate = "authenticate" // special
|
||||
Command_Dispose = "dispose" // special (disposes of the route, for multiproxy only)
|
||||
Command_RouteAnnounce = "routeannounce" // special (for routing)
|
||||
Command_RouteUnannounce = "routeunannounce" // special (for routing)
|
||||
Command_Message = "message"
|
||||
@ -62,11 +63,15 @@ const (
|
||||
Command_RemoteFileDelete = "remotefiledelete"
|
||||
Command_RemoteFileJoiin = "remotefilejoin"
|
||||
|
||||
Command_ConnStatus = "connstatus"
|
||||
Command_WslStatus = "wslstatus"
|
||||
Command_ConnEnsure = "connensure"
|
||||
Command_ConnReinstallWsh = "connreinstallwsh"
|
||||
Command_ConnConnect = "connconnect"
|
||||
Command_ConnDisconnect = "conndisconnect"
|
||||
Command_ConnList = "connlist"
|
||||
Command_WslList = "wsllist"
|
||||
Command_WslDefaultDistro = "wsldefaultdistro"
|
||||
|
||||
Command_WebSelector = "webselector"
|
||||
Command_Notify = "notify"
|
||||
@ -83,6 +88,7 @@ type RespOrErrorUnion[T any] struct {
|
||||
|
||||
type WshRpcInterface interface {
|
||||
AuthenticateCommand(ctx context.Context, data string) (CommandAuthenticateRtnData, error)
|
||||
DisposeCommand(ctx context.Context, data CommandDisposeData) error
|
||||
RouteAnnounceCommand(ctx context.Context) error // (special) announces a new route to the main router
|
||||
RouteUnannounceCommand(ctx context.Context) error // (special) unannounces a route to the main router
|
||||
|
||||
@ -114,11 +120,14 @@ type WshRpcInterface interface {
|
||||
|
||||
// connection functions
|
||||
ConnStatusCommand(ctx context.Context) ([]ConnStatus, error)
|
||||
WslStatusCommand(ctx context.Context) ([]ConnStatus, error)
|
||||
ConnEnsureCommand(ctx context.Context, connName string) error
|
||||
ConnReinstallWshCommand(ctx context.Context, connName string) error
|
||||
ConnConnectCommand(ctx context.Context, connName string) error
|
||||
ConnDisconnectCommand(ctx context.Context, connName string) error
|
||||
ConnListCommand(ctx context.Context) ([]string, error)
|
||||
WslListCommand(ctx context.Context) ([]string, error)
|
||||
WslDefaultDistroCommand(ctx context.Context) (string, error)
|
||||
|
||||
// eventrecv is special, it's handled internally by WshRpc with EventListener
|
||||
EventRecvCommand(ctx context.Context, data wps.WaveEvent) error
|
||||
@ -201,6 +210,12 @@ func HackRpcContextIntoData(dataPtr any, rpcContext RpcContext) {
|
||||
|
||||
type CommandAuthenticateRtnData struct {
|
||||
RouteId string `json:"routeid"`
|
||||
AuthToken string `json:"authtoken,omitempty"`
|
||||
}
|
||||
|
||||
type CommandDisposeData struct {
|
||||
RouteId string `json:"routeid"`
|
||||
// auth token travels in the packet directly
|
||||
}
|
||||
|
||||
type CommandMessageData struct {
|
||||
|
@ -21,6 +21,7 @@ import (
|
||||
"github.com/wavetermdev/waveterm/pkg/filestore"
|
||||
"github.com/wavetermdev/waveterm/pkg/remote"
|
||||
"github.com/wavetermdev/waveterm/pkg/remote/conncontroller"
|
||||
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
|
||||
"github.com/wavetermdev/waveterm/pkg/waveai"
|
||||
"github.com/wavetermdev/waveterm/pkg/waveobj"
|
||||
"github.com/wavetermdev/waveterm/pkg/wconfig"
|
||||
@ -29,6 +30,7 @@ import (
|
||||
"github.com/wavetermdev/waveterm/pkg/wps"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||
"github.com/wavetermdev/waveterm/pkg/wsl"
|
||||
"github.com/wavetermdev/waveterm/pkg/wstore"
|
||||
)
|
||||
|
||||
@ -36,6 +38,7 @@ const SimpleId_This = "this"
|
||||
const SimpleId_Tab = "tab"
|
||||
|
||||
var SimpleId_BlockNum_Regex = regexp.MustCompile(`^\d+$`)
|
||||
var InvalidWslDistroNames = []string{"docker-desktop", "docker-desktop-data"}
|
||||
|
||||
type WshServer struct{}
|
||||
|
||||
@ -463,11 +466,28 @@ func (ws *WshServer) ConnStatusCommand(ctx context.Context) ([]wshrpc.ConnStatus
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func (ws *WshServer) WslStatusCommand(ctx context.Context) ([]wshrpc.ConnStatus, error) {
|
||||
rtn := wsl.GetAllConnStatus()
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func (ws *WshServer) ConnEnsureCommand(ctx context.Context, connName string) error {
|
||||
if strings.HasPrefix(connName, "wsl://") {
|
||||
distroName := strings.TrimPrefix(connName, "wsl://")
|
||||
return wsl.EnsureConnection(ctx, distroName)
|
||||
}
|
||||
return conncontroller.EnsureConnection(ctx, connName)
|
||||
}
|
||||
|
||||
func (ws *WshServer) ConnDisconnectCommand(ctx context.Context, connName string) error {
|
||||
if strings.HasPrefix(connName, "wsl://") {
|
||||
distroName := strings.TrimPrefix(connName, "wsl://")
|
||||
conn := wsl.GetWslConn(ctx, distroName, false)
|
||||
if conn == nil {
|
||||
return fmt.Errorf("distro not found: %s", connName)
|
||||
}
|
||||
return conn.Close()
|
||||
}
|
||||
connOpts, err := remote.ParseOpts(connName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing connection name: %w", err)
|
||||
@ -480,6 +500,14 @@ func (ws *WshServer) ConnDisconnectCommand(ctx context.Context, connName string)
|
||||
}
|
||||
|
||||
func (ws *WshServer) ConnConnectCommand(ctx context.Context, connName string) error {
|
||||
if strings.HasPrefix(connName, "wsl://") {
|
||||
distroName := strings.TrimPrefix(connName, "wsl://")
|
||||
conn := wsl.GetWslConn(ctx, distroName, false)
|
||||
if conn == nil {
|
||||
return fmt.Errorf("connection not found: %s", connName)
|
||||
}
|
||||
return conn.Connect(ctx)
|
||||
}
|
||||
connOpts, err := remote.ParseOpts(connName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing connection name: %w", err)
|
||||
@ -492,6 +520,14 @@ func (ws *WshServer) ConnConnectCommand(ctx context.Context, connName string) er
|
||||
}
|
||||
|
||||
func (ws *WshServer) ConnReinstallWshCommand(ctx context.Context, connName string) error {
|
||||
if strings.HasPrefix(connName, "wsl://") {
|
||||
distroName := strings.TrimPrefix(connName, "wsl://")
|
||||
conn := wsl.GetWslConn(ctx, distroName, false)
|
||||
if conn == nil {
|
||||
return fmt.Errorf("connection not found: %s", connName)
|
||||
}
|
||||
return conn.CheckAndInstallWsh(ctx, connName, &wsl.WshInstallOpts{Force: true, NoUserPrompt: true})
|
||||
}
|
||||
connOpts, err := remote.ParseOpts(connName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing connection name: %w", err)
|
||||
@ -507,6 +543,33 @@ func (ws *WshServer) ConnListCommand(ctx context.Context) ([]string, error) {
|
||||
return conncontroller.GetConnectionsList()
|
||||
}
|
||||
|
||||
func (ws *WshServer) WslListCommand(ctx context.Context) ([]string, error) {
|
||||
distros, err := wsl.RegisteredDistros(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var distroNames []string
|
||||
for _, distro := range distros {
|
||||
distroName := distro.Name()
|
||||
if utilfn.ContainsStr(InvalidWslDistroNames, distroName) {
|
||||
continue
|
||||
}
|
||||
distroNames = append(distroNames, distroName)
|
||||
}
|
||||
return distroNames, nil
|
||||
}
|
||||
|
||||
func (ws *WshServer) WslDefaultDistroCommand(ctx context.Context) (string, error) {
|
||||
distro, ok, err := wsl.DefaultDistro(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to determine default distro: %w", err)
|
||||
}
|
||||
if !ok {
|
||||
return "", fmt.Errorf("unable to determine default distro")
|
||||
}
|
||||
return distro.Name(), nil
|
||||
}
|
||||
|
||||
func (ws *WshServer) BlockInfoCommand(ctx context.Context, blockId string) (*wshrpc.BlockInfoData, error) {
|
||||
blockData, err := wstore.DBMustGet[*waveobj.Block](ctx, blockId)
|
||||
if err != nil {
|
||||
|
151
pkg/wshutil/wshmultiproxy.go
Normal file
151
pkg/wshutil/wshmultiproxy.go
Normal file
@ -0,0 +1,151 @@
|
||||
// Copyright 2024, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package wshutil
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
)
|
||||
|
||||
type multiProxyRouteInfo struct {
|
||||
RouteId string
|
||||
AuthToken string
|
||||
Proxy *WshRpcProxy
|
||||
RpcContext *wshrpc.RpcContext
|
||||
}
|
||||
|
||||
// handles messages from multiple unauthenitcated clients
|
||||
type WshRpcMultiProxy struct {
|
||||
Lock *sync.Mutex
|
||||
RouteInfo map[string]*multiProxyRouteInfo // authtoken to info
|
||||
ToRemoteCh chan []byte
|
||||
FromRemoteRawCh chan []byte // raw message from the remote
|
||||
}
|
||||
|
||||
func MakeRpcMultiProxy() *WshRpcMultiProxy {
|
||||
return &WshRpcMultiProxy{
|
||||
Lock: &sync.Mutex{},
|
||||
RouteInfo: make(map[string]*multiProxyRouteInfo),
|
||||
ToRemoteCh: make(chan []byte, DefaultInputChSize),
|
||||
FromRemoteRawCh: make(chan []byte, DefaultOutputChSize),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *WshRpcMultiProxy) DisposeRoutes() {
|
||||
p.Lock.Lock()
|
||||
defer p.Lock.Unlock()
|
||||
for authToken, routeInfo := range p.RouteInfo {
|
||||
DefaultRouter.UnregisterRoute(routeInfo.RouteId)
|
||||
delete(p.RouteInfo, authToken)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *WshRpcMultiProxy) getRouteInfo(authToken string) *multiProxyRouteInfo {
|
||||
p.Lock.Lock()
|
||||
defer p.Lock.Unlock()
|
||||
return p.RouteInfo[authToken]
|
||||
}
|
||||
|
||||
func (p *WshRpcMultiProxy) setRouteInfo(authToken string, routeInfo *multiProxyRouteInfo) {
|
||||
p.Lock.Lock()
|
||||
defer p.Lock.Unlock()
|
||||
p.RouteInfo[authToken] = routeInfo
|
||||
}
|
||||
|
||||
func (p *WshRpcMultiProxy) removeRouteInfo(authToken string) {
|
||||
p.Lock.Lock()
|
||||
defer p.Lock.Unlock()
|
||||
delete(p.RouteInfo, authToken)
|
||||
}
|
||||
|
||||
func (p *WshRpcMultiProxy) sendResponseError(msg RpcMessage, sendErr error) {
|
||||
if msg.ReqId == "" {
|
||||
// no response needed
|
||||
return
|
||||
}
|
||||
resp := RpcMessage{
|
||||
ResId: msg.ReqId,
|
||||
Error: sendErr.Error(),
|
||||
}
|
||||
respBytes, _ := json.Marshal(resp)
|
||||
p.ToRemoteCh <- respBytes
|
||||
}
|
||||
|
||||
func (p *WshRpcMultiProxy) sendAuthResponse(msg RpcMessage, routeId string, authToken string) {
|
||||
if msg.ReqId == "" {
|
||||
// no response needed
|
||||
return
|
||||
}
|
||||
resp := RpcMessage{
|
||||
ResId: msg.ReqId,
|
||||
Data: wshrpc.CommandAuthenticateRtnData{RouteId: routeId, AuthToken: authToken},
|
||||
}
|
||||
respBytes, _ := json.Marshal(resp)
|
||||
p.ToRemoteCh <- respBytes
|
||||
}
|
||||
|
||||
func (p *WshRpcMultiProxy) handleUnauthMessage(msgBytes []byte) {
|
||||
var msg RpcMessage
|
||||
err := json.Unmarshal(msgBytes, &msg)
|
||||
if err != nil {
|
||||
// nothing to do here, malformed message
|
||||
return
|
||||
}
|
||||
if msg.Command == wshrpc.Command_Authenticate {
|
||||
rpcContext, routeId, err := handleAuthenticationCommand(msg)
|
||||
if err != nil {
|
||||
p.sendResponseError(msg, err)
|
||||
return
|
||||
}
|
||||
routeInfo := &multiProxyRouteInfo{
|
||||
RouteId: routeId,
|
||||
AuthToken: uuid.New().String(),
|
||||
RpcContext: rpcContext,
|
||||
}
|
||||
routeInfo.Proxy = MakeRpcProxy()
|
||||
routeInfo.Proxy.SetRpcContext(rpcContext)
|
||||
p.setRouteInfo(routeInfo.AuthToken, routeInfo)
|
||||
p.sendAuthResponse(msg, routeId, routeInfo.AuthToken)
|
||||
go func() {
|
||||
for msgBytes := range routeInfo.Proxy.ToRemoteCh {
|
||||
p.ToRemoteCh <- msgBytes
|
||||
}
|
||||
}()
|
||||
DefaultRouter.RegisterRoute(routeId, routeInfo.Proxy, true)
|
||||
return
|
||||
}
|
||||
if msg.AuthToken == "" {
|
||||
p.sendResponseError(msg, fmt.Errorf("no auth token"))
|
||||
return
|
||||
}
|
||||
routeInfo := p.getRouteInfo(msg.AuthToken)
|
||||
if routeInfo == nil {
|
||||
p.sendResponseError(msg, fmt.Errorf("invalid auth token"))
|
||||
return
|
||||
}
|
||||
if msg.Command != "" && msg.Source != routeInfo.RouteId {
|
||||
p.sendResponseError(msg, fmt.Errorf("invalid source route for auth token"))
|
||||
return
|
||||
}
|
||||
if msg.Command == wshrpc.Command_Dispose {
|
||||
DefaultRouter.UnregisterRoute(routeInfo.RouteId)
|
||||
p.removeRouteInfo(msg.AuthToken)
|
||||
close(routeInfo.Proxy.ToRemoteCh)
|
||||
close(routeInfo.Proxy.FromRemoteCh)
|
||||
return
|
||||
}
|
||||
routeInfo.Proxy.FromRemoteCh <- msgBytes
|
||||
}
|
||||
|
||||
func (p *WshRpcMultiProxy) RunUnauthLoop() {
|
||||
// loop over unauthenticated message
|
||||
// handle Authenicate commands, and pass authenticated messages to the AuthCh
|
||||
for msgBytes := range p.FromRemoteRawCh {
|
||||
p.handleUnauthMessage(msgBytes)
|
||||
}
|
||||
}
|
@ -6,7 +6,6 @@ package wshutil
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@ -18,6 +17,7 @@ type WshRpcProxy struct {
|
||||
RpcContext *wshrpc.RpcContext
|
||||
ToRemoteCh chan []byte
|
||||
FromRemoteCh chan []byte
|
||||
AuthToken string
|
||||
}
|
||||
|
||||
func MakeRpcProxy() *WshRpcProxy {
|
||||
@ -40,6 +40,18 @@ func (p *WshRpcProxy) GetRpcContext() *wshrpc.RpcContext {
|
||||
return p.RpcContext
|
||||
}
|
||||
|
||||
func (p *WshRpcProxy) SetAuthToken(authToken string) {
|
||||
p.Lock.Lock()
|
||||
defer p.Lock.Unlock()
|
||||
p.AuthToken = authToken
|
||||
}
|
||||
|
||||
func (p *WshRpcProxy) GetAuthToken() string {
|
||||
p.Lock.Lock()
|
||||
defer p.Lock.Unlock()
|
||||
return p.AuthToken
|
||||
}
|
||||
|
||||
func (p *WshRpcProxy) sendResponseError(msg RpcMessage, sendErr error) {
|
||||
if msg.ReqId == "" {
|
||||
// no response needed
|
||||
@ -54,7 +66,7 @@ func (p *WshRpcProxy) sendResponseError(msg RpcMessage, sendErr error) {
|
||||
p.SendRpcMessage(respBytes)
|
||||
}
|
||||
|
||||
func (p *WshRpcProxy) sendResponse(msg RpcMessage, routeId string) {
|
||||
func (p *WshRpcProxy) sendAuthenticateResponse(msg RpcMessage, routeId string) {
|
||||
if msg.ReqId == "" {
|
||||
// no response needed
|
||||
return
|
||||
@ -98,6 +110,49 @@ func handleAuthenticationCommand(msg RpcMessage) (*wshrpc.RpcContext, string, er
|
||||
return newCtx, routeId, nil
|
||||
}
|
||||
|
||||
// runs on the client (stdio client)
|
||||
func (p *WshRpcProxy) HandleClientProxyAuth(router *WshRouter) (string, error) {
|
||||
for {
|
||||
msgBytes, ok := <-p.FromRemoteCh
|
||||
if !ok {
|
||||
return "", fmt.Errorf("remote closed, not authenticated")
|
||||
}
|
||||
var origMsg RpcMessage
|
||||
err := json.Unmarshal(msgBytes, &origMsg)
|
||||
if err != nil {
|
||||
// nothing to do, can't even send a response since we don't have Source or ReqId
|
||||
continue
|
||||
}
|
||||
if origMsg.Command == "" {
|
||||
// this message is not allowed (protocol error at this point), ignore
|
||||
continue
|
||||
}
|
||||
// we only allow one command "authenticate", everything else returns an error
|
||||
if origMsg.Command != wshrpc.Command_Authenticate {
|
||||
respErr := fmt.Errorf("connection not authenticated")
|
||||
p.sendResponseError(origMsg, respErr)
|
||||
continue
|
||||
}
|
||||
authRtn, err := router.HandleProxyAuth(origMsg.Data)
|
||||
if err != nil {
|
||||
respErr := fmt.Errorf("error handling proxy auth: %w", err)
|
||||
p.sendResponseError(origMsg, respErr)
|
||||
return "", respErr
|
||||
}
|
||||
p.SetAuthToken(authRtn.AuthToken)
|
||||
announceMsg := RpcMessage{
|
||||
Command: wshrpc.Command_RouteAnnounce,
|
||||
Source: authRtn.RouteId,
|
||||
AuthToken: authRtn.AuthToken,
|
||||
}
|
||||
announceBytes, _ := json.Marshal(announceMsg)
|
||||
router.InjectMessage(announceBytes, authRtn.RouteId)
|
||||
p.sendAuthenticateResponse(origMsg, authRtn.RouteId)
|
||||
return authRtn.RouteId, nil
|
||||
}
|
||||
}
|
||||
|
||||
// runs on the server
|
||||
func (p *WshRpcProxy) HandleAuthentication() (*wshrpc.RpcContext, error) {
|
||||
for {
|
||||
msgBytes, ok := <-p.FromRemoteCh
|
||||
@ -122,11 +177,10 @@ func (p *WshRpcProxy) HandleAuthentication() (*wshrpc.RpcContext, error) {
|
||||
}
|
||||
newCtx, routeId, err := handleAuthenticationCommand(msg)
|
||||
if err != nil {
|
||||
log.Printf("error handling authentication: %v\n", err)
|
||||
p.sendResponseError(msg, err)
|
||||
continue
|
||||
}
|
||||
p.sendResponse(msg, routeId)
|
||||
p.sendAuthenticateResponse(msg, routeId)
|
||||
return newCtx, nil
|
||||
}
|
||||
}
|
||||
@ -136,9 +190,10 @@ func (p *WshRpcProxy) SendRpcMessage(msg []byte) {
|
||||
}
|
||||
|
||||
func (p *WshRpcProxy) RecvRpcMessage() ([]byte, bool) {
|
||||
msgBytes, ok := <-p.FromRemoteCh
|
||||
if !ok || p.RpcContext == nil {
|
||||
return msgBytes, ok
|
||||
msgBytes, more := <-p.FromRemoteCh
|
||||
authToken := p.GetAuthToken()
|
||||
if !more || (p.RpcContext == nil && authToken == "") {
|
||||
return msgBytes, more
|
||||
}
|
||||
var msg RpcMessage
|
||||
err := json.Unmarshal(msgBytes, &msg)
|
||||
@ -146,11 +201,16 @@ func (p *WshRpcProxy) RecvRpcMessage() ([]byte, bool) {
|
||||
// nothing to do here -- will error out at another level
|
||||
return msgBytes, true
|
||||
}
|
||||
if p.RpcContext != nil {
|
||||
msg.Data, err = recodeCommandData(msg.Command, msg.Data, p.RpcContext)
|
||||
if err != nil {
|
||||
// nothing to do here -- will error out at another level
|
||||
return msgBytes, true
|
||||
}
|
||||
}
|
||||
if msg.AuthToken == "" {
|
||||
msg.AuthToken = authToken
|
||||
}
|
||||
newBytes, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
// nothing to do here
|
||||
|
@ -12,11 +12,14 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
|
||||
"github.com/wavetermdev/waveterm/pkg/wps"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
)
|
||||
|
||||
const DefaultRoute = "wavesrv"
|
||||
const UpstreamRoute = "upstream"
|
||||
const SysRoute = "sys" // this route doesn't exist, just a placeholder for system messages
|
||||
const ElectronRoute = "electron"
|
||||
|
||||
@ -41,6 +44,7 @@ type WshRouter struct {
|
||||
UpstreamClient AbstractRpcClient // upstream client (if we are not the terminal router)
|
||||
AnnouncedRoutes map[string]string // routeid => local routeid
|
||||
RpcMap map[string]*routeInfo // rpcid => routeinfo
|
||||
SimpleRequestMap map[string]chan *RpcMessage // simple reqid => response channel
|
||||
InputCh chan msgAndRoute
|
||||
}
|
||||
|
||||
@ -72,6 +76,7 @@ func NewWshRouter() *WshRouter {
|
||||
RouteMap: make(map[string]AbstractRpcClient),
|
||||
AnnouncedRoutes: make(map[string]string),
|
||||
RpcMap: make(map[string]*routeInfo),
|
||||
SimpleRequestMap: make(map[string]chan *RpcMessage),
|
||||
InputCh: make(chan msgAndRoute, DefaultInputChSize),
|
||||
}
|
||||
go rtn.runServer()
|
||||
@ -237,6 +242,10 @@ func (router *WshRouter) runServer() {
|
||||
router.sendRoutedMessage(msgBytes, routeInfo.DestRouteId)
|
||||
continue
|
||||
} else if msg.ResId != "" {
|
||||
ok := router.trySimpleResponse(&msg)
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
routeInfo := router.getRouteInfo(msg.ResId)
|
||||
if routeInfo == nil {
|
||||
// no route info, nothing to do
|
||||
@ -269,10 +278,10 @@ func (router *WshRouter) WaitForRegister(ctx context.Context, routeId string) er
|
||||
}
|
||||
|
||||
// this will also consume the output channel of the abstract client
|
||||
func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient) {
|
||||
if routeId == SysRoute {
|
||||
func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient, shouldAnnounce bool) {
|
||||
if routeId == SysRoute || routeId == UpstreamRoute {
|
||||
// cannot register sys route
|
||||
log.Printf("error: WshRouter cannot register sys route\n")
|
||||
log.Printf("error: WshRouter cannot register %s route\n", routeId)
|
||||
return
|
||||
}
|
||||
log.Printf("[router] registering wsh route %q\n", routeId)
|
||||
@ -285,7 +294,7 @@ func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient) {
|
||||
router.RouteMap[routeId] = rpc
|
||||
go func() {
|
||||
// announce
|
||||
if !alreadyExists && router.GetUpstreamClient() != nil {
|
||||
if shouldAnnounce && !alreadyExists && router.GetUpstreamClient() != nil {
|
||||
announceMsg := RpcMessage{Command: wshrpc.Command_RouteAnnounce, Source: routeId}
|
||||
announceBytes, _ := json.Marshal(announceMsg)
|
||||
router.GetUpstreamClient().SendRpcMessage(announceBytes)
|
||||
@ -352,3 +361,97 @@ func (router *WshRouter) GetUpstreamClient() AbstractRpcClient {
|
||||
defer router.Lock.Unlock()
|
||||
return router.UpstreamClient
|
||||
}
|
||||
|
||||
func (router *WshRouter) InjectMessage(msgBytes []byte, fromRouteId string) {
|
||||
router.InputCh <- msgAndRoute{msgBytes: msgBytes, fromRouteId: fromRouteId}
|
||||
}
|
||||
|
||||
func (router *WshRouter) registerSimpleRequest(reqId string) chan *RpcMessage {
|
||||
router.Lock.Lock()
|
||||
defer router.Lock.Unlock()
|
||||
rtn := make(chan *RpcMessage, 1)
|
||||
router.SimpleRequestMap[reqId] = rtn
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (router *WshRouter) trySimpleResponse(msg *RpcMessage) bool {
|
||||
router.Lock.Lock()
|
||||
defer router.Lock.Unlock()
|
||||
respCh := router.SimpleRequestMap[msg.ResId]
|
||||
if respCh == nil {
|
||||
return false
|
||||
}
|
||||
respCh <- msg
|
||||
delete(router.SimpleRequestMap, msg.ResId)
|
||||
return true
|
||||
}
|
||||
|
||||
func (router *WshRouter) clearSimpleRequest(reqId string) {
|
||||
router.Lock.Lock()
|
||||
defer router.Lock.Unlock()
|
||||
delete(router.SimpleRequestMap, reqId)
|
||||
}
|
||||
|
||||
func (router *WshRouter) RunSimpleRawCommand(ctx context.Context, msg RpcMessage, fromRouteId string) (*RpcMessage, error) {
|
||||
if msg.Command == "" {
|
||||
return nil, errors.New("no command")
|
||||
}
|
||||
msgBytes, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var respCh chan *RpcMessage
|
||||
if msg.ReqId != "" {
|
||||
respCh = router.registerSimpleRequest(msg.ReqId)
|
||||
}
|
||||
router.InjectMessage(msgBytes, fromRouteId)
|
||||
if respCh == nil {
|
||||
return nil, nil
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
router.clearSimpleRequest(msg.ReqId)
|
||||
return nil, ctx.Err()
|
||||
case resp := <-respCh:
|
||||
if resp.Error != "" {
|
||||
return nil, errors.New(resp.Error)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (router *WshRouter) HandleProxyAuth(jwtTokenAny any) (*wshrpc.CommandAuthenticateRtnData, error) {
|
||||
if jwtTokenAny == nil {
|
||||
return nil, errors.New("no jwt token")
|
||||
}
|
||||
jwtToken, ok := jwtTokenAny.(string)
|
||||
if !ok {
|
||||
return nil, errors.New("jwt token not a string")
|
||||
}
|
||||
if jwtToken == "" {
|
||||
return nil, errors.New("empty jwt token")
|
||||
}
|
||||
msg := RpcMessage{
|
||||
Command: wshrpc.Command_Authenticate,
|
||||
ReqId: uuid.New().String(),
|
||||
Data: jwtToken,
|
||||
}
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), DefaultTimeoutMs*time.Millisecond)
|
||||
defer cancelFn()
|
||||
resp, err := router.RunSimpleRawCommand(ctx, msg, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp == nil || resp.Data == nil {
|
||||
return nil, errors.New("no data in authenticate response")
|
||||
}
|
||||
var respData wshrpc.CommandAuthenticateRtnData
|
||||
err = utilfn.ReUnmarshal(&respData, resp.Data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error unmarshalling authenticate response: %v", err)
|
||||
}
|
||||
if respData.AuthToken == "" {
|
||||
return nil, errors.New("no auth token in authenticate response")
|
||||
}
|
||||
return &respData, nil
|
||||
}
|
||||
|
@ -45,10 +45,13 @@ type WshRpc struct {
|
||||
InputCh chan []byte
|
||||
OutputCh chan []byte
|
||||
RpcContext *atomic.Pointer[wshrpc.RpcContext]
|
||||
AuthToken string
|
||||
RpcMap map[string]*rpcData
|
||||
ServerImpl ServerImpl
|
||||
EventListener *EventListener
|
||||
ResponseHandlerMap map[string]*RpcResponseHandler // reqId => handler
|
||||
Debug bool
|
||||
DebugName string
|
||||
}
|
||||
|
||||
type wshRpcContextKey struct{}
|
||||
@ -109,6 +112,7 @@ type RpcMessage struct {
|
||||
ResId string `json:"resid,omitempty"`
|
||||
Timeout int `json:"timeout,omitempty"`
|
||||
Route string `json:"route,omitempty"` // to route/forward requests to alternate servers
|
||||
AuthToken string `json:"authtoken,omitempty"` // needed for routing unauthenticated requests (WshRpcMultiProxy)
|
||||
Source string `json:"source,omitempty"` // source route id
|
||||
Cont bool `json:"cont,omitempty"` // flag if additional requests/responses are forthcoming
|
||||
Cancel bool `json:"cancel,omitempty"` // used to cancel a streaming request or response (sent from the side that is not streaming)
|
||||
@ -226,6 +230,14 @@ func (w *WshRpc) SetRpcContext(ctx wshrpc.RpcContext) {
|
||||
w.RpcContext.Store(&ctx)
|
||||
}
|
||||
|
||||
func (w *WshRpc) SetAuthToken(token string) {
|
||||
w.AuthToken = token
|
||||
}
|
||||
|
||||
func (w *WshRpc) GetAuthToken() string {
|
||||
return w.AuthToken
|
||||
}
|
||||
|
||||
func (w *WshRpc) registerResponseHandler(reqId string, handler *RpcResponseHandler) {
|
||||
w.Lock.Lock()
|
||||
defer w.Lock.Unlock()
|
||||
@ -323,6 +335,9 @@ func (w *WshRpc) handleRequest(req *RpcMessage) {
|
||||
func (w *WshRpc) runServer() {
|
||||
defer close(w.OutputCh)
|
||||
for msgBytes := range w.InputCh {
|
||||
if w.Debug {
|
||||
log.Printf("[%s] received message: %s\n", w.DebugName, string(msgBytes))
|
||||
}
|
||||
var msg RpcMessage
|
||||
err := json.Unmarshal(msgBytes, &msg)
|
||||
if err != nil {
|
||||
@ -457,6 +472,7 @@ func (handler *RpcRequestHandler) SendCancel() {
|
||||
msg := &RpcMessage{
|
||||
Cancel: true,
|
||||
ReqId: handler.reqId,
|
||||
AuthToken: handler.w.GetAuthToken(),
|
||||
}
|
||||
barr, _ := json.Marshal(msg) // will never fail
|
||||
handler.w.OutputCh <- barr
|
||||
@ -550,6 +566,7 @@ func (handler *RpcResponseHandler) SendMessage(msg string) {
|
||||
Data: wshrpc.CommandMessageData{
|
||||
Message: msg,
|
||||
},
|
||||
AuthToken: handler.w.GetAuthToken(),
|
||||
}
|
||||
msgBytes, _ := json.Marshal(rpcMsg) // will never fail
|
||||
handler.w.OutputCh <- msgBytes
|
||||
@ -576,6 +593,7 @@ func (handler *RpcResponseHandler) SendResponse(data any, done bool) error {
|
||||
ResId: handler.reqId,
|
||||
Data: data,
|
||||
Cont: !done,
|
||||
AuthToken: handler.w.GetAuthToken(),
|
||||
}
|
||||
barr, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
@ -600,6 +618,7 @@ func (handler *RpcResponseHandler) SendResponseError(err error) {
|
||||
msg := &RpcMessage{
|
||||
ResId: handler.reqId,
|
||||
Error: err.Error(),
|
||||
AuthToken: handler.w.GetAuthToken(),
|
||||
}
|
||||
barr, _ := json.Marshal(msg) // will never fail
|
||||
handler.w.OutputCh <- barr
|
||||
@ -665,6 +684,7 @@ func (w *WshRpc) SendComplexRequest(command string, data any, opts *wshrpc.RpcOp
|
||||
Data: data,
|
||||
Timeout: timeoutMs,
|
||||
Route: opts.Route,
|
||||
AuthToken: w.GetAuthToken(),
|
||||
}
|
||||
barr, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
|
@ -19,6 +19,7 @@ import (
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/wavetermdev/waveterm/pkg/util/packetparser"
|
||||
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
"golang.org/x/term"
|
||||
@ -204,11 +205,26 @@ func SetupTerminalRpcClient(serverImpl ServerImpl) (*WshRpc, io.Reader) {
|
||||
continue
|
||||
}
|
||||
os.Stdout.Write(barr)
|
||||
os.Stdout.Write([]byte{'\n'})
|
||||
}
|
||||
}()
|
||||
return rpcClient, ptyBuf
|
||||
}
|
||||
|
||||
func SetupPacketRpcClient(input io.Reader, output io.Writer, serverImpl ServerImpl) (*WshRpc, chan []byte) {
|
||||
messageCh := make(chan []byte, DefaultInputChSize)
|
||||
outputCh := make(chan []byte, DefaultOutputChSize)
|
||||
rawCh := make(chan []byte, DefaultOutputChSize)
|
||||
rpcClient := MakeWshRpc(messageCh, outputCh, wshrpc.RpcContext{}, serverImpl)
|
||||
go packetparser.Parse(input, messageCh, rawCh)
|
||||
go func() {
|
||||
for msg := range outputCh {
|
||||
packetparser.WritePacket(output, msg)
|
||||
}
|
||||
}()
|
||||
return rpcClient, rawCh
|
||||
}
|
||||
|
||||
func SetupConnRpcClient(conn net.Conn, serverImpl ServerImpl) (*WshRpc, chan error, error) {
|
||||
inputCh := make(chan []byte, DefaultInputChSize)
|
||||
outputCh := make(chan []byte, DefaultOutputChSize)
|
||||
@ -229,10 +245,22 @@ func SetupConnRpcClient(conn net.Conn, serverImpl ServerImpl) (*WshRpc, chan err
|
||||
return rtn, writeErrCh, nil
|
||||
}
|
||||
|
||||
func SetupDomainSocketRpcClient(sockName string, serverImpl ServerImpl) (*WshRpc, error) {
|
||||
conn, err := net.Dial("unix", sockName)
|
||||
func tryTcpSocket(sockName string) (net.Conn, error) {
|
||||
addr, err := net.ResolveTCPAddr("tcp", sockName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to Unix domain socket: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
return net.DialTCP("tcp", nil, addr)
|
||||
}
|
||||
|
||||
func SetupDomainSocketRpcClient(sockName string, serverImpl ServerImpl) (*WshRpc, error) {
|
||||
conn, tcpErr := tryTcpSocket(sockName)
|
||||
var unixErr error
|
||||
if tcpErr != nil {
|
||||
conn, unixErr = net.Dial("unix", sockName)
|
||||
}
|
||||
if tcpErr != nil && unixErr != nil {
|
||||
return nil, fmt.Errorf("failed to connect to tcp or unix domain socket: tcp err:%w: unix socket err: %w", tcpErr, unixErr)
|
||||
}
|
||||
rtn, errCh, err := SetupConnRpcClient(conn, serverImpl)
|
||||
go func() {
|
||||
@ -363,6 +391,46 @@ func MakeRouteIdFromCtx(rpcCtx *wshrpc.RpcContext) (string, error) {
|
||||
return MakeProcRouteId(procId), nil
|
||||
}
|
||||
|
||||
type WriteFlusher interface {
|
||||
Write([]byte) (int, error)
|
||||
Flush() error
|
||||
}
|
||||
|
||||
// blocking, returns if there is an error, or on EOF of input
|
||||
func HandleStdIOClient(logName string, input io.Reader, output io.Writer) {
|
||||
proxy := MakeRpcMultiProxy()
|
||||
rawCh := make(chan []byte, DefaultInputChSize)
|
||||
go packetparser.Parse(input, proxy.FromRemoteRawCh, rawCh)
|
||||
doneCh := make(chan struct{})
|
||||
var doneOnce sync.Once
|
||||
closeDoneCh := func() {
|
||||
doneOnce.Do(func() {
|
||||
close(doneCh)
|
||||
})
|
||||
proxy.DisposeRoutes()
|
||||
}
|
||||
go func() {
|
||||
proxy.RunUnauthLoop()
|
||||
}()
|
||||
go func() {
|
||||
defer closeDoneCh()
|
||||
for msg := range proxy.ToRemoteCh {
|
||||
err := packetparser.WritePacket(output, msg)
|
||||
if err != nil {
|
||||
log.Printf("[%s] error writing to output: %v\n", logName, err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer closeDoneCh()
|
||||
for msg := range rawCh {
|
||||
log.Printf("[%s:stdout] %s", logName, msg)
|
||||
}
|
||||
}()
|
||||
<-doneCh
|
||||
}
|
||||
|
||||
func handleDomainSocketClient(conn net.Conn) {
|
||||
var routeIdContainer atomic.Pointer[string]
|
||||
proxy := MakeRpcProxy()
|
||||
@ -399,7 +467,7 @@ func handleDomainSocketClient(conn net.Conn) {
|
||||
return
|
||||
}
|
||||
routeIdContainer.Store(&routeId)
|
||||
DefaultRouter.RegisterRoute(routeId, proxy)
|
||||
DefaultRouter.RegisterRoute(routeId, proxy, true)
|
||||
}
|
||||
|
||||
// only for use on client
|
||||
@ -433,5 +501,6 @@ func ExtractUnverifiedSocketName(tokenStr string) (string, error) {
|
||||
if !ok {
|
||||
return "", fmt.Errorf("sock claim is missing or invalid")
|
||||
}
|
||||
sockName = wavebase.ExpandHomeDirSafe(sockName)
|
||||
return sockName, nil
|
||||
}
|
||||
|
67
pkg/wsl/wsl-unix.go
Normal file
67
pkg/wsl/wsl-unix.go
Normal file
@ -0,0 +1,67 @@
|
||||
//go:build !windows
|
||||
|
||||
// Copyright 2024, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package wsl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
func RegisteredDistros(ctx context.Context) (distros []Distro, err error) {
|
||||
return nil, fmt.Errorf("RegisteredDistros not implemented on this system")
|
||||
}
|
||||
|
||||
func DefaultDistro(ctx context.Context) (d Distro, ok bool, err error) {
|
||||
return d, false, fmt.Errorf("DefaultDistro not implemented on this system")
|
||||
}
|
||||
|
||||
type Distro struct{}
|
||||
|
||||
func (d *Distro) Name() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (d *Distro) WslCommand(ctx context.Context, cmd string) *WslCmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
// just use the regular cmd since it's
|
||||
// similar enough to not cause issues
|
||||
// type WslCmd = exec.Cmd
|
||||
type WslCmd struct {
|
||||
exec.Cmd
|
||||
}
|
||||
|
||||
func (wc *WslCmd) GetProcess() *os.Process {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (wc *WslCmd) GetProcessState() *os.ProcessState {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WslCmd) SetStdin(stdin io.Reader) {
|
||||
c.Stdin = stdin
|
||||
}
|
||||
|
||||
func (c *WslCmd) SetStdout(stdout io.Writer) {
|
||||
c.Stdout = stdout
|
||||
}
|
||||
|
||||
func (c *WslCmd) SetStderr(stderr io.Writer) {
|
||||
c.Stdout = stderr
|
||||
}
|
||||
|
||||
func GetDistroCmd(ctx context.Context, wslDistroName string, cmd string) (*WslCmd, error) {
|
||||
return nil, fmt.Errorf("GetDistroCmd not implemented on this system")
|
||||
}
|
||||
|
||||
func GetDistro(ctx context.Context, wslDistroName WslName) (*Distro, error) {
|
||||
return nil, fmt.Errorf("GetDistro not implemented on this system")
|
||||
}
|
296
pkg/wsl/wsl-util.go
Normal file
296
pkg/wsl/wsl-util.go
Normal file
@ -0,0 +1,296 @@
|
||||
// Copyright 2024, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package wsl
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func DetectShell(ctx context.Context, client *Distro) (string, error) {
|
||||
wshPath := GetWshPath(ctx, client)
|
||||
|
||||
cmd := client.WslCommand(ctx, wshPath+" shell")
|
||||
log.Printf("shell detecting using command: %s shell", wshPath)
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
log.Printf("unable to determine shell. defaulting to /bin/bash: %s", err)
|
||||
return "/bin/bash", nil
|
||||
}
|
||||
log.Printf("detecting shell: %s", out)
|
||||
|
||||
// quoting breaks this particular case
|
||||
return strings.TrimSpace(string(out)), nil
|
||||
}
|
||||
|
||||
func GetWshVersion(ctx context.Context, client *Distro) (string, error) {
|
||||
wshPath := GetWshPath(ctx, client)
|
||||
|
||||
cmd := client.WslCommand(ctx, wshPath+" version")
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return strings.TrimSpace(string(out)), nil
|
||||
}
|
||||
|
||||
func GetWshPath(ctx context.Context, client *Distro) string {
|
||||
defaultPath := "~/.waveterm/bin/wsh"
|
||||
|
||||
cmd := client.WslCommand(ctx, "which wsh")
|
||||
out, whichErr := cmd.Output()
|
||||
if whichErr == nil {
|
||||
return strings.TrimSpace(string(out))
|
||||
}
|
||||
|
||||
cmd = client.WslCommand(ctx, "where.exe wsh")
|
||||
out, whereErr := cmd.Output()
|
||||
if whereErr == nil {
|
||||
return strings.TrimSpace(string(out))
|
||||
}
|
||||
|
||||
// check cmd on windows since it requires an absolute path with backslashes
|
||||
cmd = client.WslCommand(ctx, "(dir 2>&1 *``|echo %userprofile%\\.waveterm%\\.waveterm\\bin\\wsh.exe);&<# rem #>echo none")
|
||||
out, cmdErr := cmd.Output()
|
||||
if cmdErr == nil && strings.TrimSpace(string(out)) != "none" {
|
||||
return strings.TrimSpace(string(out))
|
||||
}
|
||||
|
||||
// no custom install, use default path
|
||||
return defaultPath
|
||||
}
|
||||
|
||||
func hasBashInstalled(ctx context.Context, client *Distro) (bool, error) {
|
||||
cmd := client.WslCommand(ctx, "which bash")
|
||||
out, whichErr := cmd.Output()
|
||||
if whichErr == nil && len(out) != 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
cmd = client.WslCommand(ctx, "where.exe bash")
|
||||
out, whereErr := cmd.Output()
|
||||
if whereErr == nil && len(out) != 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// note: we could also check in /bin/bash explicitly
|
||||
// just in case that wasn't added to the path. but if
|
||||
// that's true, we will most likely have worse
|
||||
// problems going forward
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func GetClientOs(ctx context.Context, client *Distro) (string, error) {
|
||||
cmd := client.WslCommand(ctx, "uname -s")
|
||||
out, unixErr := cmd.Output()
|
||||
if unixErr == nil {
|
||||
formatted := strings.ToLower(string(out))
|
||||
formatted = strings.TrimSpace(formatted)
|
||||
return formatted, nil
|
||||
}
|
||||
|
||||
cmd = client.WslCommand(ctx, "echo %OS%")
|
||||
out, cmdErr := cmd.Output()
|
||||
if cmdErr == nil && strings.TrimSpace(string(out)) != "%OS%" {
|
||||
formatted := strings.ToLower(string(out))
|
||||
formatted = strings.TrimSpace(formatted)
|
||||
return strings.Split(formatted, "_")[0], nil
|
||||
}
|
||||
|
||||
cmd = client.WslCommand(ctx, "echo $env:OS")
|
||||
out, psErr := cmd.Output()
|
||||
if psErr == nil && strings.TrimSpace(string(out)) != "$env:OS" {
|
||||
formatted := strings.ToLower(string(out))
|
||||
formatted = strings.TrimSpace(formatted)
|
||||
return strings.Split(formatted, "_")[0], nil
|
||||
}
|
||||
return "", fmt.Errorf("unable to determine os: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr)
|
||||
}
|
||||
|
||||
func GetClientArch(ctx context.Context, client *Distro) (string, error) {
|
||||
cmd := client.WslCommand(ctx, "uname -m")
|
||||
out, unixErr := cmd.Output()
|
||||
if unixErr == nil {
|
||||
formatted := strings.ToLower(string(out))
|
||||
formatted = strings.TrimSpace(formatted)
|
||||
if formatted == "x86_64" {
|
||||
return "x64", nil
|
||||
}
|
||||
return formatted, nil
|
||||
}
|
||||
|
||||
cmd = client.WslCommand(ctx, "echo %PROCESSOR_ARCHITECTURE%")
|
||||
out, cmdErr := cmd.Output()
|
||||
if cmdErr == nil && strings.TrimSpace(string(out)) != "%PROCESSOR_ARCHITECTURE%" {
|
||||
formatted := strings.ToLower(string(out))
|
||||
return strings.TrimSpace(formatted), nil
|
||||
}
|
||||
|
||||
cmd = client.WslCommand(ctx, "echo $env:PROCESSOR_ARCHITECTURE")
|
||||
out, psErr := cmd.Output()
|
||||
if psErr == nil && strings.TrimSpace(string(out)) != "$env:PROCESSOR_ARCHITECTURE" {
|
||||
formatted := strings.ToLower(string(out))
|
||||
return strings.TrimSpace(formatted), nil
|
||||
}
|
||||
return "", fmt.Errorf("unable to determine architecture: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr)
|
||||
}
|
||||
|
||||
type CancellableCmd struct {
|
||||
Cmd *WslCmd
|
||||
Cancel func()
|
||||
}
|
||||
|
||||
var installTemplatesRawBash = map[string]string{
|
||||
"mkdir": `bash -c 'mkdir -p {{.installDir}}'`,
|
||||
"cat": `bash -c 'cat > {{.tempPath}}'`,
|
||||
"mv": `bash -c 'mv {{.tempPath}} {{.installPath}}'`,
|
||||
"chmod": `bash -c 'chmod a+x {{.installPath}}'`,
|
||||
}
|
||||
|
||||
var installTemplatesRawDefault = map[string]string{
|
||||
"mkdir": `mkdir -p {{.installDir}}`,
|
||||
"cat": `cat > {{.tempPath}}`,
|
||||
"mv": `mv {{.tempPath}} {{.installPath}}`,
|
||||
"chmod": `chmod a+x {{.installPath}}`,
|
||||
}
|
||||
|
||||
func makeCancellableCommand(ctx context.Context, client *Distro, cmdTemplateRaw string, words map[string]string) (*CancellableCmd, error) {
|
||||
cmdContext, cmdCancel := context.WithCancel(ctx)
|
||||
|
||||
cmdStr := &bytes.Buffer{}
|
||||
cmdTemplate, err := template.New("").Parse(cmdTemplateRaw)
|
||||
if err != nil {
|
||||
cmdCancel()
|
||||
return nil, err
|
||||
}
|
||||
cmdTemplate.Execute(cmdStr, words)
|
||||
|
||||
cmd := client.WslCommand(cmdContext, cmdStr.String())
|
||||
return &CancellableCmd{cmd, cmdCancel}, nil
|
||||
}
|
||||
|
||||
func CpHostToRemote(ctx context.Context, client *Distro, sourcePath string, destPath string) error {
|
||||
// warning: does not work on windows remote yet
|
||||
bashInstalled, err := hasBashInstalled(ctx, client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var selectedTemplatesRaw map[string]string
|
||||
if bashInstalled {
|
||||
selectedTemplatesRaw = installTemplatesRawBash
|
||||
} else {
|
||||
log.Printf("bash is not installed on remote. attempting with default shell")
|
||||
selectedTemplatesRaw = installTemplatesRawDefault
|
||||
}
|
||||
|
||||
// I need to use toSlash here to force unix keybindings
|
||||
// this means we can't guarantee it will work on a remote windows machine
|
||||
var installWords = map[string]string{
|
||||
"installDir": filepath.ToSlash(filepath.Dir(destPath)),
|
||||
"tempPath": destPath + ".temp",
|
||||
"installPath": destPath,
|
||||
}
|
||||
|
||||
installStepCmds := make(map[string]*CancellableCmd)
|
||||
for cmdName, selectedTemplateRaw := range selectedTemplatesRaw {
|
||||
cancellableCmd, err := makeCancellableCommand(ctx, client, selectedTemplateRaw, installWords)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
installStepCmds[cmdName] = cancellableCmd
|
||||
}
|
||||
|
||||
_, err = installStepCmds["mkdir"].Cmd.Output()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// the cat part of this is complicated since it requires stdin
|
||||
catCmd := installStepCmds["cat"].Cmd
|
||||
catStdin, err := catCmd.StdinPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = catCmd.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input, err := os.Open(sourcePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot open local file %s to send to host: %v", sourcePath, err)
|
||||
}
|
||||
go func() {
|
||||
io.Copy(catStdin, input)
|
||||
installStepCmds["cat"].Cancel()
|
||||
|
||||
// backup just in case something weird happens
|
||||
// could cause potential race condition, but very
|
||||
// unlikely
|
||||
time.Sleep(time.Second * 1)
|
||||
process := catCmd.GetProcess()
|
||||
if process != nil {
|
||||
process.Kill()
|
||||
}
|
||||
}()
|
||||
catErr := catCmd.Wait()
|
||||
if catErr != nil && !errors.Is(catErr, context.Canceled) {
|
||||
return catErr
|
||||
}
|
||||
|
||||
_, err = installStepCmds["mv"].Cmd.Output()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = installStepCmds["chmod"].Cmd.Output()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func InstallClientRcFiles(ctx context.Context, client *Distro) error {
|
||||
path := GetWshPath(ctx, client)
|
||||
log.Printf("path to wsh searched is: %s", path)
|
||||
|
||||
cmd := client.WslCommand(ctx, path+" rcfiles")
|
||||
_, err := cmd.Output()
|
||||
return err
|
||||
}
|
||||
|
||||
func GetHomeDir(ctx context.Context, client *Distro) string {
|
||||
// note: also works for powershell
|
||||
cmd := client.WslCommand(ctx, `echo "$HOME"`)
|
||||
out, err := cmd.Output()
|
||||
if err == nil {
|
||||
return strings.TrimSpace(string(out))
|
||||
}
|
||||
|
||||
cmd = client.WslCommand(ctx, `echo %userprofile%`)
|
||||
out, err = cmd.Output()
|
||||
if err == nil {
|
||||
return strings.TrimSpace(string(out))
|
||||
}
|
||||
|
||||
return "~"
|
||||
}
|
||||
|
||||
func IsPowershell(shellPath string) bool {
|
||||
// get the base path, and then check contains
|
||||
shellBase := filepath.Base(shellPath)
|
||||
return strings.Contains(shellBase, "powershell") || strings.Contains(shellBase, "pwsh")
|
||||
}
|
125
pkg/wsl/wsl-win.go
Normal file
125
pkg/wsl/wsl-win.go
Normal file
@ -0,0 +1,125 @@
|
||||
//go:build windows
|
||||
|
||||
// Copyright 2024, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package wsl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/ubuntu/gowsl"
|
||||
)
|
||||
|
||||
var RegisteredDistros = gowsl.RegisteredDistros
|
||||
var DefaultDistro = gowsl.DefaultDistro
|
||||
|
||||
type Distro struct {
|
||||
gowsl.Distro
|
||||
}
|
||||
|
||||
type WslCmd struct {
|
||||
c *gowsl.Cmd
|
||||
wg *sync.WaitGroup
|
||||
once *sync.Once
|
||||
lock *sync.Mutex
|
||||
waitErr error
|
||||
}
|
||||
|
||||
func (d *Distro) WslCommand(ctx context.Context, cmd string) *WslCmd {
|
||||
if ctx == nil {
|
||||
panic("nil Context")
|
||||
}
|
||||
innerCmd := d.Command(ctx, cmd)
|
||||
var wg sync.WaitGroup
|
||||
var lock *sync.Mutex
|
||||
return &WslCmd{innerCmd, &wg, new(sync.Once), lock, nil}
|
||||
}
|
||||
|
||||
func (c *WslCmd) CombinedOutput() (out []byte, err error) {
|
||||
return c.c.CombinedOutput()
|
||||
}
|
||||
func (c *WslCmd) Output() (out []byte, err error) {
|
||||
return c.c.Output()
|
||||
}
|
||||
func (c *WslCmd) Run() error {
|
||||
return c.c.Run()
|
||||
}
|
||||
func (c *WslCmd) Start() (err error) {
|
||||
return c.c.Start()
|
||||
}
|
||||
func (c *WslCmd) StderrPipe() (r io.ReadCloser, err error) {
|
||||
return c.c.StderrPipe()
|
||||
}
|
||||
func (c *WslCmd) StdinPipe() (w io.WriteCloser, err error) {
|
||||
return c.c.StdinPipe()
|
||||
}
|
||||
func (c *WslCmd) StdoutPipe() (r io.ReadCloser, err error) {
|
||||
return c.c.StdoutPipe()
|
||||
}
|
||||
func (c *WslCmd) Wait() (err error) {
|
||||
c.wg.Add(1)
|
||||
c.once.Do(func() {
|
||||
c.waitErr = c.c.Wait()
|
||||
})
|
||||
c.wg.Done()
|
||||
c.wg.Wait()
|
||||
if c.waitErr != nil && c.waitErr.Error() == "not started" {
|
||||
c.once = new(sync.Once)
|
||||
return c.waitErr
|
||||
}
|
||||
return c.waitErr
|
||||
}
|
||||
func (c *WslCmd) GetProcess() *os.Process {
|
||||
return c.c.Process
|
||||
}
|
||||
|
||||
func (c *WslCmd) GetProcessState() *os.ProcessState {
|
||||
return c.c.ProcessState
|
||||
}
|
||||
|
||||
func (c *WslCmd) SetStdin(stdin io.Reader) {
|
||||
c.c.Stdin = stdin
|
||||
}
|
||||
|
||||
func (c *WslCmd) SetStdout(stdout io.Writer) {
|
||||
c.c.Stdout = stdout
|
||||
}
|
||||
|
||||
func (c *WslCmd) SetStderr(stderr io.Writer) {
|
||||
c.c.Stdout = stderr
|
||||
}
|
||||
|
||||
func GetDistroCmd(ctx context.Context, wslDistroName string, cmd string) (*WslCmd, error) {
|
||||
distros, err := RegisteredDistros(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, distro := range distros {
|
||||
if distro.Name() != wslDistroName {
|
||||
continue
|
||||
}
|
||||
wrappedDistro := Distro{distro}
|
||||
return wrappedDistro.WslCommand(ctx, cmd), nil
|
||||
}
|
||||
return nil, fmt.Errorf("wsl distro %s not found", wslDistroName)
|
||||
}
|
||||
|
||||
func GetDistro(ctx context.Context, wslDistroName WslName) (*Distro, error) {
|
||||
distros, err := RegisteredDistros(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, distro := range distros {
|
||||
if distro.Name() != wslDistroName.Distro {
|
||||
continue
|
||||
}
|
||||
wrappedDistro := Distro{distro}
|
||||
return &wrappedDistro, nil
|
||||
}
|
||||
return nil, fmt.Errorf("wsl distro %s not found", wslDistroName)
|
||||
}
|
494
pkg/wsl/wsl.go
Normal file
494
pkg/wsl/wsl.go
Normal file
@ -0,0 +1,494 @@
|
||||
// Copyright 2024, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package wsl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/wavetermdev/waveterm/pkg/userinput"
|
||||
"github.com/wavetermdev/waveterm/pkg/util/shellutil"
|
||||
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
||||
"github.com/wavetermdev/waveterm/pkg/waveobj"
|
||||
"github.com/wavetermdev/waveterm/pkg/wconfig"
|
||||
"github.com/wavetermdev/waveterm/pkg/wps"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||
)
|
||||
|
||||
const (
|
||||
Status_Init = "init"
|
||||
Status_Connecting = "connecting"
|
||||
Status_Connected = "connected"
|
||||
Status_Disconnected = "disconnected"
|
||||
Status_Error = "error"
|
||||
)
|
||||
|
||||
const DefaultConnectionTimeout = 60 * time.Second
|
||||
|
||||
var globalLock = &sync.Mutex{}
|
||||
var clientControllerMap = make(map[string]*WslConn)
|
||||
var activeConnCounter = &atomic.Int32{}
|
||||
|
||||
type WslConn struct {
|
||||
Lock *sync.Mutex
|
||||
Status string
|
||||
Name WslName
|
||||
Client *Distro
|
||||
SockName string
|
||||
DomainSockListener net.Listener
|
||||
ConnController *WslCmd
|
||||
Error string
|
||||
HasWaiter *atomic.Bool
|
||||
LastConnectTime int64
|
||||
ActiveConnNum int
|
||||
Context context.Context
|
||||
cancelFn func()
|
||||
}
|
||||
|
||||
type WslName struct {
|
||||
Distro string `json:"distro"`
|
||||
}
|
||||
|
||||
func GetAllConnStatus() []wshrpc.ConnStatus {
|
||||
globalLock.Lock()
|
||||
defer globalLock.Unlock()
|
||||
|
||||
var connStatuses []wshrpc.ConnStatus
|
||||
for _, conn := range clientControllerMap {
|
||||
connStatuses = append(connStatuses, conn.DeriveConnStatus())
|
||||
}
|
||||
return connStatuses
|
||||
}
|
||||
|
||||
func (conn *WslConn) DeriveConnStatus() wshrpc.ConnStatus {
|
||||
conn.Lock.Lock()
|
||||
defer conn.Lock.Unlock()
|
||||
return wshrpc.ConnStatus{
|
||||
Status: conn.Status,
|
||||
Connected: conn.Status == Status_Connected,
|
||||
Connection: conn.GetName(),
|
||||
HasConnected: (conn.LastConnectTime > 0),
|
||||
ActiveConnNum: conn.ActiveConnNum,
|
||||
Error: conn.Error,
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *WslConn) FireConnChangeEvent() {
|
||||
status := conn.DeriveConnStatus()
|
||||
event := wps.WaveEvent{
|
||||
Event: wps.Event_ConnChange,
|
||||
Scopes: []string{
|
||||
fmt.Sprintf("connection:%s", conn.GetName()),
|
||||
},
|
||||
Data: status,
|
||||
}
|
||||
log.Printf("sending event: %+#v", event)
|
||||
wps.Broker.Publish(event)
|
||||
}
|
||||
|
||||
func (conn *WslConn) Close() error {
|
||||
defer conn.FireConnChangeEvent()
|
||||
conn.WithLock(func() {
|
||||
if conn.Status == Status_Connected || conn.Status == Status_Connecting {
|
||||
// if status is init, disconnected, or error don't change it
|
||||
conn.Status = Status_Disconnected
|
||||
}
|
||||
conn.close_nolock()
|
||||
})
|
||||
// we must wait for the waiter to complete
|
||||
startTime := time.Now()
|
||||
for conn.HasWaiter.Load() {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
if time.Since(startTime) > 2*time.Second {
|
||||
return fmt.Errorf("timeout waiting for waiter to complete")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *WslConn) close_nolock() {
|
||||
// does not set status (that should happen at another level)
|
||||
if conn.DomainSockListener != nil {
|
||||
conn.DomainSockListener.Close()
|
||||
conn.DomainSockListener = nil
|
||||
}
|
||||
if conn.ConnController != nil {
|
||||
conn.cancelFn() // this suspends the conn controller
|
||||
conn.ConnController = nil
|
||||
}
|
||||
if conn.Client != nil {
|
||||
// conn.Client.Close() is not relevant here
|
||||
// we do not want to completely close the wsl in case
|
||||
// other applications are using it
|
||||
conn.Client = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *WslConn) GetDomainSocketName() string {
|
||||
conn.Lock.Lock()
|
||||
defer conn.Lock.Unlock()
|
||||
return conn.SockName
|
||||
}
|
||||
|
||||
func (conn *WslConn) GetStatus() string {
|
||||
conn.Lock.Lock()
|
||||
defer conn.Lock.Unlock()
|
||||
return conn.Status
|
||||
}
|
||||
|
||||
func (conn *WslConn) GetName() string {
|
||||
// no lock required because opts is immutable
|
||||
return "wsl://" + conn.Name.Distro
|
||||
}
|
||||
|
||||
/**
|
||||
* This function is does not set a listener for WslConn
|
||||
* It is still required in order to set SockName
|
||||
**/
|
||||
func (conn *WslConn) OpenDomainSocketListener() error {
|
||||
var allowed bool
|
||||
conn.WithLock(func() {
|
||||
if conn.Status != Status_Connecting {
|
||||
allowed = false
|
||||
} else {
|
||||
allowed = true
|
||||
}
|
||||
})
|
||||
if !allowed {
|
||||
return fmt.Errorf("cannot open domain socket for %q when status is %q", conn.GetName(), conn.GetStatus())
|
||||
}
|
||||
conn.WithLock(func() {
|
||||
conn.SockName = "~/.waveterm/wave-remote.sock"
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *WslConn) StartConnServer() error {
|
||||
var allowed bool
|
||||
conn.WithLock(func() {
|
||||
if conn.Status != Status_Connecting {
|
||||
allowed = false
|
||||
} else {
|
||||
allowed = true
|
||||
}
|
||||
})
|
||||
if !allowed {
|
||||
return fmt.Errorf("cannot start conn server for %q when status is %q", conn.GetName(), conn.GetStatus())
|
||||
}
|
||||
client := conn.GetClient()
|
||||
wshPath := GetWshPath(conn.Context, client)
|
||||
rpcCtx := wshrpc.RpcContext{
|
||||
ClientType: wshrpc.ClientType_ConnServer,
|
||||
Conn: conn.GetName(),
|
||||
}
|
||||
sockName := conn.GetDomainSocketName()
|
||||
jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, sockName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create jwt token for conn controller: %w", err)
|
||||
}
|
||||
shellPath, err := DetectShell(conn.Context, client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var cmdStr string
|
||||
if IsPowershell(shellPath) {
|
||||
cmdStr = fmt.Sprintf("$env:%s=\"%s\"; %s connserver --router", wshutil.WaveJwtTokenVarName, jwtToken, wshPath)
|
||||
} else {
|
||||
cmdStr = fmt.Sprintf("%s=\"%s\" %s connserver --router", wshutil.WaveJwtTokenVarName, jwtToken, wshPath)
|
||||
}
|
||||
log.Printf("starting conn controller: %s\n", cmdStr)
|
||||
cmd := client.WslCommand(conn.Context, cmdStr)
|
||||
pipeRead, pipeWrite := io.Pipe()
|
||||
inputPipeRead, inputPipeWrite := io.Pipe()
|
||||
cmd.SetStdout(pipeWrite)
|
||||
cmd.SetStderr(pipeWrite)
|
||||
cmd.SetStdin(inputPipeRead)
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to start conn controller: %w", err)
|
||||
}
|
||||
conn.WithLock(func() {
|
||||
conn.ConnController = cmd
|
||||
})
|
||||
// service the I/O
|
||||
go func() {
|
||||
// wait for termination, clear the controller
|
||||
defer conn.WithLock(func() {
|
||||
conn.ConnController = nil
|
||||
})
|
||||
waitErr := cmd.Wait()
|
||||
log.Printf("conn controller (%q) terminated: %v", conn.GetName(), waitErr)
|
||||
}()
|
||||
go func() {
|
||||
logName := fmt.Sprintf("conncontroller:%s", conn.GetName())
|
||||
wshutil.HandleStdIOClient(logName, pipeRead, inputPipeWrite)
|
||||
}()
|
||||
regCtx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancelFn()
|
||||
err = wshutil.DefaultRouter.WaitForRegister(regCtx, wshutil.MakeConnectionRouteId(rpcCtx.Conn))
|
||||
if err != nil {
|
||||
return fmt.Errorf("timeout waiting for connserver to register")
|
||||
}
|
||||
time.Sleep(300 * time.Millisecond) // TODO remove this sleep (but we need to wait until connserver is "ready")
|
||||
return nil
|
||||
}
|
||||
|
||||
type WshInstallOpts struct {
|
||||
Force bool
|
||||
NoUserPrompt bool
|
||||
}
|
||||
|
||||
func (conn *WslConn) CheckAndInstallWsh(ctx context.Context, clientDisplayName string, opts *WshInstallOpts) error {
|
||||
if opts == nil {
|
||||
opts = &WshInstallOpts{}
|
||||
}
|
||||
client := conn.GetClient()
|
||||
if client == nil {
|
||||
return fmt.Errorf("client is nil")
|
||||
}
|
||||
// check that correct wsh extensions are installed
|
||||
expectedVersion := fmt.Sprintf("wsh v%s", wavebase.WaveVersion)
|
||||
clientVersion, err := GetWshVersion(ctx, client)
|
||||
if err == nil && clientVersion == expectedVersion && !opts.Force {
|
||||
return nil
|
||||
}
|
||||
var queryText string
|
||||
var title string
|
||||
if opts.Force {
|
||||
queryText = fmt.Sprintf("ReInstalling Wave Shell Extensions (%s) on `%s`\n", wavebase.WaveVersion, clientDisplayName)
|
||||
title = "Install Wave Shell Extensions"
|
||||
} else if err != nil {
|
||||
queryText = fmt.Sprintf("Wave requires Wave Shell Extensions to be \n"+
|
||||
"installed on `%s` \n"+
|
||||
"to ensure a seamless experience. \n\n"+
|
||||
"Would you like to install them?", clientDisplayName)
|
||||
title = "Install Wave Shell Extensions"
|
||||
} else {
|
||||
// don't ask for upgrading the version
|
||||
opts.NoUserPrompt = true
|
||||
}
|
||||
if !opts.NoUserPrompt {
|
||||
request := &userinput.UserInputRequest{
|
||||
ResponseType: "confirm",
|
||||
QueryText: queryText,
|
||||
Title: title,
|
||||
Markdown: true,
|
||||
CheckBoxMsg: "Don't show me this again",
|
||||
}
|
||||
response, err := userinput.GetUserInput(ctx, request)
|
||||
if err != nil || !response.Confirm {
|
||||
return err
|
||||
}
|
||||
if response.CheckboxStat {
|
||||
meta := waveobj.MetaMapType{
|
||||
wconfig.ConfigKey_ConnAskBeforeWshInstall: false,
|
||||
}
|
||||
err := wconfig.SetBaseConfigValue(meta)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error setting conn:askbeforewshinstall value: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
log.Printf("attempting to install wsh to `%s`", clientDisplayName)
|
||||
clientOs, err := GetClientOs(ctx, client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
clientArch, err := GetClientArch(ctx, client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// attempt to install extension
|
||||
wshLocalPath := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch)
|
||||
err = CpHostToRemote(ctx, client, wshLocalPath, "~/.waveterm/bin/wsh")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("successfully installed wsh on %s\n", conn.GetName())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *WslConn) GetClient() *Distro {
|
||||
conn.Lock.Lock()
|
||||
defer conn.Lock.Unlock()
|
||||
return conn.Client
|
||||
}
|
||||
|
||||
func (conn *WslConn) Reconnect(ctx context.Context) error {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return conn.Connect(ctx)
|
||||
}
|
||||
|
||||
func (conn *WslConn) WaitForConnect(ctx context.Context) error {
|
||||
for {
|
||||
status := conn.DeriveConnStatus()
|
||||
if status.Status == Status_Connected {
|
||||
return nil
|
||||
}
|
||||
if status.Status == Status_Connecting {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("context timeout")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
continue
|
||||
}
|
||||
}
|
||||
if status.Status == Status_Init || status.Status == Status_Disconnected {
|
||||
return fmt.Errorf("disconnected")
|
||||
}
|
||||
if status.Status == Status_Error {
|
||||
return fmt.Errorf("error: %v", status.Error)
|
||||
}
|
||||
return fmt.Errorf("unknown status: %q", status.Status)
|
||||
}
|
||||
}
|
||||
|
||||
// does not return an error since that error is stored inside of WslConn
|
||||
func (conn *WslConn) Connect(ctx context.Context) error {
|
||||
var connectAllowed bool
|
||||
conn.WithLock(func() {
|
||||
if conn.Status == Status_Connecting || conn.Status == Status_Connected {
|
||||
connectAllowed = false
|
||||
} else {
|
||||
conn.Status = Status_Connecting
|
||||
conn.Error = ""
|
||||
connectAllowed = true
|
||||
}
|
||||
})
|
||||
log.Printf("Connect %s\n", conn.GetName())
|
||||
if !connectAllowed {
|
||||
return fmt.Errorf("cannot connect to %q when status is %q", conn.GetName(), conn.GetStatus())
|
||||
}
|
||||
conn.FireConnChangeEvent()
|
||||
err := conn.connectInternal(ctx)
|
||||
conn.WithLock(func() {
|
||||
if err != nil {
|
||||
conn.Status = Status_Error
|
||||
conn.Error = err.Error()
|
||||
conn.close_nolock()
|
||||
} else {
|
||||
conn.Status = Status_Connected
|
||||
conn.LastConnectTime = time.Now().UnixMilli()
|
||||
if conn.ActiveConnNum == 0 {
|
||||
conn.ActiveConnNum = int(activeConnCounter.Add(1))
|
||||
}
|
||||
}
|
||||
})
|
||||
conn.FireConnChangeEvent()
|
||||
return err
|
||||
}
|
||||
|
||||
func (conn *WslConn) WithLock(fn func()) {
|
||||
conn.Lock.Lock()
|
||||
defer conn.Lock.Unlock()
|
||||
fn()
|
||||
}
|
||||
|
||||
func (conn *WslConn) connectInternal(ctx context.Context) error {
|
||||
client, err := GetDistro(ctx, conn.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn.WithLock(func() {
|
||||
conn.Client = client
|
||||
})
|
||||
err = conn.OpenDomainSocketListener()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
config := wconfig.ReadFullConfig()
|
||||
installErr := conn.CheckAndInstallWsh(ctx, conn.GetName(), &WshInstallOpts{NoUserPrompt: !config.Settings.ConnAskBeforeWshInstall})
|
||||
if installErr != nil {
|
||||
return fmt.Errorf("conncontroller %s wsh install error: %v", conn.GetName(), installErr)
|
||||
}
|
||||
csErr := conn.StartConnServer()
|
||||
if csErr != nil {
|
||||
return fmt.Errorf("conncontroller %s start wsh connserver error: %v", conn.GetName(), csErr)
|
||||
}
|
||||
conn.HasWaiter.Store(true)
|
||||
go conn.waitForDisconnect()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *WslConn) waitForDisconnect() {
|
||||
defer conn.FireConnChangeEvent()
|
||||
defer conn.HasWaiter.Store(false)
|
||||
err := conn.ConnController.Wait()
|
||||
conn.WithLock(func() {
|
||||
// disconnects happen for a variety of reasons (like network, etc. and are typically transient)
|
||||
// so we just set the status to "disconnected" here (not error)
|
||||
// don't overwrite any existing error (or error status)
|
||||
if err != nil && conn.Error == "" {
|
||||
conn.Error = err.Error()
|
||||
}
|
||||
if conn.Status != Status_Error {
|
||||
conn.Status = Status_Disconnected
|
||||
}
|
||||
conn.close_nolock()
|
||||
})
|
||||
}
|
||||
|
||||
func getConnInternal(name string) *WslConn {
|
||||
globalLock.Lock()
|
||||
defer globalLock.Unlock()
|
||||
connName := WslName{Distro: name}
|
||||
rtn := clientControllerMap[name]
|
||||
if rtn == nil {
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
rtn = &WslConn{Lock: &sync.Mutex{}, Status: Status_Init, Name: connName, HasWaiter: &atomic.Bool{}, Context: ctx, cancelFn: cancelFn}
|
||||
clientControllerMap[name] = rtn
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func GetWslConn(ctx context.Context, name string, shouldConnect bool) *WslConn {
|
||||
conn := getConnInternal(name)
|
||||
if conn.Client == nil && shouldConnect {
|
||||
conn.Connect(ctx)
|
||||
}
|
||||
return conn
|
||||
}
|
||||
|
||||
// Convenience function for ensuring a connection is established
|
||||
func EnsureConnection(ctx context.Context, connName string) error {
|
||||
if connName == "" {
|
||||
return nil
|
||||
}
|
||||
conn := GetWslConn(ctx, connName, false)
|
||||
if conn == nil {
|
||||
return fmt.Errorf("connection not found: %s", connName)
|
||||
}
|
||||
connStatus := conn.DeriveConnStatus()
|
||||
switch connStatus.Status {
|
||||
case Status_Connected:
|
||||
return nil
|
||||
case Status_Connecting:
|
||||
return conn.WaitForConnect(ctx)
|
||||
case Status_Init, Status_Disconnected:
|
||||
return conn.Connect(ctx)
|
||||
case Status_Error:
|
||||
return fmt.Errorf("connection error: %s", connStatus.Error)
|
||||
default:
|
||||
return fmt.Errorf("unknown connection status %q", connStatus.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func DisconnectClient(connName string) error {
|
||||
conn := getConnInternal(connName)
|
||||
if conn == nil {
|
||||
return fmt.Errorf("client %q not found", connName)
|
||||
}
|
||||
err := conn.Close()
|
||||
return err
|
||||
}
|
Loading…
Reference in New Issue
Block a user