mirror of
https://github.com/wavetermdev/waveterm.git
synced 2024-12-21 16:38:23 +01:00
WSL Integration (#1031)
Adds support for connecting to local WSL installations on Windows. (also adds wshrpcmmultiproxy / connserver router)
This commit is contained in:
parent
4e86b67936
commit
8248637e00
2
.gitattributes
vendored
2
.gitattributes
vendored
@ -1 +1 @@
|
|||||||
* text=auto
|
* text=auto eol=lf
|
@ -159,11 +159,11 @@ func shutdownActivityUpdate() {
|
|||||||
|
|
||||||
func createMainWshClient() {
|
func createMainWshClient() {
|
||||||
rpc := wshserver.GetMainRpcClient()
|
rpc := wshserver.GetMainRpcClient()
|
||||||
wshutil.DefaultRouter.RegisterRoute(wshutil.DefaultRoute, rpc)
|
wshutil.DefaultRouter.RegisterRoute(wshutil.DefaultRoute, rpc, true)
|
||||||
wps.Broker.SetClient(wshutil.DefaultRouter)
|
wps.Broker.SetClient(wshutil.DefaultRouter)
|
||||||
localConnWsh := wshutil.MakeWshRpc(nil, nil, wshrpc.RpcContext{Conn: wshrpc.LocalConnName}, &wshremote.ServerImpl{})
|
localConnWsh := wshutil.MakeWshRpc(nil, nil, wshrpc.RpcContext{Conn: wshrpc.LocalConnName}, &wshremote.ServerImpl{})
|
||||||
go wshremote.RunSysInfoLoop(localConnWsh, wshrpc.LocalConnName)
|
go wshremote.RunSysInfoLoop(localConnWsh, wshrpc.LocalConnName)
|
||||||
wshutil.DefaultRouter.RegisterRoute(wshutil.MakeConnectionRouteId(wshrpc.LocalConnName), localConnWsh)
|
wshutil.DefaultRouter.RegisterRoute(wshutil.MakeConnectionRouteId(wshrpc.LocalConnName), localConnWsh, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
@ -5,6 +5,7 @@ package cmd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"github.com/wavetermdev/waveterm/pkg/remote"
|
"github.com/wavetermdev/waveterm/pkg/remote"
|
||||||
@ -25,17 +26,24 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func connStatus() error {
|
func connStatus() error {
|
||||||
resp, err := wshclient.ConnStatusCommand(RpcClient, nil)
|
var allResp []wshrpc.ConnStatus
|
||||||
|
sshResp, err := wshclient.ConnStatusCommand(RpcClient, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("getting connection status: %w", err)
|
return fmt.Errorf("getting ssh connection status: %w", err)
|
||||||
}
|
}
|
||||||
if len(resp) == 0 {
|
allResp = append(allResp, sshResp...)
|
||||||
|
wslResp, err := wshclient.WslStatusCommand(RpcClient, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("getting wsl connection status: %w", err)
|
||||||
|
}
|
||||||
|
allResp = append(allResp, wslResp...)
|
||||||
|
if len(allResp) == 0 {
|
||||||
WriteStdout("no connections\n")
|
WriteStdout("no connections\n")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
WriteStdout("%-30s %-12s\n", "connection", "status")
|
WriteStdout("%-30s %-12s\n", "connection", "status")
|
||||||
WriteStdout("----------------------------------------------\n")
|
WriteStdout("----------------------------------------------\n")
|
||||||
for _, conn := range resp {
|
for _, conn := range allResp {
|
||||||
str := fmt.Sprintf("%-30s %-12s", conn.Connection, conn.Status)
|
str := fmt.Sprintf("%-30s %-12s", conn.Connection, conn.Status)
|
||||||
if conn.Error != "" {
|
if conn.Error != "" {
|
||||||
str += fmt.Sprintf(" (%s)", conn.Error)
|
str += fmt.Sprintf(" (%s)", conn.Error)
|
||||||
@ -110,7 +118,7 @@ func connRun(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
connName = args[1]
|
connName = args[1]
|
||||||
_, err := remote.ParseOpts(connName)
|
_, err := remote.ParseOpts(connName)
|
||||||
if err != nil {
|
if err != nil && !strings.HasPrefix(connName, "wsl://") {
|
||||||
return fmt.Errorf("cannot parse connection name: %w", err)
|
return fmt.Errorf("cannot parse connection name: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,29 +4,186 @@
|
|||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/util/packetparser"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshremote"
|
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshremote"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
var serverCmd = &cobra.Command{
|
var serverCmd = &cobra.Command{
|
||||||
Use: "connserver",
|
Use: "connserver",
|
||||||
Hidden: true,
|
Hidden: true,
|
||||||
Short: "remote server to power wave blocks",
|
Short: "remote server to power wave blocks",
|
||||||
Args: cobra.NoArgs,
|
Args: cobra.NoArgs,
|
||||||
Run: serverRun,
|
RunE: serverRun,
|
||||||
PreRunE: preRunSetupRpcClient,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var connServerRouter bool
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
serverCmd.Flags().BoolVar(&connServerRouter, "router", false, "run in local router mode")
|
||||||
rootCmd.AddCommand(serverCmd)
|
rootCmd.AddCommand(serverCmd)
|
||||||
}
|
}
|
||||||
|
|
||||||
func serverRun(cmd *cobra.Command, args []string) {
|
func MakeRemoteUnixListener() (net.Listener, error) {
|
||||||
|
serverAddr := wavebase.GetRemoteDomainSocketName()
|
||||||
|
os.Remove(serverAddr) // ignore error
|
||||||
|
rtn, err := net.Listen("unix", serverAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error creating listener at %v: %v", serverAddr, err)
|
||||||
|
}
|
||||||
|
os.Chmod(serverAddr, 0700)
|
||||||
|
log.Printf("Server [unix-domain] listening on %s\n", serverAddr)
|
||||||
|
return rtn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleNewListenerConn(conn net.Conn, router *wshutil.WshRouter) {
|
||||||
|
var routeIdContainer atomic.Pointer[string]
|
||||||
|
proxy := wshutil.MakeRpcProxy()
|
||||||
|
go func() {
|
||||||
|
writeErr := wshutil.AdaptOutputChToStream(proxy.ToRemoteCh, conn)
|
||||||
|
if writeErr != nil {
|
||||||
|
log.Printf("error writing to domain socket: %v\n", writeErr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
// when input is closed, close the connection
|
||||||
|
defer func() {
|
||||||
|
conn.Close()
|
||||||
|
routeIdPtr := routeIdContainer.Load()
|
||||||
|
if routeIdPtr != nil && *routeIdPtr != "" {
|
||||||
|
router.UnregisterRoute(*routeIdPtr)
|
||||||
|
disposeMsg := &wshutil.RpcMessage{
|
||||||
|
Command: wshrpc.Command_Dispose,
|
||||||
|
Data: wshrpc.CommandDisposeData{
|
||||||
|
RouteId: *routeIdPtr,
|
||||||
|
},
|
||||||
|
Source: *routeIdPtr,
|
||||||
|
AuthToken: proxy.GetAuthToken(),
|
||||||
|
}
|
||||||
|
disposeBytes, _ := json.Marshal(disposeMsg)
|
||||||
|
router.InjectMessage(disposeBytes, *routeIdPtr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
wshutil.AdaptStreamToMsgCh(conn, proxy.FromRemoteCh)
|
||||||
|
}()
|
||||||
|
routeId, err := proxy.HandleClientProxyAuth(router)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("error handling client proxy auth: %v\n", err)
|
||||||
|
conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
router.RegisterRoute(routeId, proxy, false)
|
||||||
|
routeIdContainer.Store(&routeId)
|
||||||
|
}
|
||||||
|
|
||||||
|
func runListener(listener net.Listener, router *wshutil.WshRouter) {
|
||||||
|
defer func() {
|
||||||
|
log.Printf("listener closed, exiting\n")
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
wshutil.DoShutdown("", 1, true)
|
||||||
|
}()
|
||||||
|
for {
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("error accepting connection: %v\n", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
go handleNewListenerConn(conn, router)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupConnServerRpcClientWithRouter(router *wshutil.WshRouter) (*wshutil.WshRpc, error) {
|
||||||
|
jwtToken := os.Getenv(wshutil.WaveJwtTokenVarName)
|
||||||
|
if jwtToken == "" {
|
||||||
|
return nil, fmt.Errorf("no jwt token found for connserver")
|
||||||
|
}
|
||||||
|
rpcCtx, err := wshutil.ExtractUnverifiedRpcContext(jwtToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error extracting rpc context from %s: %v", wshutil.WaveJwtTokenVarName, err)
|
||||||
|
}
|
||||||
|
authRtn, err := router.HandleProxyAuth(jwtToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error handling proxy auth: %v", err)
|
||||||
|
}
|
||||||
|
inputCh := make(chan []byte, wshutil.DefaultInputChSize)
|
||||||
|
outputCh := make(chan []byte, wshutil.DefaultOutputChSize)
|
||||||
|
connServerClient := wshutil.MakeWshRpc(inputCh, outputCh, *rpcCtx, &wshremote.ServerImpl{LogWriter: os.Stdout})
|
||||||
|
connServerClient.SetAuthToken(authRtn.AuthToken)
|
||||||
|
router.RegisterRoute(authRtn.RouteId, connServerClient, false)
|
||||||
|
wshclient.RouteAnnounceCommand(connServerClient, nil)
|
||||||
|
return connServerClient, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func serverRunRouter() error {
|
||||||
|
router := wshutil.NewWshRouter()
|
||||||
|
termProxy := wshutil.MakeRpcProxy()
|
||||||
|
rawCh := make(chan []byte, wshutil.DefaultOutputChSize)
|
||||||
|
go packetparser.Parse(os.Stdin, termProxy.FromRemoteCh, rawCh)
|
||||||
|
go func() {
|
||||||
|
for msg := range termProxy.ToRemoteCh {
|
||||||
|
packetparser.WritePacket(os.Stdout, msg)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
// just ignore and drain the rawCh (stdin)
|
||||||
|
// when stdin is closed, shutdown
|
||||||
|
defer wshutil.DoShutdown("", 0, true)
|
||||||
|
for range rawCh {
|
||||||
|
// ignore
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
for msg := range termProxy.FromRemoteCh {
|
||||||
|
// send this to the router
|
||||||
|
router.InjectMessage(msg, wshutil.UpstreamRoute)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
router.SetUpstreamClient(termProxy)
|
||||||
|
// now set up the domain socket
|
||||||
|
unixListener, err := MakeRemoteUnixListener()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot create unix listener: %v", err)
|
||||||
|
}
|
||||||
|
client, err := setupConnServerRpcClientWithRouter(router)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error setting up connserver rpc client: %v", err)
|
||||||
|
}
|
||||||
|
go runListener(unixListener, router)
|
||||||
|
// run the sysinfo loop
|
||||||
|
wshremote.RunSysInfoLoop(client, client.GetRpcContext().Conn)
|
||||||
|
select {}
|
||||||
|
}
|
||||||
|
|
||||||
|
func serverRunNormal() error {
|
||||||
|
err := setupRpcClient(&wshremote.ServerImpl{LogWriter: os.Stdout})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
WriteStdout("running wsh connserver (%s)\n", RpcContext.Conn)
|
WriteStdout("running wsh connserver (%s)\n", RpcContext.Conn)
|
||||||
go wshremote.RunSysInfoLoop(RpcClient, RpcContext.Conn)
|
go wshremote.RunSysInfoLoop(RpcClient, RpcContext.Conn)
|
||||||
RpcClient.SetServerImpl(&wshremote.ServerImpl{LogWriter: os.Stdout})
|
|
||||||
|
|
||||||
select {} // run forever
|
select {} // run forever
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func serverRun(cmd *cobra.Command, args []string) error {
|
||||||
|
if connServerRouter {
|
||||||
|
return serverRunRouter()
|
||||||
|
} else {
|
||||||
|
return serverRunNormal()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
60
cmd/wsh/cmd/wshcmd-wsl.go
Normal file
60
cmd/wsh/cmd/wshcmd-wsl.go
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
// Copyright 2024, Command Line Inc.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/waveobj"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
|
||||||
|
)
|
||||||
|
|
||||||
|
var distroName string
|
||||||
|
|
||||||
|
var wslCmd = &cobra.Command{
|
||||||
|
Use: "wsl [-d <Distro>]",
|
||||||
|
Short: "connect this terminal to a local wsl connection",
|
||||||
|
Args: cobra.NoArgs,
|
||||||
|
Run: wslRun,
|
||||||
|
PreRunE: preRunSetupRpcClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
wslCmd.Flags().StringVarP(&distroName, "distribution", "d", "", "Run the specified distribution")
|
||||||
|
rootCmd.AddCommand(wslCmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
func wslRun(cmd *cobra.Command, args []string) {
|
||||||
|
var err error
|
||||||
|
if distroName == "" {
|
||||||
|
// get default distro from the host
|
||||||
|
distroName, err = wshclient.WslDefaultDistroCommand(RpcClient, nil)
|
||||||
|
if err != nil {
|
||||||
|
WriteStderr("[error] %s\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(distroName, "wsl://") {
|
||||||
|
distroName = "wsl://" + distroName
|
||||||
|
}
|
||||||
|
blockId := RpcContext.BlockId
|
||||||
|
if blockId == "" {
|
||||||
|
WriteStderr("[error] cannot determine blockid (not in JWT)\n")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data := wshrpc.CommandSetMetaData{
|
||||||
|
ORef: waveobj.MakeORef(waveobj.OType_Block, blockId),
|
||||||
|
Meta: map[string]any{
|
||||||
|
waveobj.MetaKey_Connection: distroName,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err = wshclient.SetMetaCommand(RpcClient, data, nil)
|
||||||
|
if err != nil {
|
||||||
|
WriteStderr("[error] setting switching connection: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
WriteStderr("switched connection to %q\n", distroName)
|
||||||
|
}
|
@ -521,6 +521,7 @@ const ChangeConnectionBlockModal = React.memo(
|
|||||||
const connStatusAtom = getConnStatusAtom(connection);
|
const connStatusAtom = getConnStatusAtom(connection);
|
||||||
const connStatus = jotai.useAtomValue(connStatusAtom);
|
const connStatus = jotai.useAtomValue(connStatusAtom);
|
||||||
const [connList, setConnList] = React.useState<Array<string>>([]);
|
const [connList, setConnList] = React.useState<Array<string>>([]);
|
||||||
|
const [wslList, setWslList] = React.useState<Array<string>>([]);
|
||||||
const allConnStatus = jotai.useAtomValue(atoms.allConnStatus);
|
const allConnStatus = jotai.useAtomValue(atoms.allConnStatus);
|
||||||
const [rowIndex, setRowIndex] = React.useState(0);
|
const [rowIndex, setRowIndex] = React.useState(0);
|
||||||
const connStatusMap = new Map<string, ConnStatus>();
|
const connStatusMap = new Map<string, ConnStatus>();
|
||||||
@ -540,6 +541,18 @@ const ChangeConnectionBlockModal = React.memo(
|
|||||||
prtn.then((newConnList) => {
|
prtn.then((newConnList) => {
|
||||||
setConnList(newConnList ?? []);
|
setConnList(newConnList ?? []);
|
||||||
}).catch((e) => console.log("unable to load conn list from backend. using blank list: ", e));
|
}).catch((e) => console.log("unable to load conn list from backend. using blank list: ", e));
|
||||||
|
const p2rtn = RpcApi.WslListCommand(TabRpcClient, { timeout: 2000 });
|
||||||
|
p2rtn
|
||||||
|
.then((newWslList) => {
|
||||||
|
console.log(newWslList);
|
||||||
|
setWslList(newWslList ?? []);
|
||||||
|
})
|
||||||
|
.catch((e) => {
|
||||||
|
// removing this log and failing silentyly since it will happen
|
||||||
|
// if a system isn't using the wsl. and would happen every time the
|
||||||
|
// typeahead was opened. good candidate for verbose log level.
|
||||||
|
//console.log("unable to load wsl list from backend. using blank list: ", e)
|
||||||
|
});
|
||||||
}, [changeConnModalOpen, setConnList]);
|
}, [changeConnModalOpen, setConnList]);
|
||||||
|
|
||||||
const changeConnection = React.useCallback(
|
const changeConnection = React.useCallback(
|
||||||
@ -588,6 +601,15 @@ const ChangeConnectionBlockModal = React.memo(
|
|||||||
filteredList.push(conn);
|
filteredList.push(conn);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
const filteredWslList: Array<string> = [];
|
||||||
|
for (const conn of wslList) {
|
||||||
|
if (conn === connSelected) {
|
||||||
|
createNew = false;
|
||||||
|
}
|
||||||
|
if (conn.includes(connSelected)) {
|
||||||
|
filteredWslList.push(conn);
|
||||||
|
}
|
||||||
|
}
|
||||||
// priority handles special suggestions when necessary
|
// priority handles special suggestions when necessary
|
||||||
// for instance, when reconnecting
|
// for instance, when reconnecting
|
||||||
const newConnectionSuggestion: SuggestionConnectionItem = {
|
const newConnectionSuggestion: SuggestionConnectionItem = {
|
||||||
@ -637,6 +659,20 @@ const ChangeConnectionBlockModal = React.memo(
|
|||||||
label: localName,
|
label: localName,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
for (const wslConn of filteredWslList) {
|
||||||
|
const connStatus = connStatusMap.get(wslConn);
|
||||||
|
const connColorNum = computeConnColorNum(connStatus);
|
||||||
|
localSuggestion.items.push({
|
||||||
|
status: "connected",
|
||||||
|
icon: "arrow-right-arrow-left",
|
||||||
|
iconColor:
|
||||||
|
connStatus?.status == "connected"
|
||||||
|
? `var(--conn-icon-color-${connColorNum})`
|
||||||
|
: "var(--grey-text-color)",
|
||||||
|
value: "wsl://" + wslConn,
|
||||||
|
label: "wsl://" + wslConn,
|
||||||
|
});
|
||||||
|
}
|
||||||
const remoteItems = filteredList.map((connName) => {
|
const remoteItems = filteredList.map((connName) => {
|
||||||
const connStatus = connStatusMap.get(connName);
|
const connStatus = connStatusMap.get(connName);
|
||||||
const connColorNum = computeConnColorNum(connStatus);
|
const connColorNum = computeConnColorNum(connStatus);
|
||||||
|
@ -72,6 +72,11 @@ class RpcApiType {
|
|||||||
return client.wshRpcCall("deleteblock", data, opts);
|
return client.wshRpcCall("deleteblock", data, opts);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// command "dispose" [call]
|
||||||
|
DisposeCommand(client: WshClient, data: CommandDisposeData, opts?: RpcOpts): Promise<void> {
|
||||||
|
return client.wshRpcCall("dispose", data, opts);
|
||||||
|
}
|
||||||
|
|
||||||
// command "eventpublish" [call]
|
// command "eventpublish" [call]
|
||||||
EventPublishCommand(client: WshClient, data: WaveEvent, opts?: RpcOpts): Promise<void> {
|
EventPublishCommand(client: WshClient, data: WaveEvent, opts?: RpcOpts): Promise<void> {
|
||||||
return client.wshRpcCall("eventpublish", data, opts);
|
return client.wshRpcCall("eventpublish", data, opts);
|
||||||
@ -237,6 +242,21 @@ class RpcApiType {
|
|||||||
return client.wshRpcCall("webselector", data, opts);
|
return client.wshRpcCall("webselector", data, opts);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// command "wsldefaultdistro" [call]
|
||||||
|
WslDefaultDistroCommand(client: WshClient, opts?: RpcOpts): Promise<string> {
|
||||||
|
return client.wshRpcCall("wsldefaultdistro", null, opts);
|
||||||
|
}
|
||||||
|
|
||||||
|
// command "wsllist" [call]
|
||||||
|
WslListCommand(client: WshClient, opts?: RpcOpts): Promise<string[]> {
|
||||||
|
return client.wshRpcCall("wsllist", null, opts);
|
||||||
|
}
|
||||||
|
|
||||||
|
// command "wslstatus" [call]
|
||||||
|
WslStatusCommand(client: WshClient, opts?: RpcOpts): Promise<ConnStatus[]> {
|
||||||
|
return client.wshRpcCall("wslstatus", null, opts);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export const RpcApi = new RpcApiType();
|
export const RpcApi = new RpcApiType();
|
||||||
|
7
frontend/types/gotypes.d.ts
vendored
7
frontend/types/gotypes.d.ts
vendored
@ -63,6 +63,7 @@ declare global {
|
|||||||
// wshrpc.CommandAuthenticateRtnData
|
// wshrpc.CommandAuthenticateRtnData
|
||||||
type CommandAuthenticateRtnData = {
|
type CommandAuthenticateRtnData = {
|
||||||
routeid: string;
|
routeid: string;
|
||||||
|
authtoken?: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
// wshrpc.CommandBlockInputData
|
// wshrpc.CommandBlockInputData
|
||||||
@ -100,6 +101,11 @@ declare global {
|
|||||||
blockid: string;
|
blockid: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// wshrpc.CommandDisposeData
|
||||||
|
type CommandDisposeData = {
|
||||||
|
routeid: string;
|
||||||
|
};
|
||||||
|
|
||||||
// wshrpc.CommandEventReadHistoryData
|
// wshrpc.CommandEventReadHistoryData
|
||||||
type CommandEventReadHistoryData = {
|
type CommandEventReadHistoryData = {
|
||||||
event: string;
|
event: string;
|
||||||
@ -416,6 +422,7 @@ declare global {
|
|||||||
resid?: string;
|
resid?: string;
|
||||||
timeout?: number;
|
timeout?: number;
|
||||||
route?: string;
|
route?: string;
|
||||||
|
authtoken?: string;
|
||||||
source?: string;
|
source?: string;
|
||||||
cont?: boolean;
|
cont?: boolean;
|
||||||
cancel?: boolean;
|
cancel?: boolean;
|
||||||
|
3
go.mod
3
go.mod
@ -21,6 +21,7 @@ require (
|
|||||||
github.com/shirou/gopsutil/v4 v4.24.9
|
github.com/shirou/gopsutil/v4 v4.24.9
|
||||||
github.com/skeema/knownhosts v1.3.0
|
github.com/skeema/knownhosts v1.3.0
|
||||||
github.com/spf13/cobra v1.8.1
|
github.com/spf13/cobra v1.8.1
|
||||||
|
github.com/ubuntu/gowsl v0.0.0-20240906163211-049fd49bd93b
|
||||||
github.com/wavetermdev/htmltoken v0.1.0
|
github.com/wavetermdev/htmltoken v0.1.0
|
||||||
golang.org/x/crypto v0.28.0
|
golang.org/x/crypto v0.28.0
|
||||||
golang.org/x/sys v0.26.0
|
golang.org/x/sys v0.26.0
|
||||||
@ -36,9 +37,11 @@ require (
|
|||||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||||
|
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||||
github.com/spf13/pflag v1.0.5 // indirect
|
github.com/spf13/pflag v1.0.5 // indirect
|
||||||
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||||
github.com/tklauser/numcpus v0.6.1 // indirect
|
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||||
|
github.com/ubuntu/decorate v0.0.0-20230125165522-2d5b0a9bb117 // indirect
|
||||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||||
go.uber.org/atomic v1.7.0 // indirect
|
go.uber.org/atomic v1.7.0 // indirect
|
||||||
golang.org/x/net v0.29.0 // indirect
|
golang.org/x/net v0.29.0 // indirect
|
||||||
|
11
go.sum
11
go.sum
@ -1,5 +1,7 @@
|
|||||||
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||||
|
github.com/0xrawsec/golang-utils v1.3.2 h1:ww4jrtHRSnX9xrGzJYbalx5nXoZewy4zPxiY+ubJgtg=
|
||||||
|
github.com/0xrawsec/golang-utils v1.3.2/go.mod h1:m7AzHXgdSAkFCD9tWWsApxNVxMlyy7anpPVOyT/yM7E=
|
||||||
github.com/alexflint/go-filemutex v1.3.0 h1:LgE+nTUWnQCyRKbpoceKZsPQbs84LivvgwUymZXdOcM=
|
github.com/alexflint/go-filemutex v1.3.0 h1:LgE+nTUWnQCyRKbpoceKZsPQbs84LivvgwUymZXdOcM=
|
||||||
github.com/alexflint/go-filemutex v1.3.0/go.mod h1:U0+VA/i30mGBlLCrFPGtTe9y6wGQfNAWPBTekHQ+c8A=
|
github.com/alexflint/go-filemutex v1.3.0/go.mod h1:U0+VA/i30mGBlLCrFPGtTe9y6wGQfNAWPBTekHQ+c8A=
|
||||||
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||||
@ -62,6 +64,8 @@ github.com/sawka/txwrap v0.2.0 h1:V3LfvKVLULxcYSxdMguLwFyQFMEU9nFDJopg0ZkL+94=
|
|||||||
github.com/sawka/txwrap v0.2.0/go.mod h1:wwQ2SQiN4U+6DU/iVPhbvr7OzXAtgZlQCIGuvOswEfA=
|
github.com/sawka/txwrap v0.2.0/go.mod h1:wwQ2SQiN4U+6DU/iVPhbvr7OzXAtgZlQCIGuvOswEfA=
|
||||||
github.com/shirou/gopsutil/v4 v4.24.9 h1:KIV+/HaHD5ka5f570RZq+2SaeFsb/pq+fp2DGNWYoOI=
|
github.com/shirou/gopsutil/v4 v4.24.9 h1:KIV+/HaHD5ka5f570RZq+2SaeFsb/pq+fp2DGNWYoOI=
|
||||||
github.com/shirou/gopsutil/v4 v4.24.9/go.mod h1:3fkaHNeYsUFCGZ8+9vZVWtbyM1k2eRnlL+bWO8Bxa/Q=
|
github.com/shirou/gopsutil/v4 v4.24.9/go.mod h1:3fkaHNeYsUFCGZ8+9vZVWtbyM1k2eRnlL+bWO8Bxa/Q=
|
||||||
|
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||||
|
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||||
github.com/skeema/knownhosts v1.3.0 h1:AM+y0rI04VksttfwjkSTNQorvGqmwATnvnAHpSgc0LY=
|
github.com/skeema/knownhosts v1.3.0 h1:AM+y0rI04VksttfwjkSTNQorvGqmwATnvnAHpSgc0LY=
|
||||||
github.com/skeema/knownhosts v1.3.0/go.mod h1:sPINvnADmT/qYH1kfv+ePMmOBTH6Tbl7b5LvTDjFK7M=
|
github.com/skeema/knownhosts v1.3.0/go.mod h1:sPINvnADmT/qYH1kfv+ePMmOBTH6Tbl7b5LvTDjFK7M=
|
||||||
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
|
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
|
||||||
@ -71,12 +75,17 @@ github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An
|
|||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||||
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||||
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
||||||
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
|
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
|
||||||
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
|
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
|
||||||
|
github.com/ubuntu/decorate v0.0.0-20230125165522-2d5b0a9bb117 h1:XQpsQG5lqRJlx4mUVHcJvyyc1rdTI9nHvwrdfcuy8aM=
|
||||||
|
github.com/ubuntu/decorate v0.0.0-20230125165522-2d5b0a9bb117/go.mod h1:mx0TjbqsaDD9DUT5gA1s3hw47U6RIbbIBfvGzR85K0g=
|
||||||
|
github.com/ubuntu/gowsl v0.0.0-20240906163211-049fd49bd93b h1:wFBKF5k5xbJQU8bYgcSoQ/ScvmYyq6KHUabAuVUjOWM=
|
||||||
|
github.com/ubuntu/gowsl v0.0.0-20240906163211-049fd49bd93b/go.mod h1:N1CYNinssZru+ikvYTgVbVeSi21thHUTCoJ9xMvWe+s=
|
||||||
github.com/wavetermdev/htmltoken v0.1.0 h1:RMdA9zTfnYa5jRC4RRG3XNoV5NOP8EDxpaVPjuVz//Q=
|
github.com/wavetermdev/htmltoken v0.1.0 h1:RMdA9zTfnYa5jRC4RRG3XNoV5NOP8EDxpaVPjuVz//Q=
|
||||||
github.com/wavetermdev/htmltoken v0.1.0/go.mod h1:5FM0XV6zNYiNza2iaTcFGj+hnMtgqumFHO31Z8euquk=
|
github.com/wavetermdev/htmltoken v0.1.0/go.mod h1:5FM0XV6zNYiNza2iaTcFGj+hnMtgqumFHO31Z8euquk=
|
||||||
github.com/wavetermdev/ssh_config v0.0.0-20240306041034-17e2087ebde2 h1:onqZrJVap1sm15AiIGTfWzdr6cEF0KdtddeuuOVhzyY=
|
github.com/wavetermdev/ssh_config v0.0.0-20240306041034-17e2087ebde2 h1:onqZrJVap1sm15AiIGTfWzdr6cEF0KdtddeuuOVhzyY=
|
||||||
@ -91,6 +100,7 @@ golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo=
|
|||||||
golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0=
|
golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0=
|
||||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220721230656-c6bc011c0c49/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220721230656-c6bc011c0c49/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
@ -102,5 +112,6 @@ golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M=
|
|||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"log"
|
"log"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -24,6 +25,7 @@ import (
|
|||||||
"github.com/wavetermdev/waveterm/pkg/wps"
|
"github.com/wavetermdev/waveterm/pkg/wps"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/wsl"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wstore"
|
"github.com/wavetermdev/waveterm/pkg/wstore"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -262,7 +264,30 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
|
|||||||
return fmt.Errorf("unknown controller type %q", bc.ControllerType)
|
return fmt.Errorf("unknown controller type %q", bc.ControllerType)
|
||||||
}
|
}
|
||||||
var shellProc *shellexec.ShellProc
|
var shellProc *shellexec.ShellProc
|
||||||
if remoteName != "" {
|
if strings.HasPrefix(remoteName, "wsl://") {
|
||||||
|
wslName := strings.TrimPrefix(remoteName, "wsl://")
|
||||||
|
credentialCtx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second)
|
||||||
|
defer cancelFunc()
|
||||||
|
|
||||||
|
wslConn := wsl.GetWslConn(credentialCtx, wslName, false)
|
||||||
|
connStatus := wslConn.DeriveConnStatus()
|
||||||
|
if connStatus.Status != conncontroller.Status_Connected {
|
||||||
|
return fmt.Errorf("not connected, cannot start shellproc")
|
||||||
|
}
|
||||||
|
|
||||||
|
// create jwt
|
||||||
|
if !blockMeta.GetBool(waveobj.MetaKey_CmdNoWsh, false) {
|
||||||
|
jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId, Conn: wslConn.GetName()}, wslConn.GetDomainSocketName())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error making jwt token: %w", err)
|
||||||
|
}
|
||||||
|
cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr
|
||||||
|
}
|
||||||
|
shellProc, err = shellexec.StartWslShellProc(ctx, rc.TermSize, cmdStr, cmdOpts, wslConn)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if remoteName != "" {
|
||||||
credentialCtx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second)
|
credentialCtx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second)
|
||||||
defer cancelFunc()
|
defer cancelFunc()
|
||||||
|
|
||||||
@ -325,7 +350,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
|
|||||||
// we don't need to authenticate this wshProxy since it is coming direct
|
// we don't need to authenticate this wshProxy since it is coming direct
|
||||||
wshProxy := wshutil.MakeRpcProxy()
|
wshProxy := wshutil.MakeRpcProxy()
|
||||||
wshProxy.SetRpcContext(&wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId})
|
wshProxy.SetRpcContext(&wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId})
|
||||||
wshutil.DefaultRouter.RegisterRoute(wshutil.MakeControllerRouteId(bc.BlockId), wshProxy)
|
wshutil.DefaultRouter.RegisterRoute(wshutil.MakeControllerRouteId(bc.BlockId), wshProxy, true)
|
||||||
ptyBuffer := wshutil.MakePtyBuffer(wshutil.WaveOSCPrefix, bc.ShellProc.Cmd, wshProxy.FromRemoteCh)
|
ptyBuffer := wshutil.MakePtyBuffer(wshutil.WaveOSCPrefix, bc.ShellProc.Cmd, wshProxy.FromRemoteCh)
|
||||||
go func() {
|
go func() {
|
||||||
// handles regular output from the pty (goes to the blockfile and xterm)
|
// handles regular output from the pty (goes to the blockfile and xterm)
|
||||||
@ -494,6 +519,15 @@ func CheckConnStatus(blockId string) error {
|
|||||||
if connName == "" {
|
if connName == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
if strings.HasPrefix(connName, "wsl://") {
|
||||||
|
distroName := strings.TrimPrefix(connName, "wsl://")
|
||||||
|
conn := wsl.GetWslConn(context.Background(), distroName, false)
|
||||||
|
connStatus := conn.DeriveConnStatus()
|
||||||
|
if connStatus.Status != conncontroller.Status_Connected {
|
||||||
|
return fmt.Errorf("not connected: %s", connStatus.Status)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
opts, err := remote.ParseOpts(connName)
|
opts, err := remote.ParseOpts(connName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error parsing connection name: %w", err)
|
return fmt.Errorf("error parsing connection name: %w", err)
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
// Copyright 2024, Command Line Inc.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
package remote
|
package remote
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
@ -17,6 +17,7 @@ import (
|
|||||||
"github.com/wavetermdev/waveterm/pkg/wcore"
|
"github.com/wavetermdev/waveterm/pkg/wcore"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wlayout"
|
"github.com/wavetermdev/waveterm/pkg/wlayout"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/wsl"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wstore"
|
"github.com/wavetermdev/waveterm/pkg/wstore"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -77,7 +78,9 @@ func (cs *ClientService) MakeWindow(ctx context.Context) (*waveobj.Window, error
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (cs *ClientService) GetAllConnStatus(ctx context.Context) ([]wshrpc.ConnStatus, error) {
|
func (cs *ClientService) GetAllConnStatus(ctx context.Context) ([]wshrpc.ConnStatus, error) {
|
||||||
return conncontroller.GetAllConnStatus(), nil
|
sshStatuses := conncontroller.GetAllConnStatus()
|
||||||
|
wslStatuses := wsl.GetAllConnStatus()
|
||||||
|
return append(sshStatuses, wslStatuses...), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// moves the window to the front of the windowId stack
|
// moves the window to the front of the windowId stack
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/creack/pty"
|
"github.com/creack/pty"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/wsl"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -129,3 +130,42 @@ func (sw SessionWrap) StderrPipe() (io.ReadCloser, error) {
|
|||||||
func (sw SessionWrap) SetSize(h int, w int) error {
|
func (sw SessionWrap) SetSize(h int, w int) error {
|
||||||
return sw.Session.WindowChange(h, w)
|
return sw.Session.WindowChange(h, w)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type WslCmdWrap struct {
|
||||||
|
*wsl.WslCmd
|
||||||
|
Tty pty.Tty
|
||||||
|
pty.Pty
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wcw WslCmdWrap) Kill() {
|
||||||
|
wcw.Tty.Close()
|
||||||
|
wcw.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wcw WslCmdWrap) KillGraceful(timeout time.Duration) {
|
||||||
|
process := wcw.WslCmd.GetProcess()
|
||||||
|
if process == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
processState := wcw.WslCmd.GetProcessState()
|
||||||
|
if processState != nil && processState.Exited() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
process.Signal(os.Interrupt)
|
||||||
|
go func() {
|
||||||
|
time.Sleep(timeout)
|
||||||
|
process := wcw.WslCmd.GetProcess()
|
||||||
|
processState := wcw.WslCmd.GetProcessState()
|
||||||
|
if processState == nil || !processState.Exited() {
|
||||||
|
process.Kill() // force kill if it is already not exited
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* SetSize does nothing for WslCmdWrap as there
|
||||||
|
* is no pty to manage.
|
||||||
|
**/
|
||||||
|
func (wcw WslCmdWrap) SetSize(w int, h int) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -5,6 +5,7 @@ package shellexec
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
@ -25,6 +26,7 @@ import (
|
|||||||
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
||||||
"github.com/wavetermdev/waveterm/pkg/waveobj"
|
"github.com/wavetermdev/waveterm/pkg/waveobj"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/wsl"
|
||||||
)
|
)
|
||||||
|
|
||||||
const DefaultGracefulKillWait = 400 * time.Millisecond
|
const DefaultGracefulKillWait = 400 * time.Millisecond
|
||||||
@ -141,6 +143,96 @@ func (pp *PipePty) WriteString(s string) (n int, err error) {
|
|||||||
return pp.Write([]byte(s))
|
return pp.Write([]byte(s))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func StartWslShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *wsl.WslConn) (*ShellProc, error) {
|
||||||
|
client := conn.GetClient()
|
||||||
|
shellPath := cmdOpts.ShellPath
|
||||||
|
if shellPath == "" {
|
||||||
|
remoteShellPath, err := wsl.DetectShell(conn.Context, client)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
shellPath = remoteShellPath
|
||||||
|
}
|
||||||
|
var shellOpts []string
|
||||||
|
log.Printf("detected shell: %s", shellPath)
|
||||||
|
|
||||||
|
err := wsl.InstallClientRcFiles(conn.Context, client)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("error installing rc files: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
homeDir := wsl.GetHomeDir(conn.Context, client)
|
||||||
|
shellOpts = append(shellOpts, "~", "-d", client.Name())
|
||||||
|
|
||||||
|
if isZshShell(shellPath) {
|
||||||
|
shellOpts = append(shellOpts, fmt.Sprintf(`ZDOTDIR="%s/.waveterm/%s"`, homeDir, shellutil.ZshIntegrationDir))
|
||||||
|
}
|
||||||
|
var subShellOpts []string
|
||||||
|
|
||||||
|
if cmdStr == "" {
|
||||||
|
/* transform command in order to inject environment vars */
|
||||||
|
if isBashShell(shellPath) {
|
||||||
|
log.Printf("recognized as bash shell")
|
||||||
|
// add --rcfile
|
||||||
|
// cant set -l or -i with --rcfile
|
||||||
|
subShellOpts = append(subShellOpts, "--rcfile", fmt.Sprintf(`%s/.waveterm/%s/.bashrc`, homeDir, shellutil.BashIntegrationDir))
|
||||||
|
} else if isFishShell(shellPath) {
|
||||||
|
carg := fmt.Sprintf(`"set -x PATH \"%s\"/.waveterm/%s $PATH"`, homeDir, shellutil.WaveHomeBinDir)
|
||||||
|
subShellOpts = append(subShellOpts, "-C", carg)
|
||||||
|
} else if wsl.IsPowershell(shellPath) {
|
||||||
|
// powershell is weird about quoted path executables and requires an ampersand first
|
||||||
|
shellPath = "& " + shellPath
|
||||||
|
subShellOpts = append(subShellOpts, "-ExecutionPolicy", "Bypass", "-NoExit", "-File", homeDir+fmt.Sprintf("/.waveterm/%s/wavepwsh.ps1", shellutil.PwshIntegrationDir))
|
||||||
|
} else {
|
||||||
|
if cmdOpts.Login {
|
||||||
|
subShellOpts = append(subShellOpts, "-l")
|
||||||
|
}
|
||||||
|
if cmdOpts.Interactive {
|
||||||
|
subShellOpts = append(subShellOpts, "-i")
|
||||||
|
}
|
||||||
|
// can't set environment vars this way
|
||||||
|
// will try to do later if possible
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
shellPath = cmdStr
|
||||||
|
if cmdOpts.Login {
|
||||||
|
subShellOpts = append(subShellOpts, "-l")
|
||||||
|
}
|
||||||
|
if cmdOpts.Interactive {
|
||||||
|
subShellOpts = append(subShellOpts, "-i")
|
||||||
|
}
|
||||||
|
subShellOpts = append(subShellOpts, "-c", cmdStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
jwtToken, ok := cmdOpts.Env[wshutil.WaveJwtTokenVarName]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("no jwt token provided to connection")
|
||||||
|
}
|
||||||
|
if remote.IsPowershell(shellPath) {
|
||||||
|
shellOpts = append(shellOpts, "--", fmt.Sprintf(`$env:%s=%s;`, wshutil.WaveJwtTokenVarName, jwtToken))
|
||||||
|
} else {
|
||||||
|
shellOpts = append(shellOpts, "--", fmt.Sprintf(`%s=%s`, wshutil.WaveJwtTokenVarName, jwtToken))
|
||||||
|
}
|
||||||
|
shellOpts = append(shellOpts, shellPath)
|
||||||
|
shellOpts = append(shellOpts, subShellOpts...)
|
||||||
|
log.Printf("full cmd is: %s %s", "wsl.exe", strings.Join(shellOpts, " "))
|
||||||
|
|
||||||
|
ecmd := exec.Command("wsl.exe", shellOpts...)
|
||||||
|
if termSize.Rows == 0 || termSize.Cols == 0 {
|
||||||
|
termSize.Rows = shellutil.DefaultTermRows
|
||||||
|
termSize.Cols = shellutil.DefaultTermCols
|
||||||
|
}
|
||||||
|
if termSize.Rows <= 0 || termSize.Cols <= 0 {
|
||||||
|
return nil, fmt.Errorf("invalid term size: %v", termSize)
|
||||||
|
}
|
||||||
|
cmdPty, err := pty.StartWithSize(ecmd, &pty.Winsize{Rows: uint16(termSize.Rows), Cols: uint16(termSize.Cols)})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &ShellProc{Cmd: CmdWrap{ecmd, cmdPty}, ConnName: conn.GetName(), CloseOnce: &sync.Once{}, DoneCh: make(chan any)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func StartRemoteShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *conncontroller.SSHConn) (*ShellProc, error) {
|
func StartRemoteShellProc(termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *conncontroller.SSHConn) (*ShellProc, error) {
|
||||||
client := conn.GetClient()
|
client := conn.GetClient()
|
||||||
shellPath := cmdOpts.ShellPath
|
shellPath := cmdOpts.ShellPath
|
||||||
|
58
pkg/util/packetparser/packetparser.go
Normal file
58
pkg/util/packetparser/packetparser.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
// Copyright 2024, Command Line Inc.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package packetparser
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PacketParser struct {
|
||||||
|
Reader io.Reader
|
||||||
|
Ch chan []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func Parse(input io.Reader, packetCh chan []byte, rawCh chan []byte) error {
|
||||||
|
bufReader := bufio.NewReader(input)
|
||||||
|
defer close(packetCh)
|
||||||
|
defer close(rawCh)
|
||||||
|
for {
|
||||||
|
line, err := bufReader.ReadBytes('\n')
|
||||||
|
if err == io.EOF {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(line) <= 1 {
|
||||||
|
// just a blank line
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if bytes.HasPrefix(line, []byte{'#', '#', 'N', '{'}) && bytes.HasSuffix(line, []byte{'}', '\n'}) {
|
||||||
|
// strip off the leading "##" and trailing "\n" (single byte)
|
||||||
|
packetCh <- line[3 : len(line)-1]
|
||||||
|
} else {
|
||||||
|
rawCh <- line
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WritePacket(output io.Writer, packet []byte) error {
|
||||||
|
if len(packet) < 2 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if packet[0] != '{' || packet[len(packet)-1] != '}' {
|
||||||
|
return fmt.Errorf("invalid packet, must start with '{' and end with '}'")
|
||||||
|
}
|
||||||
|
fullPacket := make([]byte, 0, len(packet)+5)
|
||||||
|
// we add the extra newline to make sure the ## appears at the beginning of the line
|
||||||
|
// since writer isn't buffered, we want to send this all at once
|
||||||
|
fullPacket = append(fullPacket, '\n', '#', '#', 'N')
|
||||||
|
fullPacket = append(fullPacket, packet...)
|
||||||
|
fullPacket = append(fullPacket, '\n')
|
||||||
|
_, err := output.Write(fullPacket)
|
||||||
|
return err
|
||||||
|
}
|
@ -30,10 +30,13 @@ const WaveDataHomeEnvVar = "WAVETERM_DATA_HOME"
|
|||||||
const WaveDevVarName = "WAVETERM_DEV"
|
const WaveDevVarName = "WAVETERM_DEV"
|
||||||
const WaveLockFile = "wave.lock"
|
const WaveLockFile = "wave.lock"
|
||||||
const DomainSocketBaseName = "wave.sock"
|
const DomainSocketBaseName = "wave.sock"
|
||||||
|
const RemoteDomainSocketBaseName = "wave-remote.sock"
|
||||||
const WaveDBDir = "db"
|
const WaveDBDir = "db"
|
||||||
const JwtSecret = "waveterm" // TODO generate and store this
|
const JwtSecret = "waveterm" // TODO generate and store this
|
||||||
const ConfigDir = "config"
|
const ConfigDir = "config"
|
||||||
|
|
||||||
|
var RemoteWaveHome = ExpandHomeDirSafe("~/.waveterm")
|
||||||
|
|
||||||
const WaveAppPathVarName = "WAVETERM_APP_PATH"
|
const WaveAppPathVarName = "WAVETERM_APP_PATH"
|
||||||
const AppPathBinDir = "bin"
|
const AppPathBinDir = "bin"
|
||||||
|
|
||||||
@ -101,6 +104,10 @@ func GetDomainSocketName() string {
|
|||||||
return filepath.Join(GetWaveDataDir(), DomainSocketBaseName)
|
return filepath.Join(GetWaveDataDir(), DomainSocketBaseName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetRemoteDomainSocketName() string {
|
||||||
|
return filepath.Join(RemoteWaveHome, RemoteDomainSocketBaseName)
|
||||||
|
}
|
||||||
|
|
||||||
func GetWaveDataDir() string {
|
func GetWaveDataDir() string {
|
||||||
retVal, found := os.LookupEnv(WaveDataHomeEnvVar)
|
retVal, found := os.LookupEnv(WaveDataHomeEnvVar)
|
||||||
if !found {
|
if !found {
|
||||||
|
@ -431,7 +431,7 @@ func MakeTCPListener(serviceName string) (net.Listener, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func MakeUnixListener() (net.Listener, error) {
|
func MakeUnixListener() (net.Listener, error) {
|
||||||
serverAddr := wavebase.GetWaveDataDir() + "/wave.sock"
|
serverAddr := wavebase.GetDomainSocketName()
|
||||||
os.Remove(serverAddr) // ignore error
|
os.Remove(serverAddr) // ignore error
|
||||||
rtn, err := net.Listen("unix", serverAddr)
|
rtn, err := net.Listen("unix", serverAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -252,7 +252,7 @@ func registerConn(wsConnId string, routeId string, wproxy *wshutil.WshRpcProxy)
|
|||||||
wshutil.DefaultRouter.UnregisterRoute(routeId)
|
wshutil.DefaultRouter.UnregisterRoute(routeId)
|
||||||
}
|
}
|
||||||
RouteToConnMap[routeId] = wsConnId
|
RouteToConnMap[routeId] = wsConnId
|
||||||
wshutil.DefaultRouter.RegisterRoute(routeId, wproxy)
|
wshutil.DefaultRouter.RegisterRoute(routeId, wproxy, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func unregisterConn(wsConnId string, routeId string) {
|
func unregisterConn(wsConnId string, routeId string) {
|
||||||
|
@ -92,6 +92,12 @@ func DeleteBlockCommand(w *wshutil.WshRpc, data wshrpc.CommandDeleteBlockData, o
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// command "dispose", wshserver.DisposeCommand
|
||||||
|
func DisposeCommand(w *wshutil.WshRpc, data wshrpc.CommandDisposeData, opts *wshrpc.RpcOpts) error {
|
||||||
|
_, err := sendRpcRequestCallHelper[any](w, "dispose", data, opts)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// command "eventpublish", wshserver.EventPublishCommand
|
// command "eventpublish", wshserver.EventPublishCommand
|
||||||
func EventPublishCommand(w *wshutil.WshRpc, data wps.WaveEvent, opts *wshrpc.RpcOpts) error {
|
func EventPublishCommand(w *wshutil.WshRpc, data wps.WaveEvent, opts *wshrpc.RpcOpts) error {
|
||||||
_, err := sendRpcRequestCallHelper[any](w, "eventpublish", data, opts)
|
_, err := sendRpcRequestCallHelper[any](w, "eventpublish", data, opts)
|
||||||
@ -285,4 +291,22 @@ func WebSelectorCommand(w *wshutil.WshRpc, data wshrpc.CommandWebSelectorData, o
|
|||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// command "wsldefaultdistro", wshserver.WslDefaultDistroCommand
|
||||||
|
func WslDefaultDistroCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) (string, error) {
|
||||||
|
resp, err := sendRpcRequestCallHelper[string](w, "wsldefaultdistro", nil, opts)
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// command "wsllist", wshserver.WslListCommand
|
||||||
|
func WslListCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) ([]string, error) {
|
||||||
|
resp, err := sendRpcRequestCallHelper[[]string](w, "wsllist", nil, opts)
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// command "wslstatus", wshserver.WslStatusCommand
|
||||||
|
func WslStatusCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) ([]wshrpc.ConnStatus, error) {
|
||||||
|
resp, err := sendRpcRequestCallHelper[[]wshrpc.ConnStatus](w, "wslstatus", nil, opts)
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,6 +28,7 @@ const (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
Command_Authenticate = "authenticate" // special
|
Command_Authenticate = "authenticate" // special
|
||||||
|
Command_Dispose = "dispose" // special (disposes of the route, for multiproxy only)
|
||||||
Command_RouteAnnounce = "routeannounce" // special (for routing)
|
Command_RouteAnnounce = "routeannounce" // special (for routing)
|
||||||
Command_RouteUnannounce = "routeunannounce" // special (for routing)
|
Command_RouteUnannounce = "routeunannounce" // special (for routing)
|
||||||
Command_Message = "message"
|
Command_Message = "message"
|
||||||
@ -62,11 +63,15 @@ const (
|
|||||||
Command_RemoteFileDelete = "remotefiledelete"
|
Command_RemoteFileDelete = "remotefiledelete"
|
||||||
Command_RemoteFileJoiin = "remotefilejoin"
|
Command_RemoteFileJoiin = "remotefilejoin"
|
||||||
|
|
||||||
|
Command_ConnStatus = "connstatus"
|
||||||
|
Command_WslStatus = "wslstatus"
|
||||||
Command_ConnEnsure = "connensure"
|
Command_ConnEnsure = "connensure"
|
||||||
Command_ConnReinstallWsh = "connreinstallwsh"
|
Command_ConnReinstallWsh = "connreinstallwsh"
|
||||||
Command_ConnConnect = "connconnect"
|
Command_ConnConnect = "connconnect"
|
||||||
Command_ConnDisconnect = "conndisconnect"
|
Command_ConnDisconnect = "conndisconnect"
|
||||||
Command_ConnList = "connlist"
|
Command_ConnList = "connlist"
|
||||||
|
Command_WslList = "wsllist"
|
||||||
|
Command_WslDefaultDistro = "wsldefaultdistro"
|
||||||
|
|
||||||
Command_WebSelector = "webselector"
|
Command_WebSelector = "webselector"
|
||||||
Command_Notify = "notify"
|
Command_Notify = "notify"
|
||||||
@ -83,6 +88,7 @@ type RespOrErrorUnion[T any] struct {
|
|||||||
|
|
||||||
type WshRpcInterface interface {
|
type WshRpcInterface interface {
|
||||||
AuthenticateCommand(ctx context.Context, data string) (CommandAuthenticateRtnData, error)
|
AuthenticateCommand(ctx context.Context, data string) (CommandAuthenticateRtnData, error)
|
||||||
|
DisposeCommand(ctx context.Context, data CommandDisposeData) error
|
||||||
RouteAnnounceCommand(ctx context.Context) error // (special) announces a new route to the main router
|
RouteAnnounceCommand(ctx context.Context) error // (special) announces a new route to the main router
|
||||||
RouteUnannounceCommand(ctx context.Context) error // (special) unannounces a route to the main router
|
RouteUnannounceCommand(ctx context.Context) error // (special) unannounces a route to the main router
|
||||||
|
|
||||||
@ -114,11 +120,14 @@ type WshRpcInterface interface {
|
|||||||
|
|
||||||
// connection functions
|
// connection functions
|
||||||
ConnStatusCommand(ctx context.Context) ([]ConnStatus, error)
|
ConnStatusCommand(ctx context.Context) ([]ConnStatus, error)
|
||||||
|
WslStatusCommand(ctx context.Context) ([]ConnStatus, error)
|
||||||
ConnEnsureCommand(ctx context.Context, connName string) error
|
ConnEnsureCommand(ctx context.Context, connName string) error
|
||||||
ConnReinstallWshCommand(ctx context.Context, connName string) error
|
ConnReinstallWshCommand(ctx context.Context, connName string) error
|
||||||
ConnConnectCommand(ctx context.Context, connName string) error
|
ConnConnectCommand(ctx context.Context, connName string) error
|
||||||
ConnDisconnectCommand(ctx context.Context, connName string) error
|
ConnDisconnectCommand(ctx context.Context, connName string) error
|
||||||
ConnListCommand(ctx context.Context) ([]string, error)
|
ConnListCommand(ctx context.Context) ([]string, error)
|
||||||
|
WslListCommand(ctx context.Context) ([]string, error)
|
||||||
|
WslDefaultDistroCommand(ctx context.Context) (string, error)
|
||||||
|
|
||||||
// eventrecv is special, it's handled internally by WshRpc with EventListener
|
// eventrecv is special, it's handled internally by WshRpc with EventListener
|
||||||
EventRecvCommand(ctx context.Context, data wps.WaveEvent) error
|
EventRecvCommand(ctx context.Context, data wps.WaveEvent) error
|
||||||
@ -200,7 +209,13 @@ func HackRpcContextIntoData(dataPtr any, rpcContext RpcContext) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type CommandAuthenticateRtnData struct {
|
type CommandAuthenticateRtnData struct {
|
||||||
|
RouteId string `json:"routeid"`
|
||||||
|
AuthToken string `json:"authtoken,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CommandDisposeData struct {
|
||||||
RouteId string `json:"routeid"`
|
RouteId string `json:"routeid"`
|
||||||
|
// auth token travels in the packet directly
|
||||||
}
|
}
|
||||||
|
|
||||||
type CommandMessageData struct {
|
type CommandMessageData struct {
|
||||||
|
@ -21,6 +21,7 @@ import (
|
|||||||
"github.com/wavetermdev/waveterm/pkg/filestore"
|
"github.com/wavetermdev/waveterm/pkg/filestore"
|
||||||
"github.com/wavetermdev/waveterm/pkg/remote"
|
"github.com/wavetermdev/waveterm/pkg/remote"
|
||||||
"github.com/wavetermdev/waveterm/pkg/remote/conncontroller"
|
"github.com/wavetermdev/waveterm/pkg/remote/conncontroller"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
|
||||||
"github.com/wavetermdev/waveterm/pkg/waveai"
|
"github.com/wavetermdev/waveterm/pkg/waveai"
|
||||||
"github.com/wavetermdev/waveterm/pkg/waveobj"
|
"github.com/wavetermdev/waveterm/pkg/waveobj"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wconfig"
|
"github.com/wavetermdev/waveterm/pkg/wconfig"
|
||||||
@ -29,6 +30,7 @@ import (
|
|||||||
"github.com/wavetermdev/waveterm/pkg/wps"
|
"github.com/wavetermdev/waveterm/pkg/wps"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/wsl"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wstore"
|
"github.com/wavetermdev/waveterm/pkg/wstore"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -36,6 +38,7 @@ const SimpleId_This = "this"
|
|||||||
const SimpleId_Tab = "tab"
|
const SimpleId_Tab = "tab"
|
||||||
|
|
||||||
var SimpleId_BlockNum_Regex = regexp.MustCompile(`^\d+$`)
|
var SimpleId_BlockNum_Regex = regexp.MustCompile(`^\d+$`)
|
||||||
|
var InvalidWslDistroNames = []string{"docker-desktop", "docker-desktop-data"}
|
||||||
|
|
||||||
type WshServer struct{}
|
type WshServer struct{}
|
||||||
|
|
||||||
@ -463,11 +466,28 @@ func (ws *WshServer) ConnStatusCommand(ctx context.Context) ([]wshrpc.ConnStatus
|
|||||||
return rtn, nil
|
return rtn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ws *WshServer) WslStatusCommand(ctx context.Context) ([]wshrpc.ConnStatus, error) {
|
||||||
|
rtn := wsl.GetAllConnStatus()
|
||||||
|
return rtn, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (ws *WshServer) ConnEnsureCommand(ctx context.Context, connName string) error {
|
func (ws *WshServer) ConnEnsureCommand(ctx context.Context, connName string) error {
|
||||||
|
if strings.HasPrefix(connName, "wsl://") {
|
||||||
|
distroName := strings.TrimPrefix(connName, "wsl://")
|
||||||
|
return wsl.EnsureConnection(ctx, distroName)
|
||||||
|
}
|
||||||
return conncontroller.EnsureConnection(ctx, connName)
|
return conncontroller.EnsureConnection(ctx, connName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ws *WshServer) ConnDisconnectCommand(ctx context.Context, connName string) error {
|
func (ws *WshServer) ConnDisconnectCommand(ctx context.Context, connName string) error {
|
||||||
|
if strings.HasPrefix(connName, "wsl://") {
|
||||||
|
distroName := strings.TrimPrefix(connName, "wsl://")
|
||||||
|
conn := wsl.GetWslConn(ctx, distroName, false)
|
||||||
|
if conn == nil {
|
||||||
|
return fmt.Errorf("distro not found: %s", connName)
|
||||||
|
}
|
||||||
|
return conn.Close()
|
||||||
|
}
|
||||||
connOpts, err := remote.ParseOpts(connName)
|
connOpts, err := remote.ParseOpts(connName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error parsing connection name: %w", err)
|
return fmt.Errorf("error parsing connection name: %w", err)
|
||||||
@ -480,6 +500,14 @@ func (ws *WshServer) ConnDisconnectCommand(ctx context.Context, connName string)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ws *WshServer) ConnConnectCommand(ctx context.Context, connName string) error {
|
func (ws *WshServer) ConnConnectCommand(ctx context.Context, connName string) error {
|
||||||
|
if strings.HasPrefix(connName, "wsl://") {
|
||||||
|
distroName := strings.TrimPrefix(connName, "wsl://")
|
||||||
|
conn := wsl.GetWslConn(ctx, distroName, false)
|
||||||
|
if conn == nil {
|
||||||
|
return fmt.Errorf("connection not found: %s", connName)
|
||||||
|
}
|
||||||
|
return conn.Connect(ctx)
|
||||||
|
}
|
||||||
connOpts, err := remote.ParseOpts(connName)
|
connOpts, err := remote.ParseOpts(connName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error parsing connection name: %w", err)
|
return fmt.Errorf("error parsing connection name: %w", err)
|
||||||
@ -492,6 +520,14 @@ func (ws *WshServer) ConnConnectCommand(ctx context.Context, connName string) er
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ws *WshServer) ConnReinstallWshCommand(ctx context.Context, connName string) error {
|
func (ws *WshServer) ConnReinstallWshCommand(ctx context.Context, connName string) error {
|
||||||
|
if strings.HasPrefix(connName, "wsl://") {
|
||||||
|
distroName := strings.TrimPrefix(connName, "wsl://")
|
||||||
|
conn := wsl.GetWslConn(ctx, distroName, false)
|
||||||
|
if conn == nil {
|
||||||
|
return fmt.Errorf("connection not found: %s", connName)
|
||||||
|
}
|
||||||
|
return conn.CheckAndInstallWsh(ctx, connName, &wsl.WshInstallOpts{Force: true, NoUserPrompt: true})
|
||||||
|
}
|
||||||
connOpts, err := remote.ParseOpts(connName)
|
connOpts, err := remote.ParseOpts(connName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error parsing connection name: %w", err)
|
return fmt.Errorf("error parsing connection name: %w", err)
|
||||||
@ -507,6 +543,33 @@ func (ws *WshServer) ConnListCommand(ctx context.Context) ([]string, error) {
|
|||||||
return conncontroller.GetConnectionsList()
|
return conncontroller.GetConnectionsList()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ws *WshServer) WslListCommand(ctx context.Context) ([]string, error) {
|
||||||
|
distros, err := wsl.RegisteredDistros(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var distroNames []string
|
||||||
|
for _, distro := range distros {
|
||||||
|
distroName := distro.Name()
|
||||||
|
if utilfn.ContainsStr(InvalidWslDistroNames, distroName) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
distroNames = append(distroNames, distroName)
|
||||||
|
}
|
||||||
|
return distroNames, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WshServer) WslDefaultDistroCommand(ctx context.Context) (string, error) {
|
||||||
|
distro, ok, err := wsl.DefaultDistro(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("unable to determine default distro: %w", err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("unable to determine default distro")
|
||||||
|
}
|
||||||
|
return distro.Name(), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (ws *WshServer) BlockInfoCommand(ctx context.Context, blockId string) (*wshrpc.BlockInfoData, error) {
|
func (ws *WshServer) BlockInfoCommand(ctx context.Context, blockId string) (*wshrpc.BlockInfoData, error) {
|
||||||
blockData, err := wstore.DBMustGet[*waveobj.Block](ctx, blockId)
|
blockData, err := wstore.DBMustGet[*waveobj.Block](ctx, blockId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
151
pkg/wshutil/wshmultiproxy.go
Normal file
151
pkg/wshutil/wshmultiproxy.go
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
// Copyright 2024, Command Line Inc.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package wshutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
type multiProxyRouteInfo struct {
|
||||||
|
RouteId string
|
||||||
|
AuthToken string
|
||||||
|
Proxy *WshRpcProxy
|
||||||
|
RpcContext *wshrpc.RpcContext
|
||||||
|
}
|
||||||
|
|
||||||
|
// handles messages from multiple unauthenitcated clients
|
||||||
|
type WshRpcMultiProxy struct {
|
||||||
|
Lock *sync.Mutex
|
||||||
|
RouteInfo map[string]*multiProxyRouteInfo // authtoken to info
|
||||||
|
ToRemoteCh chan []byte
|
||||||
|
FromRemoteRawCh chan []byte // raw message from the remote
|
||||||
|
}
|
||||||
|
|
||||||
|
func MakeRpcMultiProxy() *WshRpcMultiProxy {
|
||||||
|
return &WshRpcMultiProxy{
|
||||||
|
Lock: &sync.Mutex{},
|
||||||
|
RouteInfo: make(map[string]*multiProxyRouteInfo),
|
||||||
|
ToRemoteCh: make(chan []byte, DefaultInputChSize),
|
||||||
|
FromRemoteRawCh: make(chan []byte, DefaultOutputChSize),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *WshRpcMultiProxy) DisposeRoutes() {
|
||||||
|
p.Lock.Lock()
|
||||||
|
defer p.Lock.Unlock()
|
||||||
|
for authToken, routeInfo := range p.RouteInfo {
|
||||||
|
DefaultRouter.UnregisterRoute(routeInfo.RouteId)
|
||||||
|
delete(p.RouteInfo, authToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *WshRpcMultiProxy) getRouteInfo(authToken string) *multiProxyRouteInfo {
|
||||||
|
p.Lock.Lock()
|
||||||
|
defer p.Lock.Unlock()
|
||||||
|
return p.RouteInfo[authToken]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *WshRpcMultiProxy) setRouteInfo(authToken string, routeInfo *multiProxyRouteInfo) {
|
||||||
|
p.Lock.Lock()
|
||||||
|
defer p.Lock.Unlock()
|
||||||
|
p.RouteInfo[authToken] = routeInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *WshRpcMultiProxy) removeRouteInfo(authToken string) {
|
||||||
|
p.Lock.Lock()
|
||||||
|
defer p.Lock.Unlock()
|
||||||
|
delete(p.RouteInfo, authToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *WshRpcMultiProxy) sendResponseError(msg RpcMessage, sendErr error) {
|
||||||
|
if msg.ReqId == "" {
|
||||||
|
// no response needed
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp := RpcMessage{
|
||||||
|
ResId: msg.ReqId,
|
||||||
|
Error: sendErr.Error(),
|
||||||
|
}
|
||||||
|
respBytes, _ := json.Marshal(resp)
|
||||||
|
p.ToRemoteCh <- respBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *WshRpcMultiProxy) sendAuthResponse(msg RpcMessage, routeId string, authToken string) {
|
||||||
|
if msg.ReqId == "" {
|
||||||
|
// no response needed
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp := RpcMessage{
|
||||||
|
ResId: msg.ReqId,
|
||||||
|
Data: wshrpc.CommandAuthenticateRtnData{RouteId: routeId, AuthToken: authToken},
|
||||||
|
}
|
||||||
|
respBytes, _ := json.Marshal(resp)
|
||||||
|
p.ToRemoteCh <- respBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *WshRpcMultiProxy) handleUnauthMessage(msgBytes []byte) {
|
||||||
|
var msg RpcMessage
|
||||||
|
err := json.Unmarshal(msgBytes, &msg)
|
||||||
|
if err != nil {
|
||||||
|
// nothing to do here, malformed message
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if msg.Command == wshrpc.Command_Authenticate {
|
||||||
|
rpcContext, routeId, err := handleAuthenticationCommand(msg)
|
||||||
|
if err != nil {
|
||||||
|
p.sendResponseError(msg, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
routeInfo := &multiProxyRouteInfo{
|
||||||
|
RouteId: routeId,
|
||||||
|
AuthToken: uuid.New().String(),
|
||||||
|
RpcContext: rpcContext,
|
||||||
|
}
|
||||||
|
routeInfo.Proxy = MakeRpcProxy()
|
||||||
|
routeInfo.Proxy.SetRpcContext(rpcContext)
|
||||||
|
p.setRouteInfo(routeInfo.AuthToken, routeInfo)
|
||||||
|
p.sendAuthResponse(msg, routeId, routeInfo.AuthToken)
|
||||||
|
go func() {
|
||||||
|
for msgBytes := range routeInfo.Proxy.ToRemoteCh {
|
||||||
|
p.ToRemoteCh <- msgBytes
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
DefaultRouter.RegisterRoute(routeId, routeInfo.Proxy, true)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if msg.AuthToken == "" {
|
||||||
|
p.sendResponseError(msg, fmt.Errorf("no auth token"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
routeInfo := p.getRouteInfo(msg.AuthToken)
|
||||||
|
if routeInfo == nil {
|
||||||
|
p.sendResponseError(msg, fmt.Errorf("invalid auth token"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if msg.Command != "" && msg.Source != routeInfo.RouteId {
|
||||||
|
p.sendResponseError(msg, fmt.Errorf("invalid source route for auth token"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if msg.Command == wshrpc.Command_Dispose {
|
||||||
|
DefaultRouter.UnregisterRoute(routeInfo.RouteId)
|
||||||
|
p.removeRouteInfo(msg.AuthToken)
|
||||||
|
close(routeInfo.Proxy.ToRemoteCh)
|
||||||
|
close(routeInfo.Proxy.FromRemoteCh)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
routeInfo.Proxy.FromRemoteCh <- msgBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *WshRpcMultiProxy) RunUnauthLoop() {
|
||||||
|
// loop over unauthenticated message
|
||||||
|
// handle Authenicate commands, and pass authenticated messages to the AuthCh
|
||||||
|
for msgBytes := range p.FromRemoteRawCh {
|
||||||
|
p.handleUnauthMessage(msgBytes)
|
||||||
|
}
|
||||||
|
}
|
@ -6,7 +6,6 @@ package wshutil
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@ -18,6 +17,7 @@ type WshRpcProxy struct {
|
|||||||
RpcContext *wshrpc.RpcContext
|
RpcContext *wshrpc.RpcContext
|
||||||
ToRemoteCh chan []byte
|
ToRemoteCh chan []byte
|
||||||
FromRemoteCh chan []byte
|
FromRemoteCh chan []byte
|
||||||
|
AuthToken string
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeRpcProxy() *WshRpcProxy {
|
func MakeRpcProxy() *WshRpcProxy {
|
||||||
@ -40,6 +40,18 @@ func (p *WshRpcProxy) GetRpcContext() *wshrpc.RpcContext {
|
|||||||
return p.RpcContext
|
return p.RpcContext
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *WshRpcProxy) SetAuthToken(authToken string) {
|
||||||
|
p.Lock.Lock()
|
||||||
|
defer p.Lock.Unlock()
|
||||||
|
p.AuthToken = authToken
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *WshRpcProxy) GetAuthToken() string {
|
||||||
|
p.Lock.Lock()
|
||||||
|
defer p.Lock.Unlock()
|
||||||
|
return p.AuthToken
|
||||||
|
}
|
||||||
|
|
||||||
func (p *WshRpcProxy) sendResponseError(msg RpcMessage, sendErr error) {
|
func (p *WshRpcProxy) sendResponseError(msg RpcMessage, sendErr error) {
|
||||||
if msg.ReqId == "" {
|
if msg.ReqId == "" {
|
||||||
// no response needed
|
// no response needed
|
||||||
@ -54,7 +66,7 @@ func (p *WshRpcProxy) sendResponseError(msg RpcMessage, sendErr error) {
|
|||||||
p.SendRpcMessage(respBytes)
|
p.SendRpcMessage(respBytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *WshRpcProxy) sendResponse(msg RpcMessage, routeId string) {
|
func (p *WshRpcProxy) sendAuthenticateResponse(msg RpcMessage, routeId string) {
|
||||||
if msg.ReqId == "" {
|
if msg.ReqId == "" {
|
||||||
// no response needed
|
// no response needed
|
||||||
return
|
return
|
||||||
@ -98,6 +110,49 @@ func handleAuthenticationCommand(msg RpcMessage) (*wshrpc.RpcContext, string, er
|
|||||||
return newCtx, routeId, nil
|
return newCtx, routeId, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// runs on the client (stdio client)
|
||||||
|
func (p *WshRpcProxy) HandleClientProxyAuth(router *WshRouter) (string, error) {
|
||||||
|
for {
|
||||||
|
msgBytes, ok := <-p.FromRemoteCh
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("remote closed, not authenticated")
|
||||||
|
}
|
||||||
|
var origMsg RpcMessage
|
||||||
|
err := json.Unmarshal(msgBytes, &origMsg)
|
||||||
|
if err != nil {
|
||||||
|
// nothing to do, can't even send a response since we don't have Source or ReqId
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if origMsg.Command == "" {
|
||||||
|
// this message is not allowed (protocol error at this point), ignore
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// we only allow one command "authenticate", everything else returns an error
|
||||||
|
if origMsg.Command != wshrpc.Command_Authenticate {
|
||||||
|
respErr := fmt.Errorf("connection not authenticated")
|
||||||
|
p.sendResponseError(origMsg, respErr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
authRtn, err := router.HandleProxyAuth(origMsg.Data)
|
||||||
|
if err != nil {
|
||||||
|
respErr := fmt.Errorf("error handling proxy auth: %w", err)
|
||||||
|
p.sendResponseError(origMsg, respErr)
|
||||||
|
return "", respErr
|
||||||
|
}
|
||||||
|
p.SetAuthToken(authRtn.AuthToken)
|
||||||
|
announceMsg := RpcMessage{
|
||||||
|
Command: wshrpc.Command_RouteAnnounce,
|
||||||
|
Source: authRtn.RouteId,
|
||||||
|
AuthToken: authRtn.AuthToken,
|
||||||
|
}
|
||||||
|
announceBytes, _ := json.Marshal(announceMsg)
|
||||||
|
router.InjectMessage(announceBytes, authRtn.RouteId)
|
||||||
|
p.sendAuthenticateResponse(origMsg, authRtn.RouteId)
|
||||||
|
return authRtn.RouteId, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// runs on the server
|
||||||
func (p *WshRpcProxy) HandleAuthentication() (*wshrpc.RpcContext, error) {
|
func (p *WshRpcProxy) HandleAuthentication() (*wshrpc.RpcContext, error) {
|
||||||
for {
|
for {
|
||||||
msgBytes, ok := <-p.FromRemoteCh
|
msgBytes, ok := <-p.FromRemoteCh
|
||||||
@ -122,11 +177,10 @@ func (p *WshRpcProxy) HandleAuthentication() (*wshrpc.RpcContext, error) {
|
|||||||
}
|
}
|
||||||
newCtx, routeId, err := handleAuthenticationCommand(msg)
|
newCtx, routeId, err := handleAuthenticationCommand(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("error handling authentication: %v\n", err)
|
|
||||||
p.sendResponseError(msg, err)
|
p.sendResponseError(msg, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
p.sendResponse(msg, routeId)
|
p.sendAuthenticateResponse(msg, routeId)
|
||||||
return newCtx, nil
|
return newCtx, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -136,9 +190,10 @@ func (p *WshRpcProxy) SendRpcMessage(msg []byte) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *WshRpcProxy) RecvRpcMessage() ([]byte, bool) {
|
func (p *WshRpcProxy) RecvRpcMessage() ([]byte, bool) {
|
||||||
msgBytes, ok := <-p.FromRemoteCh
|
msgBytes, more := <-p.FromRemoteCh
|
||||||
if !ok || p.RpcContext == nil {
|
authToken := p.GetAuthToken()
|
||||||
return msgBytes, ok
|
if !more || (p.RpcContext == nil && authToken == "") {
|
||||||
|
return msgBytes, more
|
||||||
}
|
}
|
||||||
var msg RpcMessage
|
var msg RpcMessage
|
||||||
err := json.Unmarshal(msgBytes, &msg)
|
err := json.Unmarshal(msgBytes, &msg)
|
||||||
@ -146,10 +201,15 @@ func (p *WshRpcProxy) RecvRpcMessage() ([]byte, bool) {
|
|||||||
// nothing to do here -- will error out at another level
|
// nothing to do here -- will error out at another level
|
||||||
return msgBytes, true
|
return msgBytes, true
|
||||||
}
|
}
|
||||||
msg.Data, err = recodeCommandData(msg.Command, msg.Data, p.RpcContext)
|
if p.RpcContext != nil {
|
||||||
if err != nil {
|
msg.Data, err = recodeCommandData(msg.Command, msg.Data, p.RpcContext)
|
||||||
// nothing to do here -- will error out at another level
|
if err != nil {
|
||||||
return msgBytes, true
|
// nothing to do here -- will error out at another level
|
||||||
|
return msgBytes, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if msg.AuthToken == "" {
|
||||||
|
msg.AuthToken = authToken
|
||||||
}
|
}
|
||||||
newBytes, err := json.Marshal(msg)
|
newBytes, err := json.Marshal(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -12,11 +12,14 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wps"
|
"github.com/wavetermdev/waveterm/pkg/wps"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||||
)
|
)
|
||||||
|
|
||||||
const DefaultRoute = "wavesrv"
|
const DefaultRoute = "wavesrv"
|
||||||
|
const UpstreamRoute = "upstream"
|
||||||
const SysRoute = "sys" // this route doesn't exist, just a placeholder for system messages
|
const SysRoute = "sys" // this route doesn't exist, just a placeholder for system messages
|
||||||
const ElectronRoute = "electron"
|
const ElectronRoute = "electron"
|
||||||
|
|
||||||
@ -36,12 +39,13 @@ type msgAndRoute struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type WshRouter struct {
|
type WshRouter struct {
|
||||||
Lock *sync.Mutex
|
Lock *sync.Mutex
|
||||||
RouteMap map[string]AbstractRpcClient // routeid => client
|
RouteMap map[string]AbstractRpcClient // routeid => client
|
||||||
UpstreamClient AbstractRpcClient // upstream client (if we are not the terminal router)
|
UpstreamClient AbstractRpcClient // upstream client (if we are not the terminal router)
|
||||||
AnnouncedRoutes map[string]string // routeid => local routeid
|
AnnouncedRoutes map[string]string // routeid => local routeid
|
||||||
RpcMap map[string]*routeInfo // rpcid => routeinfo
|
RpcMap map[string]*routeInfo // rpcid => routeinfo
|
||||||
InputCh chan msgAndRoute
|
SimpleRequestMap map[string]chan *RpcMessage // simple reqid => response channel
|
||||||
|
InputCh chan msgAndRoute
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeConnectionRouteId(connId string) string {
|
func MakeConnectionRouteId(connId string) string {
|
||||||
@ -68,11 +72,12 @@ var DefaultRouter = NewWshRouter()
|
|||||||
|
|
||||||
func NewWshRouter() *WshRouter {
|
func NewWshRouter() *WshRouter {
|
||||||
rtn := &WshRouter{
|
rtn := &WshRouter{
|
||||||
Lock: &sync.Mutex{},
|
Lock: &sync.Mutex{},
|
||||||
RouteMap: make(map[string]AbstractRpcClient),
|
RouteMap: make(map[string]AbstractRpcClient),
|
||||||
AnnouncedRoutes: make(map[string]string),
|
AnnouncedRoutes: make(map[string]string),
|
||||||
RpcMap: make(map[string]*routeInfo),
|
RpcMap: make(map[string]*routeInfo),
|
||||||
InputCh: make(chan msgAndRoute, DefaultInputChSize),
|
SimpleRequestMap: make(map[string]chan *RpcMessage),
|
||||||
|
InputCh: make(chan msgAndRoute, DefaultInputChSize),
|
||||||
}
|
}
|
||||||
go rtn.runServer()
|
go rtn.runServer()
|
||||||
return rtn
|
return rtn
|
||||||
@ -237,6 +242,10 @@ func (router *WshRouter) runServer() {
|
|||||||
router.sendRoutedMessage(msgBytes, routeInfo.DestRouteId)
|
router.sendRoutedMessage(msgBytes, routeInfo.DestRouteId)
|
||||||
continue
|
continue
|
||||||
} else if msg.ResId != "" {
|
} else if msg.ResId != "" {
|
||||||
|
ok := router.trySimpleResponse(&msg)
|
||||||
|
if ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
routeInfo := router.getRouteInfo(msg.ResId)
|
routeInfo := router.getRouteInfo(msg.ResId)
|
||||||
if routeInfo == nil {
|
if routeInfo == nil {
|
||||||
// no route info, nothing to do
|
// no route info, nothing to do
|
||||||
@ -269,10 +278,10 @@ func (router *WshRouter) WaitForRegister(ctx context.Context, routeId string) er
|
|||||||
}
|
}
|
||||||
|
|
||||||
// this will also consume the output channel of the abstract client
|
// this will also consume the output channel of the abstract client
|
||||||
func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient) {
|
func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient, shouldAnnounce bool) {
|
||||||
if routeId == SysRoute {
|
if routeId == SysRoute || routeId == UpstreamRoute {
|
||||||
// cannot register sys route
|
// cannot register sys route
|
||||||
log.Printf("error: WshRouter cannot register sys route\n")
|
log.Printf("error: WshRouter cannot register %s route\n", routeId)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Printf("[router] registering wsh route %q\n", routeId)
|
log.Printf("[router] registering wsh route %q\n", routeId)
|
||||||
@ -285,7 +294,7 @@ func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient) {
|
|||||||
router.RouteMap[routeId] = rpc
|
router.RouteMap[routeId] = rpc
|
||||||
go func() {
|
go func() {
|
||||||
// announce
|
// announce
|
||||||
if !alreadyExists && router.GetUpstreamClient() != nil {
|
if shouldAnnounce && !alreadyExists && router.GetUpstreamClient() != nil {
|
||||||
announceMsg := RpcMessage{Command: wshrpc.Command_RouteAnnounce, Source: routeId}
|
announceMsg := RpcMessage{Command: wshrpc.Command_RouteAnnounce, Source: routeId}
|
||||||
announceBytes, _ := json.Marshal(announceMsg)
|
announceBytes, _ := json.Marshal(announceMsg)
|
||||||
router.GetUpstreamClient().SendRpcMessage(announceBytes)
|
router.GetUpstreamClient().SendRpcMessage(announceBytes)
|
||||||
@ -352,3 +361,97 @@ func (router *WshRouter) GetUpstreamClient() AbstractRpcClient {
|
|||||||
defer router.Lock.Unlock()
|
defer router.Lock.Unlock()
|
||||||
return router.UpstreamClient
|
return router.UpstreamClient
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (router *WshRouter) InjectMessage(msgBytes []byte, fromRouteId string) {
|
||||||
|
router.InputCh <- msgAndRoute{msgBytes: msgBytes, fromRouteId: fromRouteId}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (router *WshRouter) registerSimpleRequest(reqId string) chan *RpcMessage {
|
||||||
|
router.Lock.Lock()
|
||||||
|
defer router.Lock.Unlock()
|
||||||
|
rtn := make(chan *RpcMessage, 1)
|
||||||
|
router.SimpleRequestMap[reqId] = rtn
|
||||||
|
return rtn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (router *WshRouter) trySimpleResponse(msg *RpcMessage) bool {
|
||||||
|
router.Lock.Lock()
|
||||||
|
defer router.Lock.Unlock()
|
||||||
|
respCh := router.SimpleRequestMap[msg.ResId]
|
||||||
|
if respCh == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
respCh <- msg
|
||||||
|
delete(router.SimpleRequestMap, msg.ResId)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (router *WshRouter) clearSimpleRequest(reqId string) {
|
||||||
|
router.Lock.Lock()
|
||||||
|
defer router.Lock.Unlock()
|
||||||
|
delete(router.SimpleRequestMap, reqId)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (router *WshRouter) RunSimpleRawCommand(ctx context.Context, msg RpcMessage, fromRouteId string) (*RpcMessage, error) {
|
||||||
|
if msg.Command == "" {
|
||||||
|
return nil, errors.New("no command")
|
||||||
|
}
|
||||||
|
msgBytes, err := json.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var respCh chan *RpcMessage
|
||||||
|
if msg.ReqId != "" {
|
||||||
|
respCh = router.registerSimpleRequest(msg.ReqId)
|
||||||
|
}
|
||||||
|
router.InjectMessage(msgBytes, fromRouteId)
|
||||||
|
if respCh == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
router.clearSimpleRequest(msg.ReqId)
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case resp := <-respCh:
|
||||||
|
if resp.Error != "" {
|
||||||
|
return nil, errors.New(resp.Error)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (router *WshRouter) HandleProxyAuth(jwtTokenAny any) (*wshrpc.CommandAuthenticateRtnData, error) {
|
||||||
|
if jwtTokenAny == nil {
|
||||||
|
return nil, errors.New("no jwt token")
|
||||||
|
}
|
||||||
|
jwtToken, ok := jwtTokenAny.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("jwt token not a string")
|
||||||
|
}
|
||||||
|
if jwtToken == "" {
|
||||||
|
return nil, errors.New("empty jwt token")
|
||||||
|
}
|
||||||
|
msg := RpcMessage{
|
||||||
|
Command: wshrpc.Command_Authenticate,
|
||||||
|
ReqId: uuid.New().String(),
|
||||||
|
Data: jwtToken,
|
||||||
|
}
|
||||||
|
ctx, cancelFn := context.WithTimeout(context.Background(), DefaultTimeoutMs*time.Millisecond)
|
||||||
|
defer cancelFn()
|
||||||
|
resp, err := router.RunSimpleRawCommand(ctx, msg, "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if resp == nil || resp.Data == nil {
|
||||||
|
return nil, errors.New("no data in authenticate response")
|
||||||
|
}
|
||||||
|
var respData wshrpc.CommandAuthenticateRtnData
|
||||||
|
err = utilfn.ReUnmarshal(&respData, resp.Data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error unmarshalling authenticate response: %v", err)
|
||||||
|
}
|
||||||
|
if respData.AuthToken == "" {
|
||||||
|
return nil, errors.New("no auth token in authenticate response")
|
||||||
|
}
|
||||||
|
return &respData, nil
|
||||||
|
}
|
||||||
|
@ -45,10 +45,13 @@ type WshRpc struct {
|
|||||||
InputCh chan []byte
|
InputCh chan []byte
|
||||||
OutputCh chan []byte
|
OutputCh chan []byte
|
||||||
RpcContext *atomic.Pointer[wshrpc.RpcContext]
|
RpcContext *atomic.Pointer[wshrpc.RpcContext]
|
||||||
|
AuthToken string
|
||||||
RpcMap map[string]*rpcData
|
RpcMap map[string]*rpcData
|
||||||
ServerImpl ServerImpl
|
ServerImpl ServerImpl
|
||||||
EventListener *EventListener
|
EventListener *EventListener
|
||||||
ResponseHandlerMap map[string]*RpcResponseHandler // reqId => handler
|
ResponseHandlerMap map[string]*RpcResponseHandler // reqId => handler
|
||||||
|
Debug bool
|
||||||
|
DebugName string
|
||||||
}
|
}
|
||||||
|
|
||||||
type wshRpcContextKey struct{}
|
type wshRpcContextKey struct{}
|
||||||
@ -104,17 +107,18 @@ func (w *WshRpc) RecvRpcMessage() ([]byte, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type RpcMessage struct {
|
type RpcMessage struct {
|
||||||
Command string `json:"command,omitempty"`
|
Command string `json:"command,omitempty"`
|
||||||
ReqId string `json:"reqid,omitempty"`
|
ReqId string `json:"reqid,omitempty"`
|
||||||
ResId string `json:"resid,omitempty"`
|
ResId string `json:"resid,omitempty"`
|
||||||
Timeout int `json:"timeout,omitempty"`
|
Timeout int `json:"timeout,omitempty"`
|
||||||
Route string `json:"route,omitempty"` // to route/forward requests to alternate servers
|
Route string `json:"route,omitempty"` // to route/forward requests to alternate servers
|
||||||
Source string `json:"source,omitempty"` // source route id
|
AuthToken string `json:"authtoken,omitempty"` // needed for routing unauthenticated requests (WshRpcMultiProxy)
|
||||||
Cont bool `json:"cont,omitempty"` // flag if additional requests/responses are forthcoming
|
Source string `json:"source,omitempty"` // source route id
|
||||||
Cancel bool `json:"cancel,omitempty"` // used to cancel a streaming request or response (sent from the side that is not streaming)
|
Cont bool `json:"cont,omitempty"` // flag if additional requests/responses are forthcoming
|
||||||
Error string `json:"error,omitempty"`
|
Cancel bool `json:"cancel,omitempty"` // used to cancel a streaming request or response (sent from the side that is not streaming)
|
||||||
DataType string `json:"datatype,omitempty"`
|
Error string `json:"error,omitempty"`
|
||||||
Data any `json:"data,omitempty"`
|
DataType string `json:"datatype,omitempty"`
|
||||||
|
Data any `json:"data,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RpcMessage) IsRpcRequest() bool {
|
func (r *RpcMessage) IsRpcRequest() bool {
|
||||||
@ -226,6 +230,14 @@ func (w *WshRpc) SetRpcContext(ctx wshrpc.RpcContext) {
|
|||||||
w.RpcContext.Store(&ctx)
|
w.RpcContext.Store(&ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *WshRpc) SetAuthToken(token string) {
|
||||||
|
w.AuthToken = token
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *WshRpc) GetAuthToken() string {
|
||||||
|
return w.AuthToken
|
||||||
|
}
|
||||||
|
|
||||||
func (w *WshRpc) registerResponseHandler(reqId string, handler *RpcResponseHandler) {
|
func (w *WshRpc) registerResponseHandler(reqId string, handler *RpcResponseHandler) {
|
||||||
w.Lock.Lock()
|
w.Lock.Lock()
|
||||||
defer w.Lock.Unlock()
|
defer w.Lock.Unlock()
|
||||||
@ -323,6 +335,9 @@ func (w *WshRpc) handleRequest(req *RpcMessage) {
|
|||||||
func (w *WshRpc) runServer() {
|
func (w *WshRpc) runServer() {
|
||||||
defer close(w.OutputCh)
|
defer close(w.OutputCh)
|
||||||
for msgBytes := range w.InputCh {
|
for msgBytes := range w.InputCh {
|
||||||
|
if w.Debug {
|
||||||
|
log.Printf("[%s] received message: %s\n", w.DebugName, string(msgBytes))
|
||||||
|
}
|
||||||
var msg RpcMessage
|
var msg RpcMessage
|
||||||
err := json.Unmarshal(msgBytes, &msg)
|
err := json.Unmarshal(msgBytes, &msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -455,8 +470,9 @@ func (handler *RpcRequestHandler) SendCancel() {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
msg := &RpcMessage{
|
msg := &RpcMessage{
|
||||||
Cancel: true,
|
Cancel: true,
|
||||||
ReqId: handler.reqId,
|
ReqId: handler.reqId,
|
||||||
|
AuthToken: handler.w.GetAuthToken(),
|
||||||
}
|
}
|
||||||
barr, _ := json.Marshal(msg) // will never fail
|
barr, _ := json.Marshal(msg) // will never fail
|
||||||
handler.w.OutputCh <- barr
|
handler.w.OutputCh <- barr
|
||||||
@ -550,6 +566,7 @@ func (handler *RpcResponseHandler) SendMessage(msg string) {
|
|||||||
Data: wshrpc.CommandMessageData{
|
Data: wshrpc.CommandMessageData{
|
||||||
Message: msg,
|
Message: msg,
|
||||||
},
|
},
|
||||||
|
AuthToken: handler.w.GetAuthToken(),
|
||||||
}
|
}
|
||||||
msgBytes, _ := json.Marshal(rpcMsg) // will never fail
|
msgBytes, _ := json.Marshal(rpcMsg) // will never fail
|
||||||
handler.w.OutputCh <- msgBytes
|
handler.w.OutputCh <- msgBytes
|
||||||
@ -573,9 +590,10 @@ func (handler *RpcResponseHandler) SendResponse(data any, done bool) error {
|
|||||||
defer handler.close()
|
defer handler.close()
|
||||||
}
|
}
|
||||||
msg := &RpcMessage{
|
msg := &RpcMessage{
|
||||||
ResId: handler.reqId,
|
ResId: handler.reqId,
|
||||||
Data: data,
|
Data: data,
|
||||||
Cont: !done,
|
Cont: !done,
|
||||||
|
AuthToken: handler.w.GetAuthToken(),
|
||||||
}
|
}
|
||||||
barr, err := json.Marshal(msg)
|
barr, err := json.Marshal(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -598,8 +616,9 @@ func (handler *RpcResponseHandler) SendResponseError(err error) {
|
|||||||
}
|
}
|
||||||
defer handler.close()
|
defer handler.close()
|
||||||
msg := &RpcMessage{
|
msg := &RpcMessage{
|
||||||
ResId: handler.reqId,
|
ResId: handler.reqId,
|
||||||
Error: err.Error(),
|
Error: err.Error(),
|
||||||
|
AuthToken: handler.w.GetAuthToken(),
|
||||||
}
|
}
|
||||||
barr, _ := json.Marshal(msg) // will never fail
|
barr, _ := json.Marshal(msg) // will never fail
|
||||||
handler.w.OutputCh <- barr
|
handler.w.OutputCh <- barr
|
||||||
@ -660,11 +679,12 @@ func (w *WshRpc) SendComplexRequest(command string, data any, opts *wshrpc.RpcOp
|
|||||||
handler.reqId = uuid.New().String()
|
handler.reqId = uuid.New().String()
|
||||||
}
|
}
|
||||||
req := &RpcMessage{
|
req := &RpcMessage{
|
||||||
Command: command,
|
Command: command,
|
||||||
ReqId: handler.reqId,
|
ReqId: handler.reqId,
|
||||||
Data: data,
|
Data: data,
|
||||||
Timeout: timeoutMs,
|
Timeout: timeoutMs,
|
||||||
Route: opts.Route,
|
Route: opts.Route,
|
||||||
|
AuthToken: w.GetAuthToken(),
|
||||||
}
|
}
|
||||||
barr, err := json.Marshal(req)
|
barr, err := json.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -19,6 +19,7 @@ import (
|
|||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/util/packetparser"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||||
"golang.org/x/term"
|
"golang.org/x/term"
|
||||||
@ -204,11 +205,26 @@ func SetupTerminalRpcClient(serverImpl ServerImpl) (*WshRpc, io.Reader) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
os.Stdout.Write(barr)
|
os.Stdout.Write(barr)
|
||||||
|
os.Stdout.Write([]byte{'\n'})
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return rpcClient, ptyBuf
|
return rpcClient, ptyBuf
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SetupPacketRpcClient(input io.Reader, output io.Writer, serverImpl ServerImpl) (*WshRpc, chan []byte) {
|
||||||
|
messageCh := make(chan []byte, DefaultInputChSize)
|
||||||
|
outputCh := make(chan []byte, DefaultOutputChSize)
|
||||||
|
rawCh := make(chan []byte, DefaultOutputChSize)
|
||||||
|
rpcClient := MakeWshRpc(messageCh, outputCh, wshrpc.RpcContext{}, serverImpl)
|
||||||
|
go packetparser.Parse(input, messageCh, rawCh)
|
||||||
|
go func() {
|
||||||
|
for msg := range outputCh {
|
||||||
|
packetparser.WritePacket(output, msg)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return rpcClient, rawCh
|
||||||
|
}
|
||||||
|
|
||||||
func SetupConnRpcClient(conn net.Conn, serverImpl ServerImpl) (*WshRpc, chan error, error) {
|
func SetupConnRpcClient(conn net.Conn, serverImpl ServerImpl) (*WshRpc, chan error, error) {
|
||||||
inputCh := make(chan []byte, DefaultInputChSize)
|
inputCh := make(chan []byte, DefaultInputChSize)
|
||||||
outputCh := make(chan []byte, DefaultOutputChSize)
|
outputCh := make(chan []byte, DefaultOutputChSize)
|
||||||
@ -229,10 +245,22 @@ func SetupConnRpcClient(conn net.Conn, serverImpl ServerImpl) (*WshRpc, chan err
|
|||||||
return rtn, writeErrCh, nil
|
return rtn, writeErrCh, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetupDomainSocketRpcClient(sockName string, serverImpl ServerImpl) (*WshRpc, error) {
|
func tryTcpSocket(sockName string) (net.Conn, error) {
|
||||||
conn, err := net.Dial("unix", sockName)
|
addr, err := net.ResolveTCPAddr("tcp", sockName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to connect to Unix domain socket: %w", err)
|
return nil, err
|
||||||
|
}
|
||||||
|
return net.DialTCP("tcp", nil, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetupDomainSocketRpcClient(sockName string, serverImpl ServerImpl) (*WshRpc, error) {
|
||||||
|
conn, tcpErr := tryTcpSocket(sockName)
|
||||||
|
var unixErr error
|
||||||
|
if tcpErr != nil {
|
||||||
|
conn, unixErr = net.Dial("unix", sockName)
|
||||||
|
}
|
||||||
|
if tcpErr != nil && unixErr != nil {
|
||||||
|
return nil, fmt.Errorf("failed to connect to tcp or unix domain socket: tcp err:%w: unix socket err: %w", tcpErr, unixErr)
|
||||||
}
|
}
|
||||||
rtn, errCh, err := SetupConnRpcClient(conn, serverImpl)
|
rtn, errCh, err := SetupConnRpcClient(conn, serverImpl)
|
||||||
go func() {
|
go func() {
|
||||||
@ -363,6 +391,46 @@ func MakeRouteIdFromCtx(rpcCtx *wshrpc.RpcContext) (string, error) {
|
|||||||
return MakeProcRouteId(procId), nil
|
return MakeProcRouteId(procId), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type WriteFlusher interface {
|
||||||
|
Write([]byte) (int, error)
|
||||||
|
Flush() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// blocking, returns if there is an error, or on EOF of input
|
||||||
|
func HandleStdIOClient(logName string, input io.Reader, output io.Writer) {
|
||||||
|
proxy := MakeRpcMultiProxy()
|
||||||
|
rawCh := make(chan []byte, DefaultInputChSize)
|
||||||
|
go packetparser.Parse(input, proxy.FromRemoteRawCh, rawCh)
|
||||||
|
doneCh := make(chan struct{})
|
||||||
|
var doneOnce sync.Once
|
||||||
|
closeDoneCh := func() {
|
||||||
|
doneOnce.Do(func() {
|
||||||
|
close(doneCh)
|
||||||
|
})
|
||||||
|
proxy.DisposeRoutes()
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
proxy.RunUnauthLoop()
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer closeDoneCh()
|
||||||
|
for msg := range proxy.ToRemoteCh {
|
||||||
|
err := packetparser.WritePacket(output, msg)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[%s] error writing to output: %v\n", logName, err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer closeDoneCh()
|
||||||
|
for msg := range rawCh {
|
||||||
|
log.Printf("[%s:stdout] %s", logName, msg)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
<-doneCh
|
||||||
|
}
|
||||||
|
|
||||||
func handleDomainSocketClient(conn net.Conn) {
|
func handleDomainSocketClient(conn net.Conn) {
|
||||||
var routeIdContainer atomic.Pointer[string]
|
var routeIdContainer atomic.Pointer[string]
|
||||||
proxy := MakeRpcProxy()
|
proxy := MakeRpcProxy()
|
||||||
@ -399,7 +467,7 @@ func handleDomainSocketClient(conn net.Conn) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
routeIdContainer.Store(&routeId)
|
routeIdContainer.Store(&routeId)
|
||||||
DefaultRouter.RegisterRoute(routeId, proxy)
|
DefaultRouter.RegisterRoute(routeId, proxy, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
// only for use on client
|
// only for use on client
|
||||||
@ -433,5 +501,6 @@ func ExtractUnverifiedSocketName(tokenStr string) (string, error) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return "", fmt.Errorf("sock claim is missing or invalid")
|
return "", fmt.Errorf("sock claim is missing or invalid")
|
||||||
}
|
}
|
||||||
|
sockName = wavebase.ExpandHomeDirSafe(sockName)
|
||||||
return sockName, nil
|
return sockName, nil
|
||||||
}
|
}
|
||||||
|
67
pkg/wsl/wsl-unix.go
Normal file
67
pkg/wsl/wsl-unix.go
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
// Copyright 2024, Command Line Inc.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package wsl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
)
|
||||||
|
|
||||||
|
func RegisteredDistros(ctx context.Context) (distros []Distro, err error) {
|
||||||
|
return nil, fmt.Errorf("RegisteredDistros not implemented on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
func DefaultDistro(ctx context.Context) (d Distro, ok bool, err error) {
|
||||||
|
return d, false, fmt.Errorf("DefaultDistro not implemented on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
type Distro struct{}
|
||||||
|
|
||||||
|
func (d *Distro) Name() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Distro) WslCommand(ctx context.Context, cmd string) *WslCmd {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// just use the regular cmd since it's
|
||||||
|
// similar enough to not cause issues
|
||||||
|
// type WslCmd = exec.Cmd
|
||||||
|
type WslCmd struct {
|
||||||
|
exec.Cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wc *WslCmd) GetProcess() *os.Process {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wc *WslCmd) GetProcessState() *os.ProcessState {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WslCmd) SetStdin(stdin io.Reader) {
|
||||||
|
c.Stdin = stdin
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WslCmd) SetStdout(stdout io.Writer) {
|
||||||
|
c.Stdout = stdout
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WslCmd) SetStderr(stderr io.Writer) {
|
||||||
|
c.Stdout = stderr
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetDistroCmd(ctx context.Context, wslDistroName string, cmd string) (*WslCmd, error) {
|
||||||
|
return nil, fmt.Errorf("GetDistroCmd not implemented on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetDistro(ctx context.Context, wslDistroName WslName) (*Distro, error) {
|
||||||
|
return nil, fmt.Errorf("GetDistro not implemented on this system")
|
||||||
|
}
|
296
pkg/wsl/wsl-util.go
Normal file
296
pkg/wsl/wsl-util.go
Normal file
@ -0,0 +1,296 @@
|
|||||||
|
// Copyright 2024, Command Line Inc.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package wsl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"html/template"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func DetectShell(ctx context.Context, client *Distro) (string, error) {
|
||||||
|
wshPath := GetWshPath(ctx, client)
|
||||||
|
|
||||||
|
cmd := client.WslCommand(ctx, wshPath+" shell")
|
||||||
|
log.Printf("shell detecting using command: %s shell", wshPath)
|
||||||
|
out, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("unable to determine shell. defaulting to /bin/bash: %s", err)
|
||||||
|
return "/bin/bash", nil
|
||||||
|
}
|
||||||
|
log.Printf("detecting shell: %s", out)
|
||||||
|
|
||||||
|
// quoting breaks this particular case
|
||||||
|
return strings.TrimSpace(string(out)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetWshVersion(ctx context.Context, client *Distro) (string, error) {
|
||||||
|
wshPath := GetWshPath(ctx, client)
|
||||||
|
|
||||||
|
cmd := client.WslCommand(ctx, wshPath+" version")
|
||||||
|
out, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.TrimSpace(string(out)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetWshPath(ctx context.Context, client *Distro) string {
|
||||||
|
defaultPath := "~/.waveterm/bin/wsh"
|
||||||
|
|
||||||
|
cmd := client.WslCommand(ctx, "which wsh")
|
||||||
|
out, whichErr := cmd.Output()
|
||||||
|
if whichErr == nil {
|
||||||
|
return strings.TrimSpace(string(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd = client.WslCommand(ctx, "where.exe wsh")
|
||||||
|
out, whereErr := cmd.Output()
|
||||||
|
if whereErr == nil {
|
||||||
|
return strings.TrimSpace(string(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
// check cmd on windows since it requires an absolute path with backslashes
|
||||||
|
cmd = client.WslCommand(ctx, "(dir 2>&1 *``|echo %userprofile%\\.waveterm%\\.waveterm\\bin\\wsh.exe);&<# rem #>echo none")
|
||||||
|
out, cmdErr := cmd.Output()
|
||||||
|
if cmdErr == nil && strings.TrimSpace(string(out)) != "none" {
|
||||||
|
return strings.TrimSpace(string(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
// no custom install, use default path
|
||||||
|
return defaultPath
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasBashInstalled(ctx context.Context, client *Distro) (bool, error) {
|
||||||
|
cmd := client.WslCommand(ctx, "which bash")
|
||||||
|
out, whichErr := cmd.Output()
|
||||||
|
if whichErr == nil && len(out) != 0 {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd = client.WslCommand(ctx, "where.exe bash")
|
||||||
|
out, whereErr := cmd.Output()
|
||||||
|
if whereErr == nil && len(out) != 0 {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// note: we could also check in /bin/bash explicitly
|
||||||
|
// just in case that wasn't added to the path. but if
|
||||||
|
// that's true, we will most likely have worse
|
||||||
|
// problems going forward
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetClientOs(ctx context.Context, client *Distro) (string, error) {
|
||||||
|
cmd := client.WslCommand(ctx, "uname -s")
|
||||||
|
out, unixErr := cmd.Output()
|
||||||
|
if unixErr == nil {
|
||||||
|
formatted := strings.ToLower(string(out))
|
||||||
|
formatted = strings.TrimSpace(formatted)
|
||||||
|
return formatted, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd = client.WslCommand(ctx, "echo %OS%")
|
||||||
|
out, cmdErr := cmd.Output()
|
||||||
|
if cmdErr == nil && strings.TrimSpace(string(out)) != "%OS%" {
|
||||||
|
formatted := strings.ToLower(string(out))
|
||||||
|
formatted = strings.TrimSpace(formatted)
|
||||||
|
return strings.Split(formatted, "_")[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd = client.WslCommand(ctx, "echo $env:OS")
|
||||||
|
out, psErr := cmd.Output()
|
||||||
|
if psErr == nil && strings.TrimSpace(string(out)) != "$env:OS" {
|
||||||
|
formatted := strings.ToLower(string(out))
|
||||||
|
formatted = strings.TrimSpace(formatted)
|
||||||
|
return strings.Split(formatted, "_")[0], nil
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("unable to determine os: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetClientArch(ctx context.Context, client *Distro) (string, error) {
|
||||||
|
cmd := client.WslCommand(ctx, "uname -m")
|
||||||
|
out, unixErr := cmd.Output()
|
||||||
|
if unixErr == nil {
|
||||||
|
formatted := strings.ToLower(string(out))
|
||||||
|
formatted = strings.TrimSpace(formatted)
|
||||||
|
if formatted == "x86_64" {
|
||||||
|
return "x64", nil
|
||||||
|
}
|
||||||
|
return formatted, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd = client.WslCommand(ctx, "echo %PROCESSOR_ARCHITECTURE%")
|
||||||
|
out, cmdErr := cmd.Output()
|
||||||
|
if cmdErr == nil && strings.TrimSpace(string(out)) != "%PROCESSOR_ARCHITECTURE%" {
|
||||||
|
formatted := strings.ToLower(string(out))
|
||||||
|
return strings.TrimSpace(formatted), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd = client.WslCommand(ctx, "echo $env:PROCESSOR_ARCHITECTURE")
|
||||||
|
out, psErr := cmd.Output()
|
||||||
|
if psErr == nil && strings.TrimSpace(string(out)) != "$env:PROCESSOR_ARCHITECTURE" {
|
||||||
|
formatted := strings.ToLower(string(out))
|
||||||
|
return strings.TrimSpace(formatted), nil
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("unable to determine architecture: {unix: %s, cmd: %s, powershell: %s}", unixErr, cmdErr, psErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
type CancellableCmd struct {
|
||||||
|
Cmd *WslCmd
|
||||||
|
Cancel func()
|
||||||
|
}
|
||||||
|
|
||||||
|
var installTemplatesRawBash = map[string]string{
|
||||||
|
"mkdir": `bash -c 'mkdir -p {{.installDir}}'`,
|
||||||
|
"cat": `bash -c 'cat > {{.tempPath}}'`,
|
||||||
|
"mv": `bash -c 'mv {{.tempPath}} {{.installPath}}'`,
|
||||||
|
"chmod": `bash -c 'chmod a+x {{.installPath}}'`,
|
||||||
|
}
|
||||||
|
|
||||||
|
var installTemplatesRawDefault = map[string]string{
|
||||||
|
"mkdir": `mkdir -p {{.installDir}}`,
|
||||||
|
"cat": `cat > {{.tempPath}}`,
|
||||||
|
"mv": `mv {{.tempPath}} {{.installPath}}`,
|
||||||
|
"chmod": `chmod a+x {{.installPath}}`,
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeCancellableCommand(ctx context.Context, client *Distro, cmdTemplateRaw string, words map[string]string) (*CancellableCmd, error) {
|
||||||
|
cmdContext, cmdCancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
|
cmdStr := &bytes.Buffer{}
|
||||||
|
cmdTemplate, err := template.New("").Parse(cmdTemplateRaw)
|
||||||
|
if err != nil {
|
||||||
|
cmdCancel()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cmdTemplate.Execute(cmdStr, words)
|
||||||
|
|
||||||
|
cmd := client.WslCommand(cmdContext, cmdStr.String())
|
||||||
|
return &CancellableCmd{cmd, cmdCancel}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CpHostToRemote(ctx context.Context, client *Distro, sourcePath string, destPath string) error {
|
||||||
|
// warning: does not work on windows remote yet
|
||||||
|
bashInstalled, err := hasBashInstalled(ctx, client)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var selectedTemplatesRaw map[string]string
|
||||||
|
if bashInstalled {
|
||||||
|
selectedTemplatesRaw = installTemplatesRawBash
|
||||||
|
} else {
|
||||||
|
log.Printf("bash is not installed on remote. attempting with default shell")
|
||||||
|
selectedTemplatesRaw = installTemplatesRawDefault
|
||||||
|
}
|
||||||
|
|
||||||
|
// I need to use toSlash here to force unix keybindings
|
||||||
|
// this means we can't guarantee it will work on a remote windows machine
|
||||||
|
var installWords = map[string]string{
|
||||||
|
"installDir": filepath.ToSlash(filepath.Dir(destPath)),
|
||||||
|
"tempPath": destPath + ".temp",
|
||||||
|
"installPath": destPath,
|
||||||
|
}
|
||||||
|
|
||||||
|
installStepCmds := make(map[string]*CancellableCmd)
|
||||||
|
for cmdName, selectedTemplateRaw := range selectedTemplatesRaw {
|
||||||
|
cancellableCmd, err := makeCancellableCommand(ctx, client, selectedTemplateRaw, installWords)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
installStepCmds[cmdName] = cancellableCmd
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = installStepCmds["mkdir"].Cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// the cat part of this is complicated since it requires stdin
|
||||||
|
catCmd := installStepCmds["cat"].Cmd
|
||||||
|
catStdin, err := catCmd.StdinPipe()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = catCmd.Start()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
input, err := os.Open(sourcePath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot open local file %s to send to host: %v", sourcePath, err)
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
io.Copy(catStdin, input)
|
||||||
|
installStepCmds["cat"].Cancel()
|
||||||
|
|
||||||
|
// backup just in case something weird happens
|
||||||
|
// could cause potential race condition, but very
|
||||||
|
// unlikely
|
||||||
|
time.Sleep(time.Second * 1)
|
||||||
|
process := catCmd.GetProcess()
|
||||||
|
if process != nil {
|
||||||
|
process.Kill()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
catErr := catCmd.Wait()
|
||||||
|
if catErr != nil && !errors.Is(catErr, context.Canceled) {
|
||||||
|
return catErr
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = installStepCmds["mv"].Cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = installStepCmds["chmod"].Cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func InstallClientRcFiles(ctx context.Context, client *Distro) error {
|
||||||
|
path := GetWshPath(ctx, client)
|
||||||
|
log.Printf("path to wsh searched is: %s", path)
|
||||||
|
|
||||||
|
cmd := client.WslCommand(ctx, path+" rcfiles")
|
||||||
|
_, err := cmd.Output()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetHomeDir(ctx context.Context, client *Distro) string {
|
||||||
|
// note: also works for powershell
|
||||||
|
cmd := client.WslCommand(ctx, `echo "$HOME"`)
|
||||||
|
out, err := cmd.Output()
|
||||||
|
if err == nil {
|
||||||
|
return strings.TrimSpace(string(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd = client.WslCommand(ctx, `echo %userprofile%`)
|
||||||
|
out, err = cmd.Output()
|
||||||
|
if err == nil {
|
||||||
|
return strings.TrimSpace(string(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
return "~"
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsPowershell(shellPath string) bool {
|
||||||
|
// get the base path, and then check contains
|
||||||
|
shellBase := filepath.Base(shellPath)
|
||||||
|
return strings.Contains(shellBase, "powershell") || strings.Contains(shellBase, "pwsh")
|
||||||
|
}
|
125
pkg/wsl/wsl-win.go
Normal file
125
pkg/wsl/wsl-win.go
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
// Copyright 2024, Command Line Inc.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package wsl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/ubuntu/gowsl"
|
||||||
|
)
|
||||||
|
|
||||||
|
var RegisteredDistros = gowsl.RegisteredDistros
|
||||||
|
var DefaultDistro = gowsl.DefaultDistro
|
||||||
|
|
||||||
|
type Distro struct {
|
||||||
|
gowsl.Distro
|
||||||
|
}
|
||||||
|
|
||||||
|
type WslCmd struct {
|
||||||
|
c *gowsl.Cmd
|
||||||
|
wg *sync.WaitGroup
|
||||||
|
once *sync.Once
|
||||||
|
lock *sync.Mutex
|
||||||
|
waitErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Distro) WslCommand(ctx context.Context, cmd string) *WslCmd {
|
||||||
|
if ctx == nil {
|
||||||
|
panic("nil Context")
|
||||||
|
}
|
||||||
|
innerCmd := d.Command(ctx, cmd)
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
var lock *sync.Mutex
|
||||||
|
return &WslCmd{innerCmd, &wg, new(sync.Once), lock, nil}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WslCmd) CombinedOutput() (out []byte, err error) {
|
||||||
|
return c.c.CombinedOutput()
|
||||||
|
}
|
||||||
|
func (c *WslCmd) Output() (out []byte, err error) {
|
||||||
|
return c.c.Output()
|
||||||
|
}
|
||||||
|
func (c *WslCmd) Run() error {
|
||||||
|
return c.c.Run()
|
||||||
|
}
|
||||||
|
func (c *WslCmd) Start() (err error) {
|
||||||
|
return c.c.Start()
|
||||||
|
}
|
||||||
|
func (c *WslCmd) StderrPipe() (r io.ReadCloser, err error) {
|
||||||
|
return c.c.StderrPipe()
|
||||||
|
}
|
||||||
|
func (c *WslCmd) StdinPipe() (w io.WriteCloser, err error) {
|
||||||
|
return c.c.StdinPipe()
|
||||||
|
}
|
||||||
|
func (c *WslCmd) StdoutPipe() (r io.ReadCloser, err error) {
|
||||||
|
return c.c.StdoutPipe()
|
||||||
|
}
|
||||||
|
func (c *WslCmd) Wait() (err error) {
|
||||||
|
c.wg.Add(1)
|
||||||
|
c.once.Do(func() {
|
||||||
|
c.waitErr = c.c.Wait()
|
||||||
|
})
|
||||||
|
c.wg.Done()
|
||||||
|
c.wg.Wait()
|
||||||
|
if c.waitErr != nil && c.waitErr.Error() == "not started" {
|
||||||
|
c.once = new(sync.Once)
|
||||||
|
return c.waitErr
|
||||||
|
}
|
||||||
|
return c.waitErr
|
||||||
|
}
|
||||||
|
func (c *WslCmd) GetProcess() *os.Process {
|
||||||
|
return c.c.Process
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WslCmd) GetProcessState() *os.ProcessState {
|
||||||
|
return c.c.ProcessState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WslCmd) SetStdin(stdin io.Reader) {
|
||||||
|
c.c.Stdin = stdin
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WslCmd) SetStdout(stdout io.Writer) {
|
||||||
|
c.c.Stdout = stdout
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WslCmd) SetStderr(stderr io.Writer) {
|
||||||
|
c.c.Stdout = stderr
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetDistroCmd(ctx context.Context, wslDistroName string, cmd string) (*WslCmd, error) {
|
||||||
|
distros, err := RegisteredDistros(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, distro := range distros {
|
||||||
|
if distro.Name() != wslDistroName {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
wrappedDistro := Distro{distro}
|
||||||
|
return wrappedDistro.WslCommand(ctx, cmd), nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("wsl distro %s not found", wslDistroName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetDistro(ctx context.Context, wslDistroName WslName) (*Distro, error) {
|
||||||
|
distros, err := RegisteredDistros(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, distro := range distros {
|
||||||
|
if distro.Name() != wslDistroName.Distro {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
wrappedDistro := Distro{distro}
|
||||||
|
return &wrappedDistro, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("wsl distro %s not found", wslDistroName)
|
||||||
|
}
|
494
pkg/wsl/wsl.go
Normal file
494
pkg/wsl/wsl.go
Normal file
@ -0,0 +1,494 @@
|
|||||||
|
// Copyright 2024, Command Line Inc.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package wsl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/userinput"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/util/shellutil"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/waveobj"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/wconfig"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/wps"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/wshrpc"
|
||||||
|
"github.com/wavetermdev/waveterm/pkg/wshutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
Status_Init = "init"
|
||||||
|
Status_Connecting = "connecting"
|
||||||
|
Status_Connected = "connected"
|
||||||
|
Status_Disconnected = "disconnected"
|
||||||
|
Status_Error = "error"
|
||||||
|
)
|
||||||
|
|
||||||
|
const DefaultConnectionTimeout = 60 * time.Second
|
||||||
|
|
||||||
|
var globalLock = &sync.Mutex{}
|
||||||
|
var clientControllerMap = make(map[string]*WslConn)
|
||||||
|
var activeConnCounter = &atomic.Int32{}
|
||||||
|
|
||||||
|
type WslConn struct {
|
||||||
|
Lock *sync.Mutex
|
||||||
|
Status string
|
||||||
|
Name WslName
|
||||||
|
Client *Distro
|
||||||
|
SockName string
|
||||||
|
DomainSockListener net.Listener
|
||||||
|
ConnController *WslCmd
|
||||||
|
Error string
|
||||||
|
HasWaiter *atomic.Bool
|
||||||
|
LastConnectTime int64
|
||||||
|
ActiveConnNum int
|
||||||
|
Context context.Context
|
||||||
|
cancelFn func()
|
||||||
|
}
|
||||||
|
|
||||||
|
type WslName struct {
|
||||||
|
Distro string `json:"distro"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetAllConnStatus() []wshrpc.ConnStatus {
|
||||||
|
globalLock.Lock()
|
||||||
|
defer globalLock.Unlock()
|
||||||
|
|
||||||
|
var connStatuses []wshrpc.ConnStatus
|
||||||
|
for _, conn := range clientControllerMap {
|
||||||
|
connStatuses = append(connStatuses, conn.DeriveConnStatus())
|
||||||
|
}
|
||||||
|
return connStatuses
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) DeriveConnStatus() wshrpc.ConnStatus {
|
||||||
|
conn.Lock.Lock()
|
||||||
|
defer conn.Lock.Unlock()
|
||||||
|
return wshrpc.ConnStatus{
|
||||||
|
Status: conn.Status,
|
||||||
|
Connected: conn.Status == Status_Connected,
|
||||||
|
Connection: conn.GetName(),
|
||||||
|
HasConnected: (conn.LastConnectTime > 0),
|
||||||
|
ActiveConnNum: conn.ActiveConnNum,
|
||||||
|
Error: conn.Error,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) FireConnChangeEvent() {
|
||||||
|
status := conn.DeriveConnStatus()
|
||||||
|
event := wps.WaveEvent{
|
||||||
|
Event: wps.Event_ConnChange,
|
||||||
|
Scopes: []string{
|
||||||
|
fmt.Sprintf("connection:%s", conn.GetName()),
|
||||||
|
},
|
||||||
|
Data: status,
|
||||||
|
}
|
||||||
|
log.Printf("sending event: %+#v", event)
|
||||||
|
wps.Broker.Publish(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) Close() error {
|
||||||
|
defer conn.FireConnChangeEvent()
|
||||||
|
conn.WithLock(func() {
|
||||||
|
if conn.Status == Status_Connected || conn.Status == Status_Connecting {
|
||||||
|
// if status is init, disconnected, or error don't change it
|
||||||
|
conn.Status = Status_Disconnected
|
||||||
|
}
|
||||||
|
conn.close_nolock()
|
||||||
|
})
|
||||||
|
// we must wait for the waiter to complete
|
||||||
|
startTime := time.Now()
|
||||||
|
for conn.HasWaiter.Load() {
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
if time.Since(startTime) > 2*time.Second {
|
||||||
|
return fmt.Errorf("timeout waiting for waiter to complete")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) close_nolock() {
|
||||||
|
// does not set status (that should happen at another level)
|
||||||
|
if conn.DomainSockListener != nil {
|
||||||
|
conn.DomainSockListener.Close()
|
||||||
|
conn.DomainSockListener = nil
|
||||||
|
}
|
||||||
|
if conn.ConnController != nil {
|
||||||
|
conn.cancelFn() // this suspends the conn controller
|
||||||
|
conn.ConnController = nil
|
||||||
|
}
|
||||||
|
if conn.Client != nil {
|
||||||
|
// conn.Client.Close() is not relevant here
|
||||||
|
// we do not want to completely close the wsl in case
|
||||||
|
// other applications are using it
|
||||||
|
conn.Client = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) GetDomainSocketName() string {
|
||||||
|
conn.Lock.Lock()
|
||||||
|
defer conn.Lock.Unlock()
|
||||||
|
return conn.SockName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) GetStatus() string {
|
||||||
|
conn.Lock.Lock()
|
||||||
|
defer conn.Lock.Unlock()
|
||||||
|
return conn.Status
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) GetName() string {
|
||||||
|
// no lock required because opts is immutable
|
||||||
|
return "wsl://" + conn.Name.Distro
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This function is does not set a listener for WslConn
|
||||||
|
* It is still required in order to set SockName
|
||||||
|
**/
|
||||||
|
func (conn *WslConn) OpenDomainSocketListener() error {
|
||||||
|
var allowed bool
|
||||||
|
conn.WithLock(func() {
|
||||||
|
if conn.Status != Status_Connecting {
|
||||||
|
allowed = false
|
||||||
|
} else {
|
||||||
|
allowed = true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if !allowed {
|
||||||
|
return fmt.Errorf("cannot open domain socket for %q when status is %q", conn.GetName(), conn.GetStatus())
|
||||||
|
}
|
||||||
|
conn.WithLock(func() {
|
||||||
|
conn.SockName = "~/.waveterm/wave-remote.sock"
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) StartConnServer() error {
|
||||||
|
var allowed bool
|
||||||
|
conn.WithLock(func() {
|
||||||
|
if conn.Status != Status_Connecting {
|
||||||
|
allowed = false
|
||||||
|
} else {
|
||||||
|
allowed = true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if !allowed {
|
||||||
|
return fmt.Errorf("cannot start conn server for %q when status is %q", conn.GetName(), conn.GetStatus())
|
||||||
|
}
|
||||||
|
client := conn.GetClient()
|
||||||
|
wshPath := GetWshPath(conn.Context, client)
|
||||||
|
rpcCtx := wshrpc.RpcContext{
|
||||||
|
ClientType: wshrpc.ClientType_ConnServer,
|
||||||
|
Conn: conn.GetName(),
|
||||||
|
}
|
||||||
|
sockName := conn.GetDomainSocketName()
|
||||||
|
jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, sockName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to create jwt token for conn controller: %w", err)
|
||||||
|
}
|
||||||
|
shellPath, err := DetectShell(conn.Context, client)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var cmdStr string
|
||||||
|
if IsPowershell(shellPath) {
|
||||||
|
cmdStr = fmt.Sprintf("$env:%s=\"%s\"; %s connserver --router", wshutil.WaveJwtTokenVarName, jwtToken, wshPath)
|
||||||
|
} else {
|
||||||
|
cmdStr = fmt.Sprintf("%s=\"%s\" %s connserver --router", wshutil.WaveJwtTokenVarName, jwtToken, wshPath)
|
||||||
|
}
|
||||||
|
log.Printf("starting conn controller: %s\n", cmdStr)
|
||||||
|
cmd := client.WslCommand(conn.Context, cmdStr)
|
||||||
|
pipeRead, pipeWrite := io.Pipe()
|
||||||
|
inputPipeRead, inputPipeWrite := io.Pipe()
|
||||||
|
cmd.SetStdout(pipeWrite)
|
||||||
|
cmd.SetStderr(pipeWrite)
|
||||||
|
cmd.SetStdin(inputPipeRead)
|
||||||
|
err = cmd.Start()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to start conn controller: %w", err)
|
||||||
|
}
|
||||||
|
conn.WithLock(func() {
|
||||||
|
conn.ConnController = cmd
|
||||||
|
})
|
||||||
|
// service the I/O
|
||||||
|
go func() {
|
||||||
|
// wait for termination, clear the controller
|
||||||
|
defer conn.WithLock(func() {
|
||||||
|
conn.ConnController = nil
|
||||||
|
})
|
||||||
|
waitErr := cmd.Wait()
|
||||||
|
log.Printf("conn controller (%q) terminated: %v", conn.GetName(), waitErr)
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
logName := fmt.Sprintf("conncontroller:%s", conn.GetName())
|
||||||
|
wshutil.HandleStdIOClient(logName, pipeRead, inputPipeWrite)
|
||||||
|
}()
|
||||||
|
regCtx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancelFn()
|
||||||
|
err = wshutil.DefaultRouter.WaitForRegister(regCtx, wshutil.MakeConnectionRouteId(rpcCtx.Conn))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("timeout waiting for connserver to register")
|
||||||
|
}
|
||||||
|
time.Sleep(300 * time.Millisecond) // TODO remove this sleep (but we need to wait until connserver is "ready")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type WshInstallOpts struct {
|
||||||
|
Force bool
|
||||||
|
NoUserPrompt bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) CheckAndInstallWsh(ctx context.Context, clientDisplayName string, opts *WshInstallOpts) error {
|
||||||
|
if opts == nil {
|
||||||
|
opts = &WshInstallOpts{}
|
||||||
|
}
|
||||||
|
client := conn.GetClient()
|
||||||
|
if client == nil {
|
||||||
|
return fmt.Errorf("client is nil")
|
||||||
|
}
|
||||||
|
// check that correct wsh extensions are installed
|
||||||
|
expectedVersion := fmt.Sprintf("wsh v%s", wavebase.WaveVersion)
|
||||||
|
clientVersion, err := GetWshVersion(ctx, client)
|
||||||
|
if err == nil && clientVersion == expectedVersion && !opts.Force {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var queryText string
|
||||||
|
var title string
|
||||||
|
if opts.Force {
|
||||||
|
queryText = fmt.Sprintf("ReInstalling Wave Shell Extensions (%s) on `%s`\n", wavebase.WaveVersion, clientDisplayName)
|
||||||
|
title = "Install Wave Shell Extensions"
|
||||||
|
} else if err != nil {
|
||||||
|
queryText = fmt.Sprintf("Wave requires Wave Shell Extensions to be \n"+
|
||||||
|
"installed on `%s` \n"+
|
||||||
|
"to ensure a seamless experience. \n\n"+
|
||||||
|
"Would you like to install them?", clientDisplayName)
|
||||||
|
title = "Install Wave Shell Extensions"
|
||||||
|
} else {
|
||||||
|
// don't ask for upgrading the version
|
||||||
|
opts.NoUserPrompt = true
|
||||||
|
}
|
||||||
|
if !opts.NoUserPrompt {
|
||||||
|
request := &userinput.UserInputRequest{
|
||||||
|
ResponseType: "confirm",
|
||||||
|
QueryText: queryText,
|
||||||
|
Title: title,
|
||||||
|
Markdown: true,
|
||||||
|
CheckBoxMsg: "Don't show me this again",
|
||||||
|
}
|
||||||
|
response, err := userinput.GetUserInput(ctx, request)
|
||||||
|
if err != nil || !response.Confirm {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if response.CheckboxStat {
|
||||||
|
meta := waveobj.MetaMapType{
|
||||||
|
wconfig.ConfigKey_ConnAskBeforeWshInstall: false,
|
||||||
|
}
|
||||||
|
err := wconfig.SetBaseConfigValue(meta)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error setting conn:askbeforewshinstall value: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Printf("attempting to install wsh to `%s`", clientDisplayName)
|
||||||
|
clientOs, err := GetClientOs(ctx, client)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
clientArch, err := GetClientArch(ctx, client)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// attempt to install extension
|
||||||
|
wshLocalPath := shellutil.GetWshBinaryPath(wavebase.WaveVersion, clientOs, clientArch)
|
||||||
|
err = CpHostToRemote(ctx, client, wshLocalPath, "~/.waveterm/bin/wsh")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Printf("successfully installed wsh on %s\n", conn.GetName())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) GetClient() *Distro {
|
||||||
|
conn.Lock.Lock()
|
||||||
|
defer conn.Lock.Unlock()
|
||||||
|
return conn.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) Reconnect(ctx context.Context) error {
|
||||||
|
err := conn.Close()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return conn.Connect(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) WaitForConnect(ctx context.Context) error {
|
||||||
|
for {
|
||||||
|
status := conn.DeriveConnStatus()
|
||||||
|
if status.Status == Status_Connected {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if status.Status == Status_Connecting {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return fmt.Errorf("context timeout")
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if status.Status == Status_Init || status.Status == Status_Disconnected {
|
||||||
|
return fmt.Errorf("disconnected")
|
||||||
|
}
|
||||||
|
if status.Status == Status_Error {
|
||||||
|
return fmt.Errorf("error: %v", status.Error)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("unknown status: %q", status.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// does not return an error since that error is stored inside of WslConn
|
||||||
|
func (conn *WslConn) Connect(ctx context.Context) error {
|
||||||
|
var connectAllowed bool
|
||||||
|
conn.WithLock(func() {
|
||||||
|
if conn.Status == Status_Connecting || conn.Status == Status_Connected {
|
||||||
|
connectAllowed = false
|
||||||
|
} else {
|
||||||
|
conn.Status = Status_Connecting
|
||||||
|
conn.Error = ""
|
||||||
|
connectAllowed = true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
log.Printf("Connect %s\n", conn.GetName())
|
||||||
|
if !connectAllowed {
|
||||||
|
return fmt.Errorf("cannot connect to %q when status is %q", conn.GetName(), conn.GetStatus())
|
||||||
|
}
|
||||||
|
conn.FireConnChangeEvent()
|
||||||
|
err := conn.connectInternal(ctx)
|
||||||
|
conn.WithLock(func() {
|
||||||
|
if err != nil {
|
||||||
|
conn.Status = Status_Error
|
||||||
|
conn.Error = err.Error()
|
||||||
|
conn.close_nolock()
|
||||||
|
} else {
|
||||||
|
conn.Status = Status_Connected
|
||||||
|
conn.LastConnectTime = time.Now().UnixMilli()
|
||||||
|
if conn.ActiveConnNum == 0 {
|
||||||
|
conn.ActiveConnNum = int(activeConnCounter.Add(1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
conn.FireConnChangeEvent()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) WithLock(fn func()) {
|
||||||
|
conn.Lock.Lock()
|
||||||
|
defer conn.Lock.Unlock()
|
||||||
|
fn()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) connectInternal(ctx context.Context) error {
|
||||||
|
client, err := GetDistro(ctx, conn.Name)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
conn.WithLock(func() {
|
||||||
|
conn.Client = client
|
||||||
|
})
|
||||||
|
err = conn.OpenDomainSocketListener()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
config := wconfig.ReadFullConfig()
|
||||||
|
installErr := conn.CheckAndInstallWsh(ctx, conn.GetName(), &WshInstallOpts{NoUserPrompt: !config.Settings.ConnAskBeforeWshInstall})
|
||||||
|
if installErr != nil {
|
||||||
|
return fmt.Errorf("conncontroller %s wsh install error: %v", conn.GetName(), installErr)
|
||||||
|
}
|
||||||
|
csErr := conn.StartConnServer()
|
||||||
|
if csErr != nil {
|
||||||
|
return fmt.Errorf("conncontroller %s start wsh connserver error: %v", conn.GetName(), csErr)
|
||||||
|
}
|
||||||
|
conn.HasWaiter.Store(true)
|
||||||
|
go conn.waitForDisconnect()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *WslConn) waitForDisconnect() {
|
||||||
|
defer conn.FireConnChangeEvent()
|
||||||
|
defer conn.HasWaiter.Store(false)
|
||||||
|
err := conn.ConnController.Wait()
|
||||||
|
conn.WithLock(func() {
|
||||||
|
// disconnects happen for a variety of reasons (like network, etc. and are typically transient)
|
||||||
|
// so we just set the status to "disconnected" here (not error)
|
||||||
|
// don't overwrite any existing error (or error status)
|
||||||
|
if err != nil && conn.Error == "" {
|
||||||
|
conn.Error = err.Error()
|
||||||
|
}
|
||||||
|
if conn.Status != Status_Error {
|
||||||
|
conn.Status = Status_Disconnected
|
||||||
|
}
|
||||||
|
conn.close_nolock()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func getConnInternal(name string) *WslConn {
|
||||||
|
globalLock.Lock()
|
||||||
|
defer globalLock.Unlock()
|
||||||
|
connName := WslName{Distro: name}
|
||||||
|
rtn := clientControllerMap[name]
|
||||||
|
if rtn == nil {
|
||||||
|
ctx, cancelFn := context.WithCancel(context.Background())
|
||||||
|
rtn = &WslConn{Lock: &sync.Mutex{}, Status: Status_Init, Name: connName, HasWaiter: &atomic.Bool{}, Context: ctx, cancelFn: cancelFn}
|
||||||
|
clientControllerMap[name] = rtn
|
||||||
|
}
|
||||||
|
return rtn
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetWslConn(ctx context.Context, name string, shouldConnect bool) *WslConn {
|
||||||
|
conn := getConnInternal(name)
|
||||||
|
if conn.Client == nil && shouldConnect {
|
||||||
|
conn.Connect(ctx)
|
||||||
|
}
|
||||||
|
return conn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convenience function for ensuring a connection is established
|
||||||
|
func EnsureConnection(ctx context.Context, connName string) error {
|
||||||
|
if connName == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
conn := GetWslConn(ctx, connName, false)
|
||||||
|
if conn == nil {
|
||||||
|
return fmt.Errorf("connection not found: %s", connName)
|
||||||
|
}
|
||||||
|
connStatus := conn.DeriveConnStatus()
|
||||||
|
switch connStatus.Status {
|
||||||
|
case Status_Connected:
|
||||||
|
return nil
|
||||||
|
case Status_Connecting:
|
||||||
|
return conn.WaitForConnect(ctx)
|
||||||
|
case Status_Init, Status_Disconnected:
|
||||||
|
return conn.Connect(ctx)
|
||||||
|
case Status_Error:
|
||||||
|
return fmt.Errorf("connection error: %s", connStatus.Error)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown connection status %q", connStatus.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func DisconnectClient(connName string) error {
|
||||||
|
conn := getConnInternal(connName)
|
||||||
|
if conn == nil {
|
||||||
|
return fmt.Errorf("client %q not found", connName)
|
||||||
|
}
|
||||||
|
err := conn.Close()
|
||||||
|
return err
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user