waveterm/pkg/wshrpc/wshserver/wshserverutil.go

194 lines
5.5 KiB
Go
Raw Normal View History

// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package wshserver
import (
"context"
"fmt"
"log"
"net"
"os"
"reflect"
"github.com/wavetermdev/thenextwave/pkg/util/utilfn"
"github.com/wavetermdev/thenextwave/pkg/wavebase"
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
"github.com/wavetermdev/thenextwave/pkg/wshutil"
)
// this file contains the generic types and functions that create and power the WSH server
const (
DefaultOutputChSize = 32
DefaultInputChSize = 32
)
type WshServer struct{}
type WshServerMethodDecl struct {
Command string
CommandType string
MethodName string
Method reflect.Value
CommandDataType reflect.Type
DefaultResponseDataType reflect.Type
RequestDataTypes []reflect.Type // for streaming requests
ResponseDataTypes []reflect.Type // for streaming responses
}
var WshServerImpl = WshServer{}
var contextRType = reflect.TypeOf((*context.Context)(nil)).Elem()
func GetWshServerMethod(command string, commandType string, methodName string, methodFunc any) *WshServerMethodDecl {
methodVal := reflect.ValueOf(methodFunc)
methodType := methodVal.Type()
if methodType.Kind() != reflect.Func {
panic(fmt.Sprintf("methodVal must be a function got [%v]", methodType))
}
if methodType.In(0) != contextRType {
panic(fmt.Sprintf("methodVal must have a context as the first argument %v", methodType))
}
var defResponseType reflect.Type
if methodType.NumOut() > 1 {
defResponseType = methodType.Out(0)
}
rtn := &WshServerMethodDecl{
Command: command,
CommandType: commandType,
MethodName: methodName,
Method: methodVal,
CommandDataType: methodType.In(1),
DefaultResponseDataType: defResponseType,
}
return rtn
}
func decodeRtnVals(rtnVals []reflect.Value) (any, error) {
switch len(rtnVals) {
case 0:
return nil, nil
case 1:
errIf := rtnVals[0].Interface()
if errIf == nil {
return nil, nil
}
return nil, errIf.(error)
case 2:
errIf := rtnVals[1].Interface()
if errIf == nil {
return rtnVals[0].Interface(), nil
}
return rtnVals[0].Interface(), errIf.(error)
default:
return nil, fmt.Errorf("too many return values: %d", len(rtnVals))
}
}
func mainWshServerHandler(handler *wshutil.RpcResponseHandler) bool {
command := handler.GetCommand()
methodDecl := WshServerCommandToDeclMap[command]
if methodDecl == nil {
handler.SendResponseError(fmt.Errorf("command %q not found", command))
return true
}
var callParams []reflect.Value
callParams = append(callParams, reflect.ValueOf(handler.Context()))
if methodDecl.CommandDataType != nil {
commandData := reflect.New(methodDecl.CommandDataType).Interface()
err := utilfn.ReUnmarshal(commandData, handler.GetCommandRawData())
if err != nil {
handler.SendResponseError(fmt.Errorf("error re-marshalling command data: %w", err))
return true
}
wshrpc.HackRpcContextIntoData(commandData, handler.GetRpcContext())
callParams = append(callParams, reflect.ValueOf(commandData).Elem())
}
if methodDecl.CommandType == wshutil.RpcType_Call {
rtnVals := methodDecl.Method.Call(callParams)
rtnData, rtnErr := decodeRtnVals(rtnVals)
if rtnErr != nil {
handler.SendResponseError(rtnErr)
return true
}
handler.SendResponse(rtnData, true)
return true
} else if methodDecl.CommandType == wshutil.RpcType_ResponseStream {
rtnVals := methodDecl.Method.Call(callParams)
rtnChVal := rtnVals[0]
if rtnChVal.IsNil() {
handler.SendResponse(nil, true)
return true
}
go func() {
defer handler.Finalize()
// must use reflection here because we don't know the generic type of RespOrErrorUnion
for {
respVal, ok := rtnChVal.Recv()
if !ok {
break
}
errorVal := respVal.FieldByName("Error")
if !errorVal.IsNil() {
handler.SendResponseError(errorVal.Interface().(error))
break
}
respData := respVal.FieldByName("Response").Interface()
handler.SendResponse(respData, false)
}
}()
return false
} else {
handler.SendResponseError(fmt.Errorf("unsupported command type %q", methodDecl.CommandType))
return true
}
}
func MakeUnixListener(sockName string) (net.Listener, error) {
os.Remove(sockName) // ignore error
rtn, err := net.Listen("unix", sockName)
if err != nil {
return nil, fmt.Errorf("error creating listener at %v: %v", sockName, err)
}
os.Chmod(sockName, 0700)
log.Printf("Server listening on %s\n", sockName)
return rtn, nil
}
func runWshRpcWithStream(conn net.Conn) {
defer conn.Close()
inputCh := make(chan []byte, DefaultInputChSize)
outputCh := make(chan []byte, DefaultOutputChSize)
go wshutil.AdaptMsgChToStream(outputCh, conn)
go wshutil.AdaptStreamToMsgCh(conn, inputCh)
wshutil.MakeWshRpc(inputCh, outputCh, wshutil.RpcContext{}, mainWshServerHandler)
}
func RunWshRpcOverListener(listener net.Listener) {
go func() {
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("error accepting connection: %v\n", err)
continue
}
go runWshRpcWithStream(conn)
}
}()
}
func RunDomainSocketWshServer() error {
sockName := wavebase.GetDomainSocketName()
listener, err := MakeUnixListener(sockName)
if err != nil {
return fmt.Errorf("error starging unix listener for wsh-server: %w", err)
}
defer listener.Close()
RunWshRpcOverListener(listener)
return nil
}
func MakeWshServer(inputCh chan []byte, outputCh chan []byte, initialCtx wshutil.RpcContext) {
wshutil.MakeWshRpc(inputCh, outputCh, initialCtx, mainWshServerHandler)
}