mirror of
https://github.com/wavetermdev/waveterm.git
synced 2024-12-22 16:48:23 +01:00
194 lines
5.5 KiB
Go
194 lines
5.5 KiB
Go
// Copyright 2024, Command Line Inc.
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package wshserver
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"reflect"
|
|
|
|
"github.com/wavetermdev/thenextwave/pkg/util/utilfn"
|
|
"github.com/wavetermdev/thenextwave/pkg/wavebase"
|
|
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
|
|
"github.com/wavetermdev/thenextwave/pkg/wshutil"
|
|
)
|
|
|
|
// this file contains the generic types and functions that create and power the WSH server
|
|
|
|
const (
|
|
DefaultOutputChSize = 32
|
|
DefaultInputChSize = 32
|
|
)
|
|
|
|
type WshServer struct{}
|
|
|
|
type WshServerMethodDecl struct {
|
|
Command string
|
|
CommandType string
|
|
MethodName string
|
|
Method reflect.Value
|
|
CommandDataType reflect.Type
|
|
DefaultResponseDataType reflect.Type
|
|
RequestDataTypes []reflect.Type // for streaming requests
|
|
ResponseDataTypes []reflect.Type // for streaming responses
|
|
}
|
|
|
|
var WshServerImpl = WshServer{}
|
|
var contextRType = reflect.TypeOf((*context.Context)(nil)).Elem()
|
|
|
|
func GetWshServerMethod(command string, commandType string, methodName string, methodFunc any) *WshServerMethodDecl {
|
|
methodVal := reflect.ValueOf(methodFunc)
|
|
methodType := methodVal.Type()
|
|
if methodType.Kind() != reflect.Func {
|
|
panic(fmt.Sprintf("methodVal must be a function got [%v]", methodType))
|
|
}
|
|
if methodType.In(0) != contextRType {
|
|
panic(fmt.Sprintf("methodVal must have a context as the first argument %v", methodType))
|
|
}
|
|
var defResponseType reflect.Type
|
|
if methodType.NumOut() > 1 {
|
|
defResponseType = methodType.Out(0)
|
|
}
|
|
rtn := &WshServerMethodDecl{
|
|
Command: command,
|
|
CommandType: commandType,
|
|
MethodName: methodName,
|
|
Method: methodVal,
|
|
CommandDataType: methodType.In(1),
|
|
DefaultResponseDataType: defResponseType,
|
|
}
|
|
return rtn
|
|
}
|
|
|
|
func decodeRtnVals(rtnVals []reflect.Value) (any, error) {
|
|
switch len(rtnVals) {
|
|
case 0:
|
|
return nil, nil
|
|
case 1:
|
|
errIf := rtnVals[0].Interface()
|
|
if errIf == nil {
|
|
return nil, nil
|
|
}
|
|
return nil, errIf.(error)
|
|
case 2:
|
|
errIf := rtnVals[1].Interface()
|
|
if errIf == nil {
|
|
return rtnVals[0].Interface(), nil
|
|
}
|
|
return rtnVals[0].Interface(), errIf.(error)
|
|
default:
|
|
return nil, fmt.Errorf("too many return values: %d", len(rtnVals))
|
|
}
|
|
}
|
|
|
|
func mainWshServerHandler(handler *wshutil.RpcResponseHandler) bool {
|
|
command := handler.GetCommand()
|
|
methodDecl := WshServerCommandToDeclMap[command]
|
|
if methodDecl == nil {
|
|
handler.SendResponseError(fmt.Errorf("command %q not found", command))
|
|
return true
|
|
}
|
|
var callParams []reflect.Value
|
|
callParams = append(callParams, reflect.ValueOf(handler.Context()))
|
|
if methodDecl.CommandDataType != nil {
|
|
commandData := reflect.New(methodDecl.CommandDataType).Interface()
|
|
err := utilfn.ReUnmarshal(commandData, handler.GetCommandRawData())
|
|
if err != nil {
|
|
handler.SendResponseError(fmt.Errorf("error re-marshalling command data: %w", err))
|
|
return true
|
|
}
|
|
wshrpc.HackRpcContextIntoData(commandData, handler.GetRpcContext())
|
|
callParams = append(callParams, reflect.ValueOf(commandData).Elem())
|
|
}
|
|
if methodDecl.CommandType == wshutil.RpcType_Call {
|
|
rtnVals := methodDecl.Method.Call(callParams)
|
|
rtnData, rtnErr := decodeRtnVals(rtnVals)
|
|
if rtnErr != nil {
|
|
handler.SendResponseError(rtnErr)
|
|
return true
|
|
}
|
|
handler.SendResponse(rtnData, true)
|
|
return true
|
|
} else if methodDecl.CommandType == wshutil.RpcType_ResponseStream {
|
|
rtnVals := methodDecl.Method.Call(callParams)
|
|
rtnChVal := rtnVals[0]
|
|
if rtnChVal.IsNil() {
|
|
handler.SendResponse(nil, true)
|
|
return true
|
|
}
|
|
go func() {
|
|
defer handler.Finalize()
|
|
// must use reflection here because we don't know the generic type of RespOrErrorUnion
|
|
for {
|
|
respVal, ok := rtnChVal.Recv()
|
|
if !ok {
|
|
break
|
|
}
|
|
errorVal := respVal.FieldByName("Error")
|
|
if !errorVal.IsNil() {
|
|
handler.SendResponseError(errorVal.Interface().(error))
|
|
break
|
|
}
|
|
respData := respVal.FieldByName("Response").Interface()
|
|
handler.SendResponse(respData, false)
|
|
}
|
|
}()
|
|
return false
|
|
} else {
|
|
handler.SendResponseError(fmt.Errorf("unsupported command type %q", methodDecl.CommandType))
|
|
return true
|
|
}
|
|
}
|
|
|
|
func MakeUnixListener(sockName string) (net.Listener, error) {
|
|
os.Remove(sockName) // ignore error
|
|
rtn, err := net.Listen("unix", sockName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error creating listener at %v: %v", sockName, err)
|
|
}
|
|
os.Chmod(sockName, 0700)
|
|
log.Printf("Server listening on %s\n", sockName)
|
|
return rtn, nil
|
|
}
|
|
|
|
func runWshRpcWithStream(conn net.Conn) {
|
|
defer conn.Close()
|
|
inputCh := make(chan []byte, DefaultInputChSize)
|
|
outputCh := make(chan []byte, DefaultOutputChSize)
|
|
go wshutil.AdaptMsgChToStream(outputCh, conn)
|
|
go wshutil.AdaptStreamToMsgCh(conn, inputCh)
|
|
wshutil.MakeWshRpc(inputCh, outputCh, wshutil.RpcContext{}, mainWshServerHandler)
|
|
}
|
|
|
|
func RunWshRpcOverListener(listener net.Listener) {
|
|
go func() {
|
|
for {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
log.Printf("error accepting connection: %v\n", err)
|
|
continue
|
|
}
|
|
go runWshRpcWithStream(conn)
|
|
}
|
|
}()
|
|
}
|
|
|
|
func RunDomainSocketWshServer() error {
|
|
sockName := wavebase.GetDomainSocketName()
|
|
listener, err := MakeUnixListener(sockName)
|
|
if err != nil {
|
|
return fmt.Errorf("error starging unix listener for wsh-server: %w", err)
|
|
}
|
|
defer listener.Close()
|
|
RunWshRpcOverListener(listener)
|
|
return nil
|
|
}
|
|
|
|
func MakeWshServer(inputCh chan []byte, outputCh chan []byte, initialCtx wshutil.RpcContext) {
|
|
wshutil.MakeWshRpc(inputCh, outputCh, initialCtx, mainWshServerHandler)
|
|
}
|