mirror of
https://github.com/wavetermdev/waveterm.git
synced 2024-12-21 16:38:23 +01:00
streaming rpc support (backend streams to the frontend) (#120)
This commit is contained in:
parent
734a066af8
commit
776ccd7da0
@ -46,9 +46,11 @@ tasks:
|
||||
- go run cmd/generatewshclient/main-generatewshclient.go
|
||||
sources:
|
||||
- "cmd/generate/*.go"
|
||||
- "cmd/generatewshclient/*.go"
|
||||
- "pkg/service/**/*.go"
|
||||
- "pkg/wstore/*.go"
|
||||
- "pkg/wshrpc/**/*.go"
|
||||
- "pkg/tsgen/**/*.go"
|
||||
generates:
|
||||
- frontend/types/gotypes.d.ts
|
||||
- pkg/wshrpc/wshclient/wshclient.go
|
||||
|
@ -12,7 +12,24 @@ import (
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshutil"
|
||||
)
|
||||
|
||||
func genMethod(fd *os.File, methodDecl *wshserver.WshServerMethodDecl) {
|
||||
func genMethod_ResponseStream(fd *os.File, methodDecl *wshserver.WshServerMethodDecl) {
|
||||
fmt.Fprintf(fd, "// command %q, wshserver.%s\n", methodDecl.Command, methodDecl.MethodName)
|
||||
var dataType string
|
||||
dataVarName := "nil"
|
||||
if methodDecl.CommandDataType != nil {
|
||||
dataType = ", data " + methodDecl.CommandDataType.String()
|
||||
dataVarName = "data"
|
||||
}
|
||||
respType := "any"
|
||||
if methodDecl.DefaultResponseDataType != nil {
|
||||
respType = methodDecl.DefaultResponseDataType.String()
|
||||
}
|
||||
fmt.Fprintf(fd, "func %s(w *wshutil.WshRpc%s, opts *wshrpc.WshRpcCommandOpts) chan wshrpc.RespOrErrorUnion[%s] {\n", methodDecl.MethodName, dataType, respType)
|
||||
fmt.Fprintf(fd, " return sendRpcRequestResponseStreamHelper[%s](w, %q, %s, opts)\n", respType, methodDecl.Command, dataVarName)
|
||||
fmt.Fprintf(fd, "}\n\n")
|
||||
}
|
||||
|
||||
func genMethod_Call(fd *os.File, methodDecl *wshserver.WshServerMethodDecl) {
|
||||
fmt.Fprintf(fd, "// command %q, wshserver.%s\n", methodDecl.Command, methodDecl.MethodName)
|
||||
var dataType string
|
||||
dataVarName := "nil"
|
||||
@ -29,15 +46,11 @@ func genMethod(fd *os.File, methodDecl *wshserver.WshServerMethodDecl) {
|
||||
tParamVal = methodDecl.DefaultResponseDataType.String()
|
||||
}
|
||||
fmt.Fprintf(fd, "func %s(w *wshutil.WshRpc%s, opts *wshrpc.WshRpcCommandOpts) %s {\n", methodDecl.MethodName, dataType, returnType)
|
||||
if methodDecl.CommandType == wshutil.RpcType_Call {
|
||||
fmt.Fprintf(fd, " %s, err := sendRpcRequestHelper[%s](w, %q, %s, opts)\n", respName, tParamVal, methodDecl.Command, dataVarName)
|
||||
if methodDecl.DefaultResponseDataType != nil {
|
||||
fmt.Fprintf(fd, " return resp, err\n")
|
||||
} else {
|
||||
fmt.Fprintf(fd, " return err\n")
|
||||
}
|
||||
fmt.Fprintf(fd, " %s, err := sendRpcRequestCallHelper[%s](w, %q, %s, opts)\n", respName, tParamVal, methodDecl.Command, dataVarName)
|
||||
if methodDecl.DefaultResponseDataType != nil {
|
||||
fmt.Fprintf(fd, " return resp, err\n")
|
||||
} else {
|
||||
panic("unsupported command type " + methodDecl.CommandType)
|
||||
fmt.Fprintf(fd, " return err\n")
|
||||
}
|
||||
fmt.Fprintf(fd, "}\n\n")
|
||||
}
|
||||
@ -61,7 +74,13 @@ func main() {
|
||||
|
||||
for _, key := range utilfn.GetOrderedMapKeys(wshserver.WshServerCommandToDeclMap) {
|
||||
methodDecl := wshserver.WshServerCommandToDeclMap[key]
|
||||
genMethod(fd, methodDecl)
|
||||
if methodDecl.CommandType == wshutil.RpcType_ResponseStream {
|
||||
genMethod_ResponseStream(fd, methodDecl)
|
||||
} else if methodDecl.CommandType == wshutil.RpcType_Call {
|
||||
genMethod_Call(fd, methodDecl)
|
||||
} else {
|
||||
panic("unsupported command type " + methodDecl.CommandType)
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(fd, "\n")
|
||||
}
|
||||
|
@ -107,12 +107,27 @@ function callBackendService(service: string, method: string, args: any[], noUICo
|
||||
return prtn;
|
||||
}
|
||||
|
||||
function callWshServerRpc(
|
||||
function wshServerRpcHelper_responsestream(
|
||||
command: string,
|
||||
data: any,
|
||||
meta: WshServerCommandMeta,
|
||||
opts: WshRpcCommandOpts
|
||||
): Promise<any> {
|
||||
): AsyncGenerator<any, void, boolean> {
|
||||
if (opts?.noresponse) {
|
||||
throw new Error("noresponse not supported for responsestream calls");
|
||||
}
|
||||
let msg: RpcMessage = {
|
||||
command: command,
|
||||
data: data,
|
||||
reqid: uuidv4(),
|
||||
};
|
||||
if (opts?.timeout) {
|
||||
msg.timeout = opts.timeout;
|
||||
}
|
||||
const rpcGen = sendRpcCommand(msg);
|
||||
return rpcGen;
|
||||
}
|
||||
|
||||
function wshServerRpcHelper_call(command: string, data: any, opts: WshRpcCommandOpts): Promise<any> {
|
||||
let msg: RpcMessage = {
|
||||
command: command,
|
||||
data: data,
|
||||
@ -123,32 +138,14 @@ function callWshServerRpc(
|
||||
if (opts?.timeout) {
|
||||
msg.timeout = opts.timeout;
|
||||
}
|
||||
if (meta.commandtype != "call") {
|
||||
throw new Error("unimplemented wshserver commandtype " + meta.commandtype);
|
||||
}
|
||||
const rpcGen = sendRpcCommand(msg);
|
||||
if (rpcGen == null) {
|
||||
return null;
|
||||
}
|
||||
let resolveFn: (value: any) => void;
|
||||
let rejectFn: (reason?: any) => void;
|
||||
const prtn = new Promise((resolve, reject) => {
|
||||
resolveFn = resolve;
|
||||
rejectFn = reject;
|
||||
const respMsgPromise = rpcGen.next(true); // pass true to force termination of rpc after 1 response (not streaming)
|
||||
return respMsgPromise.then((msg: IteratorResult<any, void>) => {
|
||||
return msg.value;
|
||||
});
|
||||
const respMsg = rpcGen.next(true); // pass true to force termination of rpc after 1 response (not streaing)
|
||||
respMsg.then((msg: IteratorResult<RpcMessage, void>) => {
|
||||
if (msg.value == null) {
|
||||
resolveFn(null);
|
||||
}
|
||||
let respMsg: RpcMessage = msg.value as RpcMessage;
|
||||
if (respMsg.error != null) {
|
||||
rejectFn(new Error(respMsg.error));
|
||||
return;
|
||||
}
|
||||
resolveFn(respMsg.data);
|
||||
});
|
||||
return prtn;
|
||||
}
|
||||
|
||||
const waveObjectValueCache = new Map<string, WaveObjectValue<any>>();
|
||||
@ -368,7 +365,6 @@ function setObjectValue<T extends WaveObj>(value: T, setFn?: jotai.Setter, pushT
|
||||
|
||||
export {
|
||||
callBackendService,
|
||||
callWshServerRpc,
|
||||
cleanWaveObjectCache,
|
||||
clearWaveObjectCache,
|
||||
getObjectValue,
|
||||
@ -383,4 +379,6 @@ export {
|
||||
useWaveObjectValue,
|
||||
useWaveObjectValueWithSuspense,
|
||||
waveObjectValueCache,
|
||||
wshServerRpcHelper_call,
|
||||
wshServerRpcHelper_responsestream,
|
||||
};
|
||||
|
@ -16,7 +16,7 @@ async function* rpcResponseGenerator(
|
||||
command: string,
|
||||
reqid: string,
|
||||
timeout: number
|
||||
): AsyncGenerator<RpcMessage, void, boolean> {
|
||||
): AsyncGenerator<any, void, boolean> {
|
||||
const msgQueue: RpcMessage[] = [];
|
||||
let signalFn: () => void;
|
||||
let signalPromise = new Promise<void>((resolve) => (signalFn = resolve));
|
||||
@ -39,11 +39,18 @@ async function* rpcResponseGenerator(
|
||||
command: command,
|
||||
msgFn: msgFn,
|
||||
});
|
||||
yield null;
|
||||
try {
|
||||
while (true) {
|
||||
while (msgQueue.length > 0) {
|
||||
const msg = msgQueue.shift()!;
|
||||
const shouldTerminate = yield msg;
|
||||
if (msg.error != null) {
|
||||
throw new Error(msg.error);
|
||||
}
|
||||
if (!msg.cont && msg.data == null) {
|
||||
return;
|
||||
}
|
||||
const shouldTerminate = yield msg.data;
|
||||
if (shouldTerminate || !msg.cont) {
|
||||
return;
|
||||
}
|
||||
@ -64,7 +71,9 @@ function sendRpcCommand(msg: RpcMessage): AsyncGenerator<RpcMessage, void, boole
|
||||
if (msg.reqid == null) {
|
||||
return null;
|
||||
}
|
||||
return rpcResponseGenerator(msg.command, msg.reqid, msg.timeout);
|
||||
const rtnGen = rpcResponseGenerator(msg.command, msg.reqid, msg.timeout);
|
||||
rtnGen.next(); // start the generator (run the initialization/registration logic, throw away the result)
|
||||
return rtnGen;
|
||||
}
|
||||
|
||||
function handleIncomingRpcMessage(msg: RpcMessage) {
|
||||
@ -85,4 +94,22 @@ function handleIncomingRpcMessage(msg: RpcMessage) {
|
||||
entry.msgFn(msg);
|
||||
}
|
||||
|
||||
async function consumeGenerator(gen: AsyncGenerator<any, any, any>) {
|
||||
let idx = 0;
|
||||
try {
|
||||
for await (const msg of gen) {
|
||||
console.log("gen", idx, msg);
|
||||
idx++;
|
||||
}
|
||||
const result = await gen.return(undefined);
|
||||
console.log("gen done", result.value);
|
||||
} catch (e) {
|
||||
console.log("gen error", e);
|
||||
}
|
||||
}
|
||||
|
||||
if (globalThis.window != null) {
|
||||
globalThis["consumeGenerator"] = consumeGenerator;
|
||||
}
|
||||
|
||||
export { handleIncomingRpcMessage, sendRpcCommand };
|
||||
|
@ -9,62 +9,57 @@ import * as WOS from "./wos";
|
||||
class WshServerType {
|
||||
// command "controller:input" [call]
|
||||
BlockInputCommand(data: CommandBlockInputData, opts?: WshRpcCommandOpts): Promise<void> {
|
||||
const meta: WshServerCommandMeta = {commandtype: "call"};
|
||||
return WOS.callWshServerRpc("controller:input", data, meta, opts);
|
||||
return WOS.wshServerRpcHelper_call("controller:input", data, opts);
|
||||
}
|
||||
|
||||
// command "controller:restart" [call]
|
||||
BlockRestartCommand(data: CommandBlockRestartData, opts?: WshRpcCommandOpts): Promise<void> {
|
||||
const meta: WshServerCommandMeta = {commandtype: "call"};
|
||||
return WOS.callWshServerRpc("controller:restart", data, meta, opts);
|
||||
return WOS.wshServerRpcHelper_call("controller:restart", data, opts);
|
||||
}
|
||||
|
||||
// command "createblock" [call]
|
||||
CreateBlockCommand(data: CommandCreateBlockData, opts?: WshRpcCommandOpts): Promise<ORef> {
|
||||
const meta: WshServerCommandMeta = {commandtype: "call"};
|
||||
return WOS.callWshServerRpc("createblock", data, meta, opts);
|
||||
return WOS.wshServerRpcHelper_call("createblock", data, opts);
|
||||
}
|
||||
|
||||
// command "file:append" [call]
|
||||
AppendFileCommand(data: CommandAppendFileData, opts?: WshRpcCommandOpts): Promise<void> {
|
||||
const meta: WshServerCommandMeta = {commandtype: "call"};
|
||||
return WOS.callWshServerRpc("file:append", data, meta, opts);
|
||||
return WOS.wshServerRpcHelper_call("file:append", data, opts);
|
||||
}
|
||||
|
||||
// command "file:appendijson" [call]
|
||||
AppendIJsonCommand(data: CommandAppendIJsonData, opts?: WshRpcCommandOpts): Promise<void> {
|
||||
const meta: WshServerCommandMeta = {commandtype: "call"};
|
||||
return WOS.callWshServerRpc("file:appendijson", data, meta, opts);
|
||||
return WOS.wshServerRpcHelper_call("file:appendijson", data, opts);
|
||||
}
|
||||
|
||||
// command "getmeta" [call]
|
||||
GetMetaCommand(data: CommandGetMetaData, opts?: WshRpcCommandOpts): Promise<MetaType> {
|
||||
const meta: WshServerCommandMeta = {commandtype: "call"};
|
||||
return WOS.callWshServerRpc("getmeta", data, meta, opts);
|
||||
return WOS.wshServerRpcHelper_call("getmeta", data, opts);
|
||||
}
|
||||
|
||||
// command "message" [call]
|
||||
MessageCommand(data: CommandMessageData, opts?: WshRpcCommandOpts): Promise<void> {
|
||||
const meta: WshServerCommandMeta = {commandtype: "call"};
|
||||
return WOS.callWshServerRpc("message", data, meta, opts);
|
||||
return WOS.wshServerRpcHelper_call("message", data, opts);
|
||||
}
|
||||
|
||||
// command "resolveids" [call]
|
||||
ResolveIdsCommand(data: CommandResolveIdsData, opts?: WshRpcCommandOpts): Promise<CommandResolveIdsRtnData> {
|
||||
const meta: WshServerCommandMeta = {commandtype: "call"};
|
||||
return WOS.callWshServerRpc("resolveids", data, meta, opts);
|
||||
return WOS.wshServerRpcHelper_call("resolveids", data, opts);
|
||||
}
|
||||
|
||||
// command "setmeta" [call]
|
||||
SetMetaCommand(data: CommandSetMetaData, opts?: WshRpcCommandOpts): Promise<void> {
|
||||
const meta: WshServerCommandMeta = {commandtype: "call"};
|
||||
return WOS.callWshServerRpc("setmeta", data, meta, opts);
|
||||
return WOS.wshServerRpcHelper_call("setmeta", data, opts);
|
||||
}
|
||||
|
||||
// command "setview" [call]
|
||||
BlockSetViewCommand(data: CommandBlockSetViewData, opts?: WshRpcCommandOpts): Promise<void> {
|
||||
const meta: WshServerCommandMeta = {commandtype: "call"};
|
||||
return WOS.callWshServerRpc("setview", data, meta, opts);
|
||||
return WOS.wshServerRpcHelper_call("setview", data, opts);
|
||||
}
|
||||
|
||||
// command "streamtest" [responsestream]
|
||||
RespStreamTest(opts?: WshRpcCommandOpts): AsyncGenerator<number, void, boolean> {
|
||||
return WOS.wshServerRpcHelper_responsestream("streamtest", null, opts);
|
||||
}
|
||||
|
||||
}
|
||||
|
1
frontend/types/gotypes.d.ts
vendored
1
frontend/types/gotypes.d.ts
vendored
@ -189,6 +189,7 @@ declare global {
|
||||
resid?: string;
|
||||
timeout?: number;
|
||||
cont?: boolean;
|
||||
cancel?: boolean;
|
||||
error?: string;
|
||||
datatype?: string;
|
||||
data?: any;
|
||||
|
@ -392,6 +392,38 @@ func GenerateServiceClass(serviceName string, serviceObj any, tsTypesMap map[ref
|
||||
}
|
||||
|
||||
func GenerateWshServerMethod(methodDecl *wshserver.WshServerMethodDecl, tsTypesMap map[reflect.Type]string) string {
|
||||
if methodDecl.CommandType == wshutil.RpcType_ResponseStream {
|
||||
return GenerateWshServerMethod_ResponseStream(methodDecl, tsTypesMap)
|
||||
} else if methodDecl.CommandType == wshutil.RpcType_Call {
|
||||
return GenerateWshServerMethod_Call(methodDecl, tsTypesMap)
|
||||
} else {
|
||||
panic(fmt.Sprintf("cannot generate wshserver commandtype %q", methodDecl.CommandType))
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateWshServerMethod_ResponseStream(methodDecl *wshserver.WshServerMethodDecl, tsTypesMap map[reflect.Type]string) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf(" // command %q [%s]\n", methodDecl.Command, methodDecl.CommandType))
|
||||
respType := "any"
|
||||
if methodDecl.DefaultResponseDataType != nil {
|
||||
respType, _ = TypeToTSType(methodDecl.DefaultResponseDataType, tsTypesMap)
|
||||
}
|
||||
dataName := "null"
|
||||
if methodDecl.CommandDataType != nil {
|
||||
dataName = "data"
|
||||
}
|
||||
genRespType := fmt.Sprintf("AsyncGenerator<%s, void, boolean>", respType)
|
||||
if methodDecl.CommandDataType != nil {
|
||||
sb.WriteString(fmt.Sprintf(" %s(data: %s, opts?: WshRpcCommandOpts): %s {\n", methodDecl.MethodName, methodDecl.CommandDataType.Name(), genRespType))
|
||||
} else {
|
||||
sb.WriteString(fmt.Sprintf(" %s(opts?: WshRpcCommandOpts): %s {\n", methodDecl.MethodName, genRespType))
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf(" return WOS.wshServerRpcHelper_responsestream(%q, %s, opts);\n", methodDecl.Command, dataName))
|
||||
sb.WriteString(" }\n")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func GenerateWshServerMethod_Call(methodDecl *wshserver.WshServerMethodDecl, tsTypesMap map[reflect.Type]string) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf(" // command %q [%s]\n", methodDecl.Command, methodDecl.CommandType))
|
||||
rtnType := "Promise<void>"
|
||||
@ -399,14 +431,16 @@ func GenerateWshServerMethod(methodDecl *wshserver.WshServerMethodDecl, tsTypesM
|
||||
rtnTypeName, _ := TypeToTSType(methodDecl.DefaultResponseDataType, tsTypesMap)
|
||||
rtnType = fmt.Sprintf("Promise<%s>", rtnTypeName)
|
||||
}
|
||||
dataName := "null"
|
||||
if methodDecl.CommandDataType != nil {
|
||||
dataName = "data"
|
||||
}
|
||||
if methodDecl.CommandDataType != nil {
|
||||
sb.WriteString(fmt.Sprintf(" %s(data: %s, opts?: WshRpcCommandOpts): %s {\n", methodDecl.MethodName, methodDecl.CommandDataType.Name(), rtnType))
|
||||
} else {
|
||||
sb.WriteString(fmt.Sprintf(" %s(opts?: WshRpcCommandOpts): %s {\n", methodDecl.MethodName, rtnType))
|
||||
}
|
||||
metaData := fmt.Sprintf(" const meta: WshServerCommandMeta = {commandtype: %q};\n", methodDecl.CommandType)
|
||||
methodBody := fmt.Sprintf(" return WOS.callWshServerRpc(%q, data, meta, opts);\n", methodDecl.Command)
|
||||
sb.WriteString(metaData)
|
||||
methodBody := fmt.Sprintf(" return WOS.wshServerRpcHelper_call(%q, %s, opts);\n", methodDecl.Command, dataName)
|
||||
sb.WriteString(methodBody)
|
||||
sb.WriteString(" }\n")
|
||||
return sb.String()
|
||||
|
@ -13,62 +13,67 @@ import (
|
||||
|
||||
// command "controller:input", wshserver.BlockInputCommand
|
||||
func BlockInputCommand(w *wshutil.WshRpc, data wshrpc.CommandBlockInputData, opts *wshrpc.WshRpcCommandOpts) error {
|
||||
_, err := sendRpcRequestHelper[any](w, "controller:input", data, opts)
|
||||
_, err := sendRpcRequestCallHelper[any](w, "controller:input", data, opts)
|
||||
return err
|
||||
}
|
||||
|
||||
// command "controller:restart", wshserver.BlockRestartCommand
|
||||
func BlockRestartCommand(w *wshutil.WshRpc, data wshrpc.CommandBlockRestartData, opts *wshrpc.WshRpcCommandOpts) error {
|
||||
_, err := sendRpcRequestHelper[any](w, "controller:restart", data, opts)
|
||||
_, err := sendRpcRequestCallHelper[any](w, "controller:restart", data, opts)
|
||||
return err
|
||||
}
|
||||
|
||||
// command "createblock", wshserver.CreateBlockCommand
|
||||
func CreateBlockCommand(w *wshutil.WshRpc, data wshrpc.CommandCreateBlockData, opts *wshrpc.WshRpcCommandOpts) (*waveobj.ORef, error) {
|
||||
resp, err := sendRpcRequestHelper[*waveobj.ORef](w, "createblock", data, opts)
|
||||
resp, err := sendRpcRequestCallHelper[*waveobj.ORef](w, "createblock", data, opts)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// command "file:append", wshserver.AppendFileCommand
|
||||
func AppendFileCommand(w *wshutil.WshRpc, data wshrpc.CommandAppendFileData, opts *wshrpc.WshRpcCommandOpts) error {
|
||||
_, err := sendRpcRequestHelper[any](w, "file:append", data, opts)
|
||||
_, err := sendRpcRequestCallHelper[any](w, "file:append", data, opts)
|
||||
return err
|
||||
}
|
||||
|
||||
// command "file:appendijson", wshserver.AppendIJsonCommand
|
||||
func AppendIJsonCommand(w *wshutil.WshRpc, data wshrpc.CommandAppendIJsonData, opts *wshrpc.WshRpcCommandOpts) error {
|
||||
_, err := sendRpcRequestHelper[any](w, "file:appendijson", data, opts)
|
||||
_, err := sendRpcRequestCallHelper[any](w, "file:appendijson", data, opts)
|
||||
return err
|
||||
}
|
||||
|
||||
// command "getmeta", wshserver.GetMetaCommand
|
||||
func GetMetaCommand(w *wshutil.WshRpc, data wshrpc.CommandGetMetaData, opts *wshrpc.WshRpcCommandOpts) (map[string]interface {}, error) {
|
||||
resp, err := sendRpcRequestHelper[map[string]interface {}](w, "getmeta", data, opts)
|
||||
resp, err := sendRpcRequestCallHelper[map[string]interface {}](w, "getmeta", data, opts)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// command "message", wshserver.MessageCommand
|
||||
func MessageCommand(w *wshutil.WshRpc, data wshrpc.CommandMessageData, opts *wshrpc.WshRpcCommandOpts) error {
|
||||
_, err := sendRpcRequestHelper[any](w, "message", data, opts)
|
||||
_, err := sendRpcRequestCallHelper[any](w, "message", data, opts)
|
||||
return err
|
||||
}
|
||||
|
||||
// command "resolveids", wshserver.ResolveIdsCommand
|
||||
func ResolveIdsCommand(w *wshutil.WshRpc, data wshrpc.CommandResolveIdsData, opts *wshrpc.WshRpcCommandOpts) (wshrpc.CommandResolveIdsRtnData, error) {
|
||||
resp, err := sendRpcRequestHelper[wshrpc.CommandResolveIdsRtnData](w, "resolveids", data, opts)
|
||||
resp, err := sendRpcRequestCallHelper[wshrpc.CommandResolveIdsRtnData](w, "resolveids", data, opts)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// command "setmeta", wshserver.SetMetaCommand
|
||||
func SetMetaCommand(w *wshutil.WshRpc, data wshrpc.CommandSetMetaData, opts *wshrpc.WshRpcCommandOpts) error {
|
||||
_, err := sendRpcRequestHelper[any](w, "setmeta", data, opts)
|
||||
_, err := sendRpcRequestCallHelper[any](w, "setmeta", data, opts)
|
||||
return err
|
||||
}
|
||||
|
||||
// command "setview", wshserver.BlockSetViewCommand
|
||||
func BlockSetViewCommand(w *wshutil.WshRpc, data wshrpc.CommandBlockSetViewData, opts *wshrpc.WshRpcCommandOpts) error {
|
||||
_, err := sendRpcRequestHelper[any](w, "setview", data, opts)
|
||||
_, err := sendRpcRequestCallHelper[any](w, "setview", data, opts)
|
||||
return err
|
||||
}
|
||||
|
||||
// command "streamtest", wshserver.RespStreamTest
|
||||
func RespStreamTest(w *wshutil.WshRpc, opts *wshrpc.WshRpcCommandOpts) chan wshrpc.RespOrErrorUnion[int] {
|
||||
return sendRpcRequestResponseStreamHelper[int](w, "streamtest", nil, opts)
|
||||
}
|
||||
|
||||
|
||||
|
@ -9,7 +9,7 @@ import (
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshutil"
|
||||
)
|
||||
|
||||
func sendRpcRequestHelper[T any](w *wshutil.WshRpc, command string, data interface{}, opts *wshrpc.WshRpcCommandOpts) (T, error) {
|
||||
func sendRpcRequestCallHelper[T any](w *wshutil.WshRpc, command string, data interface{}, opts *wshrpc.WshRpcCommandOpts) (T, error) {
|
||||
var respData T
|
||||
if opts.NoResponse {
|
||||
err := w.SendCommand(command, data)
|
||||
@ -28,3 +28,36 @@ func sendRpcRequestHelper[T any](w *wshutil.WshRpc, command string, data interfa
|
||||
}
|
||||
return respData, nil
|
||||
}
|
||||
|
||||
func sendRpcRequestResponseStreamHelper[T any](w *wshutil.WshRpc, command string, data interface{}, opts *wshrpc.WshRpcCommandOpts) chan wshrpc.RespOrErrorUnion[T] {
|
||||
respChan := make(chan wshrpc.RespOrErrorUnion[T])
|
||||
reqHandler, err := w.SendComplexRequest(command, data, true, opts.Timeout)
|
||||
if err != nil {
|
||||
go func() {
|
||||
respChan <- wshrpc.RespOrErrorUnion[T]{Error: err}
|
||||
close(respChan)
|
||||
}()
|
||||
} else {
|
||||
go func() {
|
||||
defer close(respChan)
|
||||
for {
|
||||
if reqHandler.ResponseDone() {
|
||||
break
|
||||
}
|
||||
resp, err := reqHandler.NextResponse()
|
||||
if err != nil {
|
||||
respChan <- wshrpc.RespOrErrorUnion[T]{Error: err}
|
||||
break
|
||||
}
|
||||
var respData T
|
||||
err = utilfn.ReUnmarshal(&respData, resp)
|
||||
if err != nil {
|
||||
respChan <- wshrpc.RespOrErrorUnion[T]{Error: err}
|
||||
break
|
||||
}
|
||||
respChan <- wshrpc.RespOrErrorUnion[T]{Response: respData}
|
||||
}
|
||||
}()
|
||||
}
|
||||
return respChan
|
||||
}
|
||||
|
@ -35,6 +35,11 @@ var DataTypeMap = map[string]reflect.Type{
|
||||
"oref": reflect.TypeOf(waveobj.ORef{}),
|
||||
}
|
||||
|
||||
type RespOrErrorUnion[T any] struct {
|
||||
Response T
|
||||
Error error
|
||||
}
|
||||
|
||||
// for frontend
|
||||
type WshServerCommandMeta struct {
|
||||
CommandType string `json:"commandtype"`
|
||||
|
@ -3,14 +3,14 @@
|
||||
|
||||
package wshserver
|
||||
|
||||
// this file contains the implementation of the wsh server methods
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
@ -18,33 +18,19 @@ import (
|
||||
"github.com/wavetermdev/thenextwave/pkg/blockcontroller"
|
||||
"github.com/wavetermdev/thenextwave/pkg/eventbus"
|
||||
"github.com/wavetermdev/thenextwave/pkg/filestore"
|
||||
"github.com/wavetermdev/thenextwave/pkg/util/utilfn"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wavebase"
|
||||
"github.com/wavetermdev/thenextwave/pkg/waveobj"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshutil"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wstore"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultOutputChSize = 32
|
||||
DefaultInputChSize = 32
|
||||
)
|
||||
|
||||
type WshServer struct{}
|
||||
|
||||
var WshServerImpl = WshServer{}
|
||||
var contextRType = reflect.TypeOf((*context.Context)(nil)).Elem()
|
||||
|
||||
type WshServerMethodDecl struct {
|
||||
Command string
|
||||
CommandType string
|
||||
MethodName string
|
||||
Method reflect.Value
|
||||
CommandDataType reflect.Type
|
||||
DefaultResponseDataType reflect.Type
|
||||
RequestDataTypes []reflect.Type // for streaming requests
|
||||
ResponseDataTypes []reflect.Type // for streaming responses
|
||||
var RespStreamTest_MethodDecl = &WshServerMethodDecl{
|
||||
Command: "streamtest",
|
||||
CommandType: wshutil.RpcType_ResponseStream,
|
||||
MethodName: "RespStreamTest",
|
||||
Method: reflect.ValueOf(WshServerImpl.RespStreamTest),
|
||||
CommandDataType: nil,
|
||||
DefaultResponseDataType: reflect.TypeOf((int)(0)),
|
||||
}
|
||||
|
||||
var WshServerCommandToDeclMap = map[string]*WshServerMethodDecl{
|
||||
@ -58,37 +44,28 @@ var WshServerCommandToDeclMap = map[string]*WshServerMethodDecl{
|
||||
wshrpc.Command_BlockInput: GetWshServerMethod(wshrpc.Command_BlockInput, wshutil.RpcType_Call, "BlockInputCommand", WshServerImpl.BlockInputCommand),
|
||||
wshrpc.Command_AppendFile: GetWshServerMethod(wshrpc.Command_AppendFile, wshutil.RpcType_Call, "AppendFileCommand", WshServerImpl.AppendFileCommand),
|
||||
wshrpc.Command_AppendIJson: GetWshServerMethod(wshrpc.Command_AppendIJson, wshutil.RpcType_Call, "AppendIJsonCommand", WshServerImpl.AppendIJsonCommand),
|
||||
"streamtest": RespStreamTest_MethodDecl,
|
||||
}
|
||||
|
||||
func GetWshServerMethod(command string, commandType string, methodName string, methodFunc any) *WshServerMethodDecl {
|
||||
methodVal := reflect.ValueOf(methodFunc)
|
||||
methodType := methodVal.Type()
|
||||
if methodType.Kind() != reflect.Func {
|
||||
panic(fmt.Sprintf("methodVal must be a function got [%v]", methodType))
|
||||
}
|
||||
if methodType.In(0) != contextRType {
|
||||
panic(fmt.Sprintf("methodVal must have a context as the first argument %v", methodType))
|
||||
}
|
||||
var defResponseType reflect.Type
|
||||
if methodType.NumOut() > 1 {
|
||||
defResponseType = methodType.Out(0)
|
||||
}
|
||||
rtn := &WshServerMethodDecl{
|
||||
Command: command,
|
||||
CommandType: commandType,
|
||||
MethodName: methodName,
|
||||
Method: methodVal,
|
||||
CommandDataType: methodType.In(1),
|
||||
DefaultResponseDataType: defResponseType,
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
// for testing
|
||||
func (ws *WshServer) MessageCommand(ctx context.Context, data wshrpc.CommandMessageData) error {
|
||||
log.Printf("MESSAGE: %s | %q\n", data.ORef, data.Message)
|
||||
return nil
|
||||
}
|
||||
|
||||
// for testing
|
||||
func (ws *WshServer) RespStreamTest(ctx context.Context) chan wshrpc.RespOrErrorUnion[int] {
|
||||
rtn := make(chan wshrpc.RespOrErrorUnion[int])
|
||||
go func() {
|
||||
for i := 1; i <= 5; i++ {
|
||||
rtn <- wshrpc.RespOrErrorUnion[int]{Response: i}
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
close(rtn)
|
||||
}()
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (ws *WshServer) GetMetaCommand(ctx context.Context, data wshrpc.CommandGetMetaData) (wshrpc.MetaDataType, error) {
|
||||
log.Printf("calling meta: %s\n", data.ORef)
|
||||
obj, err := wstore.DBGetORef(ctx, data.ORef)
|
||||
@ -317,101 +294,3 @@ func (ws *WshServer) AppendIJsonCommand(ctx context.Context, data wshrpc.Command
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeRtnVals(rtnVals []reflect.Value) (any, error) {
|
||||
switch len(rtnVals) {
|
||||
case 0:
|
||||
return nil, nil
|
||||
case 1:
|
||||
errIf := rtnVals[0].Interface()
|
||||
if errIf == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, errIf.(error)
|
||||
case 2:
|
||||
errIf := rtnVals[1].Interface()
|
||||
if errIf == nil {
|
||||
return rtnVals[0].Interface(), nil
|
||||
}
|
||||
return rtnVals[0].Interface(), errIf.(error)
|
||||
default:
|
||||
return nil, fmt.Errorf("too many return values: %d", len(rtnVals))
|
||||
}
|
||||
}
|
||||
|
||||
func mainWshServerHandler(handler *wshutil.RpcResponseHandler) {
|
||||
command := handler.GetCommand()
|
||||
methodDecl := WshServerCommandToDeclMap[command]
|
||||
if methodDecl == nil {
|
||||
handler.SendResponseError(fmt.Errorf("command %q not found", command))
|
||||
return
|
||||
}
|
||||
var callParams []reflect.Value
|
||||
callParams = append(callParams, reflect.ValueOf(handler.Context()))
|
||||
if methodDecl.CommandDataType != nil {
|
||||
commandData := reflect.New(methodDecl.CommandDataType).Interface()
|
||||
err := utilfn.ReUnmarshal(commandData, handler.GetCommandRawData())
|
||||
if err != nil {
|
||||
handler.SendResponseError(fmt.Errorf("error re-marshalling command data: %w", err))
|
||||
return
|
||||
}
|
||||
wshrpc.HackRpcContextIntoData(commandData, handler.GetRpcContext())
|
||||
callParams = append(callParams, reflect.ValueOf(commandData).Elem())
|
||||
}
|
||||
rtnVals := methodDecl.Method.Call(callParams)
|
||||
rtnData, rtnErr := decodeRtnVals(rtnVals)
|
||||
if rtnErr != nil {
|
||||
handler.SendResponseError(rtnErr)
|
||||
return
|
||||
} else {
|
||||
handler.SendResponse(rtnData, true)
|
||||
}
|
||||
}
|
||||
|
||||
func MakeUnixListener(sockName string) (net.Listener, error) {
|
||||
os.Remove(sockName) // ignore error
|
||||
rtn, err := net.Listen("unix", sockName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating listener at %v: %v", sockName, err)
|
||||
}
|
||||
os.Chmod(sockName, 0700)
|
||||
log.Printf("Server listening on %s\n", sockName)
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func runWshRpcWithStream(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
inputCh := make(chan []byte, DefaultInputChSize)
|
||||
outputCh := make(chan []byte, DefaultOutputChSize)
|
||||
go wshutil.AdaptMsgChToStream(outputCh, conn)
|
||||
go wshutil.AdaptStreamToMsgCh(conn, inputCh)
|
||||
wshutil.MakeWshRpc(inputCh, outputCh, wshutil.RpcContext{}, mainWshServerHandler)
|
||||
}
|
||||
|
||||
func RunWshRpcOverListener(listener net.Listener) {
|
||||
go func() {
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.Printf("error accepting connection: %v\n", err)
|
||||
continue
|
||||
}
|
||||
go runWshRpcWithStream(conn)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func RunDomainSocketWshServer() error {
|
||||
sockName := wavebase.GetDomainSocketName()
|
||||
listener, err := MakeUnixListener(sockName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error starging unix listener for wsh-server: %w", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
RunWshRpcOverListener(listener)
|
||||
return nil
|
||||
}
|
||||
|
||||
func MakeWshServer(inputCh chan []byte, outputCh chan []byte, initialCtx wshutil.RpcContext) {
|
||||
wshutil.MakeWshRpc(inputCh, outputCh, initialCtx, mainWshServerHandler)
|
||||
}
|
||||
|
193
pkg/wshrpc/wshserver/wshserverutil.go
Normal file
193
pkg/wshrpc/wshserver/wshserverutil.go
Normal file
@ -0,0 +1,193 @@
|
||||
// Copyright 2024, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package wshserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"reflect"
|
||||
|
||||
"github.com/wavetermdev/thenextwave/pkg/util/utilfn"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wavebase"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshutil"
|
||||
)
|
||||
|
||||
// this file contains the generic types and functions that create and power the WSH server
|
||||
|
||||
const (
|
||||
DefaultOutputChSize = 32
|
||||
DefaultInputChSize = 32
|
||||
)
|
||||
|
||||
type WshServer struct{}
|
||||
|
||||
type WshServerMethodDecl struct {
|
||||
Command string
|
||||
CommandType string
|
||||
MethodName string
|
||||
Method reflect.Value
|
||||
CommandDataType reflect.Type
|
||||
DefaultResponseDataType reflect.Type
|
||||
RequestDataTypes []reflect.Type // for streaming requests
|
||||
ResponseDataTypes []reflect.Type // for streaming responses
|
||||
}
|
||||
|
||||
var WshServerImpl = WshServer{}
|
||||
var contextRType = reflect.TypeOf((*context.Context)(nil)).Elem()
|
||||
|
||||
func GetWshServerMethod(command string, commandType string, methodName string, methodFunc any) *WshServerMethodDecl {
|
||||
methodVal := reflect.ValueOf(methodFunc)
|
||||
methodType := methodVal.Type()
|
||||
if methodType.Kind() != reflect.Func {
|
||||
panic(fmt.Sprintf("methodVal must be a function got [%v]", methodType))
|
||||
}
|
||||
if methodType.In(0) != contextRType {
|
||||
panic(fmt.Sprintf("methodVal must have a context as the first argument %v", methodType))
|
||||
}
|
||||
var defResponseType reflect.Type
|
||||
if methodType.NumOut() > 1 {
|
||||
defResponseType = methodType.Out(0)
|
||||
}
|
||||
rtn := &WshServerMethodDecl{
|
||||
Command: command,
|
||||
CommandType: commandType,
|
||||
MethodName: methodName,
|
||||
Method: methodVal,
|
||||
CommandDataType: methodType.In(1),
|
||||
DefaultResponseDataType: defResponseType,
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func decodeRtnVals(rtnVals []reflect.Value) (any, error) {
|
||||
switch len(rtnVals) {
|
||||
case 0:
|
||||
return nil, nil
|
||||
case 1:
|
||||
errIf := rtnVals[0].Interface()
|
||||
if errIf == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, errIf.(error)
|
||||
case 2:
|
||||
errIf := rtnVals[1].Interface()
|
||||
if errIf == nil {
|
||||
return rtnVals[0].Interface(), nil
|
||||
}
|
||||
return rtnVals[0].Interface(), errIf.(error)
|
||||
default:
|
||||
return nil, fmt.Errorf("too many return values: %d", len(rtnVals))
|
||||
}
|
||||
}
|
||||
|
||||
func mainWshServerHandler(handler *wshutil.RpcResponseHandler) bool {
|
||||
command := handler.GetCommand()
|
||||
methodDecl := WshServerCommandToDeclMap[command]
|
||||
if methodDecl == nil {
|
||||
handler.SendResponseError(fmt.Errorf("command %q not found", command))
|
||||
return true
|
||||
}
|
||||
var callParams []reflect.Value
|
||||
callParams = append(callParams, reflect.ValueOf(handler.Context()))
|
||||
if methodDecl.CommandDataType != nil {
|
||||
commandData := reflect.New(methodDecl.CommandDataType).Interface()
|
||||
err := utilfn.ReUnmarshal(commandData, handler.GetCommandRawData())
|
||||
if err != nil {
|
||||
handler.SendResponseError(fmt.Errorf("error re-marshalling command data: %w", err))
|
||||
return true
|
||||
}
|
||||
wshrpc.HackRpcContextIntoData(commandData, handler.GetRpcContext())
|
||||
callParams = append(callParams, reflect.ValueOf(commandData).Elem())
|
||||
}
|
||||
if methodDecl.CommandType == wshutil.RpcType_Call {
|
||||
rtnVals := methodDecl.Method.Call(callParams)
|
||||
rtnData, rtnErr := decodeRtnVals(rtnVals)
|
||||
if rtnErr != nil {
|
||||
handler.SendResponseError(rtnErr)
|
||||
return true
|
||||
}
|
||||
handler.SendResponse(rtnData, true)
|
||||
return true
|
||||
} else if methodDecl.CommandType == wshutil.RpcType_ResponseStream {
|
||||
rtnVals := methodDecl.Method.Call(callParams)
|
||||
rtnChVal := rtnVals[0]
|
||||
if rtnChVal.IsNil() {
|
||||
handler.SendResponse(nil, true)
|
||||
return true
|
||||
}
|
||||
go func() {
|
||||
defer handler.Finalize()
|
||||
// must use reflection here because we don't know the generic type of RespOrErrorUnion
|
||||
for {
|
||||
respVal, ok := rtnChVal.Recv()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
errorVal := respVal.FieldByName("Error")
|
||||
if !errorVal.IsNil() {
|
||||
handler.SendResponseError(errorVal.Interface().(error))
|
||||
break
|
||||
}
|
||||
respData := respVal.FieldByName("Response").Interface()
|
||||
handler.SendResponse(respData, false)
|
||||
}
|
||||
}()
|
||||
return false
|
||||
} else {
|
||||
handler.SendResponseError(fmt.Errorf("unsupported command type %q", methodDecl.CommandType))
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func MakeUnixListener(sockName string) (net.Listener, error) {
|
||||
os.Remove(sockName) // ignore error
|
||||
rtn, err := net.Listen("unix", sockName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating listener at %v: %v", sockName, err)
|
||||
}
|
||||
os.Chmod(sockName, 0700)
|
||||
log.Printf("Server listening on %s\n", sockName)
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func runWshRpcWithStream(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
inputCh := make(chan []byte, DefaultInputChSize)
|
||||
outputCh := make(chan []byte, DefaultOutputChSize)
|
||||
go wshutil.AdaptMsgChToStream(outputCh, conn)
|
||||
go wshutil.AdaptStreamToMsgCh(conn, inputCh)
|
||||
wshutil.MakeWshRpc(inputCh, outputCh, wshutil.RpcContext{}, mainWshServerHandler)
|
||||
}
|
||||
|
||||
func RunWshRpcOverListener(listener net.Listener) {
|
||||
go func() {
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.Printf("error accepting connection: %v\n", err)
|
||||
continue
|
||||
}
|
||||
go runWshRpcWithStream(conn)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func RunDomainSocketWshServer() error {
|
||||
sockName := wavebase.GetDomainSocketName()
|
||||
listener, err := MakeUnixListener(sockName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error starging unix listener for wsh-server: %w", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
RunWshRpcOverListener(listener)
|
||||
return nil
|
||||
}
|
||||
|
||||
func MakeWshServer(inputCh chan []byte, outputCh chan []byte, initialCtx wshutil.RpcContext) {
|
||||
wshutil.MakeWshRpc(inputCh, outputCh, initialCtx, mainWshServerHandler)
|
||||
}
|
@ -29,7 +29,9 @@ const (
|
||||
)
|
||||
|
||||
type ResponseFnType = func(any) error
|
||||
type CommandHandlerFnType = func(*RpcResponseHandler)
|
||||
|
||||
// returns true if handler is complete, false for an async handler
|
||||
type CommandHandlerFnType = func(*RpcResponseHandler) bool
|
||||
|
||||
type wshRpcContextKey struct{}
|
||||
|
||||
@ -50,7 +52,8 @@ type RpcMessage struct {
|
||||
ReqId string `json:"reqid,omitempty"`
|
||||
ResId string `json:"resid,omitempty"`
|
||||
Timeout int `json:"timeout,omitempty"`
|
||||
Cont bool `json:"cont,omitempty"`
|
||||
Cont bool `json:"cont,omitempty"` // flag if additional requests/responses are forthcoming
|
||||
Cancel bool `json:"cancel,omitempty"` // used to cancel a streaming request or response (sent from the side that is not streaming)
|
||||
Error string `json:"error,omitempty"`
|
||||
DataType string `json:"datatype,omitempty"`
|
||||
Data any `json:"data,omitempty"`
|
||||
@ -61,6 +64,21 @@ func (r *RpcMessage) IsRpcRequest() bool {
|
||||
}
|
||||
|
||||
func (r *RpcMessage) Validate() error {
|
||||
if r.ReqId != "" && r.ResId != "" {
|
||||
return fmt.Errorf("request packets may not have both reqid and resid set")
|
||||
}
|
||||
if r.Cancel {
|
||||
if r.Command != "" {
|
||||
return fmt.Errorf("cancel packets may not have command set")
|
||||
}
|
||||
if r.ReqId == "" && r.ResId == "" {
|
||||
return fmt.Errorf("cancel packets must have reqid or resid set")
|
||||
}
|
||||
if r.Data != nil {
|
||||
return fmt.Errorf("cancel packets may not have data set")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if r.Command != "" {
|
||||
if r.ResId != "" {
|
||||
return fmt.Errorf("command packets may not have resid set")
|
||||
@ -110,6 +128,8 @@ type WshRpc struct {
|
||||
RpcContext *atomic.Pointer[RpcContext]
|
||||
RpcMap map[string]*rpcData
|
||||
HandlerFn CommandHandlerFnType
|
||||
|
||||
ResponseHandlerMap map[string]*RpcResponseHandler // reqId => handler
|
||||
}
|
||||
|
||||
type rpcData struct {
|
||||
@ -121,12 +141,13 @@ type rpcData struct {
|
||||
// closes outputCh when inputCh is closed/done
|
||||
func MakeWshRpc(inputCh chan []byte, outputCh chan []byte, rpcCtx RpcContext, commandHandlerFn CommandHandlerFnType) *WshRpc {
|
||||
rtn := &WshRpc{
|
||||
Lock: &sync.Mutex{},
|
||||
InputCh: inputCh,
|
||||
OutputCh: outputCh,
|
||||
RpcMap: make(map[string]*rpcData),
|
||||
RpcContext: &atomic.Pointer[RpcContext]{},
|
||||
HandlerFn: commandHandlerFn,
|
||||
Lock: &sync.Mutex{},
|
||||
InputCh: inputCh,
|
||||
OutputCh: outputCh,
|
||||
RpcMap: make(map[string]*rpcData),
|
||||
RpcContext: &atomic.Pointer[RpcContext]{},
|
||||
HandlerFn: commandHandlerFn,
|
||||
ResponseHandlerMap: make(map[string]*RpcResponseHandler),
|
||||
}
|
||||
rtn.RpcContext.Store(&rpcCtx)
|
||||
go rtn.runServer()
|
||||
@ -142,6 +163,31 @@ func (w *WshRpc) SetRpcContext(ctx RpcContext) {
|
||||
w.RpcContext.Store(&ctx)
|
||||
}
|
||||
|
||||
func (w *WshRpc) registerResponseHandler(reqId string, handler *RpcResponseHandler) {
|
||||
w.Lock.Lock()
|
||||
defer w.Lock.Unlock()
|
||||
w.ResponseHandlerMap[reqId] = handler
|
||||
}
|
||||
|
||||
func (w *WshRpc) unregisterResponseHandler(reqId string) {
|
||||
w.Lock.Lock()
|
||||
defer w.Lock.Unlock()
|
||||
delete(w.ResponseHandlerMap, reqId)
|
||||
}
|
||||
|
||||
func (w *WshRpc) cancelRequest(reqId string) {
|
||||
if reqId == "" {
|
||||
return
|
||||
}
|
||||
w.Lock.Lock()
|
||||
defer w.Lock.Unlock()
|
||||
handler := w.ResponseHandlerMap[reqId]
|
||||
if handler != nil {
|
||||
handler.canceled.Store(true)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (w *WshRpc) handleRequest(req *RpcMessage) {
|
||||
var respHandler *RpcResponseHandler
|
||||
defer func() {
|
||||
@ -159,26 +205,38 @@ func (w *WshRpc) handleRequest(req *RpcMessage) {
|
||||
}
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
|
||||
ctx = withWshRpcContext(ctx, w)
|
||||
defer cancelFn()
|
||||
respHandler = &RpcResponseHandler{
|
||||
w: w,
|
||||
ctx: ctx,
|
||||
reqId: req.ReqId,
|
||||
command: req.Command,
|
||||
commandData: req.Data,
|
||||
done: &atomic.Bool{},
|
||||
rpcCtx: w.GetRpcContext(),
|
||||
w: w,
|
||||
ctx: ctx,
|
||||
reqId: req.ReqId,
|
||||
command: req.Command,
|
||||
commandData: req.Data,
|
||||
done: &atomic.Bool{},
|
||||
canceled: &atomic.Bool{},
|
||||
contextCancelFn: &atomic.Pointer[context.CancelFunc]{},
|
||||
rpcCtx: w.GetRpcContext(),
|
||||
}
|
||||
respHandler.contextCancelFn.Store(&cancelFn)
|
||||
w.registerResponseHandler(req.ReqId, respHandler)
|
||||
isAsync := false
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("panic in handleRequest: %v\n", r)
|
||||
debug.PrintStack()
|
||||
respHandler.SendResponseError(fmt.Errorf("panic: %v", r))
|
||||
}
|
||||
respHandler.finalize()
|
||||
if isAsync {
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
respHandler.Finalize()
|
||||
}()
|
||||
} else {
|
||||
cancelFn()
|
||||
respHandler.Finalize()
|
||||
}
|
||||
}()
|
||||
if w.HandlerFn != nil {
|
||||
w.HandlerFn(respHandler)
|
||||
isAsync = !w.HandlerFn(respHandler)
|
||||
}
|
||||
}
|
||||
|
||||
@ -191,6 +249,12 @@ func (w *WshRpc) runServer() {
|
||||
log.Printf("wshrpc received bad message: %v\n", err)
|
||||
continue
|
||||
}
|
||||
if msg.Cancel {
|
||||
if msg.ReqId != "" {
|
||||
w.cancelRequest(msg.ReqId)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if msg.IsRpcRequest() {
|
||||
w.handleRequest(&msg)
|
||||
} else {
|
||||
@ -279,17 +343,27 @@ func (w *WshRpc) SendRpcRequest(command string, data any, timeoutMs int) (any, e
|
||||
}
|
||||
|
||||
type RpcRequestHandler struct {
|
||||
w *WshRpc
|
||||
ctx context.Context
|
||||
cancelFn func()
|
||||
reqId string
|
||||
respCh chan *RpcMessage
|
||||
w *WshRpc
|
||||
ctx context.Context
|
||||
ctxCancelFn *atomic.Pointer[context.CancelFunc]
|
||||
reqId string
|
||||
respCh chan *RpcMessage
|
||||
}
|
||||
|
||||
func (handler *RpcRequestHandler) Context() context.Context {
|
||||
return handler.ctx
|
||||
}
|
||||
|
||||
func (handler *RpcRequestHandler) SendCancel() {
|
||||
msg := &RpcMessage{
|
||||
Cancel: true,
|
||||
ReqId: handler.reqId,
|
||||
}
|
||||
barr, _ := json.Marshal(msg) // will never fail
|
||||
handler.w.OutputCh <- barr
|
||||
handler.finalize()
|
||||
}
|
||||
|
||||
func (handler *RpcRequestHandler) ResponseDone() bool {
|
||||
select {
|
||||
case _, more := <-handler.respCh:
|
||||
@ -308,8 +382,10 @@ func (handler *RpcRequestHandler) NextResponse() (any, error) {
|
||||
}
|
||||
|
||||
func (handler *RpcRequestHandler) finalize() {
|
||||
if handler.cancelFn != nil {
|
||||
handler.cancelFn()
|
||||
cancelFnPtr := handler.ctxCancelFn.Load()
|
||||
if cancelFnPtr != nil && *cancelFnPtr != nil {
|
||||
(*cancelFnPtr)()
|
||||
handler.ctxCancelFn.Store(nil)
|
||||
}
|
||||
if handler.reqId != "" {
|
||||
handler.w.unregisterRpc(handler.reqId, nil)
|
||||
@ -317,13 +393,15 @@ func (handler *RpcRequestHandler) finalize() {
|
||||
}
|
||||
|
||||
type RpcResponseHandler struct {
|
||||
w *WshRpc
|
||||
ctx context.Context
|
||||
reqId string
|
||||
command string
|
||||
commandData any
|
||||
rpcCtx RpcContext
|
||||
done *atomic.Bool
|
||||
w *WshRpc
|
||||
ctx context.Context
|
||||
contextCancelFn *atomic.Pointer[context.CancelFunc]
|
||||
reqId string
|
||||
command string
|
||||
commandData any
|
||||
rpcCtx RpcContext
|
||||
canceled *atomic.Bool // canceled by requestor
|
||||
done *atomic.Bool
|
||||
}
|
||||
|
||||
func (handler *RpcResponseHandler) Context() context.Context {
|
||||
@ -350,7 +428,7 @@ func (handler *RpcResponseHandler) SendResponse(data any, done bool) error {
|
||||
return fmt.Errorf("request already done, cannot send additional response")
|
||||
}
|
||||
if done {
|
||||
handler.done.Store(true)
|
||||
defer handler.close()
|
||||
}
|
||||
msg := &RpcMessage{
|
||||
ResId: handler.reqId,
|
||||
@ -369,7 +447,7 @@ func (handler *RpcResponseHandler) SendResponseError(err error) {
|
||||
if handler.reqId == "" || handler.done.Load() {
|
||||
return
|
||||
}
|
||||
handler.done.Store(true)
|
||||
defer handler.close()
|
||||
msg := &RpcMessage{
|
||||
ResId: handler.reqId,
|
||||
Error: err.Error(),
|
||||
@ -378,12 +456,27 @@ func (handler *RpcResponseHandler) SendResponseError(err error) {
|
||||
handler.w.OutputCh <- barr
|
||||
}
|
||||
|
||||
func (handler *RpcResponseHandler) finalize() {
|
||||
func (handler *RpcResponseHandler) IsCanceled() bool {
|
||||
return handler.canceled.Load()
|
||||
}
|
||||
|
||||
func (handler *RpcResponseHandler) close() {
|
||||
cancelFn := handler.contextCancelFn.Load()
|
||||
if cancelFn != nil && *cancelFn != nil {
|
||||
(*cancelFn)()
|
||||
handler.contextCancelFn.Store(nil)
|
||||
}
|
||||
handler.done.Store(true)
|
||||
}
|
||||
|
||||
// if async, caller must call finalize
|
||||
func (handler *RpcResponseHandler) Finalize() {
|
||||
if handler.reqId == "" || handler.done.Load() {
|
||||
return
|
||||
}
|
||||
handler.done.Store(true)
|
||||
handler.SendResponse(nil, true)
|
||||
handler.close()
|
||||
handler.w.unregisterResponseHandler(handler.reqId)
|
||||
}
|
||||
|
||||
func (handler *RpcResponseHandler) IsDone() bool {
|
||||
@ -395,12 +488,15 @@ func (w *WshRpc) SendComplexRequest(command string, data any, expectsResponse bo
|
||||
return nil, fmt.Errorf("command cannot be empty")
|
||||
}
|
||||
handler := &RpcRequestHandler{
|
||||
w: w,
|
||||
w: w,
|
||||
ctxCancelFn: &atomic.Pointer[context.CancelFunc]{},
|
||||
}
|
||||
if timeoutMs < 0 {
|
||||
handler.ctx = context.Background()
|
||||
} else {
|
||||
handler.ctx, handler.cancelFn = context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
|
||||
var cancelFn context.CancelFunc
|
||||
handler.ctx, cancelFn = context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
|
||||
handler.ctxCancelFn.Store(&cancelFn)
|
||||
}
|
||||
if expectsResponse {
|
||||
handler.reqId = uuid.New().String()
|
||||
|
Loading…
Reference in New Issue
Block a user