streaming rpc support (backend streams to the frontend) (#120)

This commit is contained in:
Mike Sawka 2024-07-18 15:56:04 -07:00 committed by GitHub
parent 734a066af8
commit 776ccd7da0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 542 additions and 255 deletions

View File

@ -46,9 +46,11 @@ tasks:
- go run cmd/generatewshclient/main-generatewshclient.go - go run cmd/generatewshclient/main-generatewshclient.go
sources: sources:
- "cmd/generate/*.go" - "cmd/generate/*.go"
- "cmd/generatewshclient/*.go"
- "pkg/service/**/*.go" - "pkg/service/**/*.go"
- "pkg/wstore/*.go" - "pkg/wstore/*.go"
- "pkg/wshrpc/**/*.go" - "pkg/wshrpc/**/*.go"
- "pkg/tsgen/**/*.go"
generates: generates:
- frontend/types/gotypes.d.ts - frontend/types/gotypes.d.ts
- pkg/wshrpc/wshclient/wshclient.go - pkg/wshrpc/wshclient/wshclient.go

View File

@ -12,7 +12,24 @@ import (
"github.com/wavetermdev/thenextwave/pkg/wshutil" "github.com/wavetermdev/thenextwave/pkg/wshutil"
) )
func genMethod(fd *os.File, methodDecl *wshserver.WshServerMethodDecl) { func genMethod_ResponseStream(fd *os.File, methodDecl *wshserver.WshServerMethodDecl) {
fmt.Fprintf(fd, "// command %q, wshserver.%s\n", methodDecl.Command, methodDecl.MethodName)
var dataType string
dataVarName := "nil"
if methodDecl.CommandDataType != nil {
dataType = ", data " + methodDecl.CommandDataType.String()
dataVarName = "data"
}
respType := "any"
if methodDecl.DefaultResponseDataType != nil {
respType = methodDecl.DefaultResponseDataType.String()
}
fmt.Fprintf(fd, "func %s(w *wshutil.WshRpc%s, opts *wshrpc.WshRpcCommandOpts) chan wshrpc.RespOrErrorUnion[%s] {\n", methodDecl.MethodName, dataType, respType)
fmt.Fprintf(fd, " return sendRpcRequestResponseStreamHelper[%s](w, %q, %s, opts)\n", respType, methodDecl.Command, dataVarName)
fmt.Fprintf(fd, "}\n\n")
}
func genMethod_Call(fd *os.File, methodDecl *wshserver.WshServerMethodDecl) {
fmt.Fprintf(fd, "// command %q, wshserver.%s\n", methodDecl.Command, methodDecl.MethodName) fmt.Fprintf(fd, "// command %q, wshserver.%s\n", methodDecl.Command, methodDecl.MethodName)
var dataType string var dataType string
dataVarName := "nil" dataVarName := "nil"
@ -29,16 +46,12 @@ func genMethod(fd *os.File, methodDecl *wshserver.WshServerMethodDecl) {
tParamVal = methodDecl.DefaultResponseDataType.String() tParamVal = methodDecl.DefaultResponseDataType.String()
} }
fmt.Fprintf(fd, "func %s(w *wshutil.WshRpc%s, opts *wshrpc.WshRpcCommandOpts) %s {\n", methodDecl.MethodName, dataType, returnType) fmt.Fprintf(fd, "func %s(w *wshutil.WshRpc%s, opts *wshrpc.WshRpcCommandOpts) %s {\n", methodDecl.MethodName, dataType, returnType)
if methodDecl.CommandType == wshutil.RpcType_Call { fmt.Fprintf(fd, " %s, err := sendRpcRequestCallHelper[%s](w, %q, %s, opts)\n", respName, tParamVal, methodDecl.Command, dataVarName)
fmt.Fprintf(fd, " %s, err := sendRpcRequestHelper[%s](w, %q, %s, opts)\n", respName, tParamVal, methodDecl.Command, dataVarName)
if methodDecl.DefaultResponseDataType != nil { if methodDecl.DefaultResponseDataType != nil {
fmt.Fprintf(fd, " return resp, err\n") fmt.Fprintf(fd, " return resp, err\n")
} else { } else {
fmt.Fprintf(fd, " return err\n") fmt.Fprintf(fd, " return err\n")
} }
} else {
panic("unsupported command type " + methodDecl.CommandType)
}
fmt.Fprintf(fd, "}\n\n") fmt.Fprintf(fd, "}\n\n")
} }
@ -61,7 +74,13 @@ func main() {
for _, key := range utilfn.GetOrderedMapKeys(wshserver.WshServerCommandToDeclMap) { for _, key := range utilfn.GetOrderedMapKeys(wshserver.WshServerCommandToDeclMap) {
methodDecl := wshserver.WshServerCommandToDeclMap[key] methodDecl := wshserver.WshServerCommandToDeclMap[key]
genMethod(fd, methodDecl) if methodDecl.CommandType == wshutil.RpcType_ResponseStream {
genMethod_ResponseStream(fd, methodDecl)
} else if methodDecl.CommandType == wshutil.RpcType_Call {
genMethod_Call(fd, methodDecl)
} else {
panic("unsupported command type " + methodDecl.CommandType)
}
} }
fmt.Fprintf(fd, "\n") fmt.Fprintf(fd, "\n")
} }

View File

@ -107,12 +107,27 @@ function callBackendService(service: string, method: string, args: any[], noUICo
return prtn; return prtn;
} }
function callWshServerRpc( function wshServerRpcHelper_responsestream(
command: string, command: string,
data: any, data: any,
meta: WshServerCommandMeta,
opts: WshRpcCommandOpts opts: WshRpcCommandOpts
): Promise<any> { ): AsyncGenerator<any, void, boolean> {
if (opts?.noresponse) {
throw new Error("noresponse not supported for responsestream calls");
}
let msg: RpcMessage = {
command: command,
data: data,
reqid: uuidv4(),
};
if (opts?.timeout) {
msg.timeout = opts.timeout;
}
const rpcGen = sendRpcCommand(msg);
return rpcGen;
}
function wshServerRpcHelper_call(command: string, data: any, opts: WshRpcCommandOpts): Promise<any> {
let msg: RpcMessage = { let msg: RpcMessage = {
command: command, command: command,
data: data, data: data,
@ -123,32 +138,14 @@ function callWshServerRpc(
if (opts?.timeout) { if (opts?.timeout) {
msg.timeout = opts.timeout; msg.timeout = opts.timeout;
} }
if (meta.commandtype != "call") {
throw new Error("unimplemented wshserver commandtype " + meta.commandtype);
}
const rpcGen = sendRpcCommand(msg); const rpcGen = sendRpcCommand(msg);
if (rpcGen == null) { if (rpcGen == null) {
return null; return null;
} }
let resolveFn: (value: any) => void; const respMsgPromise = rpcGen.next(true); // pass true to force termination of rpc after 1 response (not streaming)
let rejectFn: (reason?: any) => void; return respMsgPromise.then((msg: IteratorResult<any, void>) => {
const prtn = new Promise((resolve, reject) => { return msg.value;
resolveFn = resolve;
rejectFn = reject;
}); });
const respMsg = rpcGen.next(true); // pass true to force termination of rpc after 1 response (not streaing)
respMsg.then((msg: IteratorResult<RpcMessage, void>) => {
if (msg.value == null) {
resolveFn(null);
}
let respMsg: RpcMessage = msg.value as RpcMessage;
if (respMsg.error != null) {
rejectFn(new Error(respMsg.error));
return;
}
resolveFn(respMsg.data);
});
return prtn;
} }
const waveObjectValueCache = new Map<string, WaveObjectValue<any>>(); const waveObjectValueCache = new Map<string, WaveObjectValue<any>>();
@ -368,7 +365,6 @@ function setObjectValue<T extends WaveObj>(value: T, setFn?: jotai.Setter, pushT
export { export {
callBackendService, callBackendService,
callWshServerRpc,
cleanWaveObjectCache, cleanWaveObjectCache,
clearWaveObjectCache, clearWaveObjectCache,
getObjectValue, getObjectValue,
@ -383,4 +379,6 @@ export {
useWaveObjectValue, useWaveObjectValue,
useWaveObjectValueWithSuspense, useWaveObjectValueWithSuspense,
waveObjectValueCache, waveObjectValueCache,
wshServerRpcHelper_call,
wshServerRpcHelper_responsestream,
}; };

View File

@ -16,7 +16,7 @@ async function* rpcResponseGenerator(
command: string, command: string,
reqid: string, reqid: string,
timeout: number timeout: number
): AsyncGenerator<RpcMessage, void, boolean> { ): AsyncGenerator<any, void, boolean> {
const msgQueue: RpcMessage[] = []; const msgQueue: RpcMessage[] = [];
let signalFn: () => void; let signalFn: () => void;
let signalPromise = new Promise<void>((resolve) => (signalFn = resolve)); let signalPromise = new Promise<void>((resolve) => (signalFn = resolve));
@ -39,11 +39,18 @@ async function* rpcResponseGenerator(
command: command, command: command,
msgFn: msgFn, msgFn: msgFn,
}); });
yield null;
try { try {
while (true) { while (true) {
while (msgQueue.length > 0) { while (msgQueue.length > 0) {
const msg = msgQueue.shift()!; const msg = msgQueue.shift()!;
const shouldTerminate = yield msg; if (msg.error != null) {
throw new Error(msg.error);
}
if (!msg.cont && msg.data == null) {
return;
}
const shouldTerminate = yield msg.data;
if (shouldTerminate || !msg.cont) { if (shouldTerminate || !msg.cont) {
return; return;
} }
@ -64,7 +71,9 @@ function sendRpcCommand(msg: RpcMessage): AsyncGenerator<RpcMessage, void, boole
if (msg.reqid == null) { if (msg.reqid == null) {
return null; return null;
} }
return rpcResponseGenerator(msg.command, msg.reqid, msg.timeout); const rtnGen = rpcResponseGenerator(msg.command, msg.reqid, msg.timeout);
rtnGen.next(); // start the generator (run the initialization/registration logic, throw away the result)
return rtnGen;
} }
function handleIncomingRpcMessage(msg: RpcMessage) { function handleIncomingRpcMessage(msg: RpcMessage) {
@ -85,4 +94,22 @@ function handleIncomingRpcMessage(msg: RpcMessage) {
entry.msgFn(msg); entry.msgFn(msg);
} }
async function consumeGenerator(gen: AsyncGenerator<any, any, any>) {
let idx = 0;
try {
for await (const msg of gen) {
console.log("gen", idx, msg);
idx++;
}
const result = await gen.return(undefined);
console.log("gen done", result.value);
} catch (e) {
console.log("gen error", e);
}
}
if (globalThis.window != null) {
globalThis["consumeGenerator"] = consumeGenerator;
}
export { handleIncomingRpcMessage, sendRpcCommand }; export { handleIncomingRpcMessage, sendRpcCommand };

View File

@ -9,62 +9,57 @@ import * as WOS from "./wos";
class WshServerType { class WshServerType {
// command "controller:input" [call] // command "controller:input" [call]
BlockInputCommand(data: CommandBlockInputData, opts?: WshRpcCommandOpts): Promise<void> { BlockInputCommand(data: CommandBlockInputData, opts?: WshRpcCommandOpts): Promise<void> {
const meta: WshServerCommandMeta = {commandtype: "call"}; return WOS.wshServerRpcHelper_call("controller:input", data, opts);
return WOS.callWshServerRpc("controller:input", data, meta, opts);
} }
// command "controller:restart" [call] // command "controller:restart" [call]
BlockRestartCommand(data: CommandBlockRestartData, opts?: WshRpcCommandOpts): Promise<void> { BlockRestartCommand(data: CommandBlockRestartData, opts?: WshRpcCommandOpts): Promise<void> {
const meta: WshServerCommandMeta = {commandtype: "call"}; return WOS.wshServerRpcHelper_call("controller:restart", data, opts);
return WOS.callWshServerRpc("controller:restart", data, meta, opts);
} }
// command "createblock" [call] // command "createblock" [call]
CreateBlockCommand(data: CommandCreateBlockData, opts?: WshRpcCommandOpts): Promise<ORef> { CreateBlockCommand(data: CommandCreateBlockData, opts?: WshRpcCommandOpts): Promise<ORef> {
const meta: WshServerCommandMeta = {commandtype: "call"}; return WOS.wshServerRpcHelper_call("createblock", data, opts);
return WOS.callWshServerRpc("createblock", data, meta, opts);
} }
// command "file:append" [call] // command "file:append" [call]
AppendFileCommand(data: CommandAppendFileData, opts?: WshRpcCommandOpts): Promise<void> { AppendFileCommand(data: CommandAppendFileData, opts?: WshRpcCommandOpts): Promise<void> {
const meta: WshServerCommandMeta = {commandtype: "call"}; return WOS.wshServerRpcHelper_call("file:append", data, opts);
return WOS.callWshServerRpc("file:append", data, meta, opts);
} }
// command "file:appendijson" [call] // command "file:appendijson" [call]
AppendIJsonCommand(data: CommandAppendIJsonData, opts?: WshRpcCommandOpts): Promise<void> { AppendIJsonCommand(data: CommandAppendIJsonData, opts?: WshRpcCommandOpts): Promise<void> {
const meta: WshServerCommandMeta = {commandtype: "call"}; return WOS.wshServerRpcHelper_call("file:appendijson", data, opts);
return WOS.callWshServerRpc("file:appendijson", data, meta, opts);
} }
// command "getmeta" [call] // command "getmeta" [call]
GetMetaCommand(data: CommandGetMetaData, opts?: WshRpcCommandOpts): Promise<MetaType> { GetMetaCommand(data: CommandGetMetaData, opts?: WshRpcCommandOpts): Promise<MetaType> {
const meta: WshServerCommandMeta = {commandtype: "call"}; return WOS.wshServerRpcHelper_call("getmeta", data, opts);
return WOS.callWshServerRpc("getmeta", data, meta, opts);
} }
// command "message" [call] // command "message" [call]
MessageCommand(data: CommandMessageData, opts?: WshRpcCommandOpts): Promise<void> { MessageCommand(data: CommandMessageData, opts?: WshRpcCommandOpts): Promise<void> {
const meta: WshServerCommandMeta = {commandtype: "call"}; return WOS.wshServerRpcHelper_call("message", data, opts);
return WOS.callWshServerRpc("message", data, meta, opts);
} }
// command "resolveids" [call] // command "resolveids" [call]
ResolveIdsCommand(data: CommandResolveIdsData, opts?: WshRpcCommandOpts): Promise<CommandResolveIdsRtnData> { ResolveIdsCommand(data: CommandResolveIdsData, opts?: WshRpcCommandOpts): Promise<CommandResolveIdsRtnData> {
const meta: WshServerCommandMeta = {commandtype: "call"}; return WOS.wshServerRpcHelper_call("resolveids", data, opts);
return WOS.callWshServerRpc("resolveids", data, meta, opts);
} }
// command "setmeta" [call] // command "setmeta" [call]
SetMetaCommand(data: CommandSetMetaData, opts?: WshRpcCommandOpts): Promise<void> { SetMetaCommand(data: CommandSetMetaData, opts?: WshRpcCommandOpts): Promise<void> {
const meta: WshServerCommandMeta = {commandtype: "call"}; return WOS.wshServerRpcHelper_call("setmeta", data, opts);
return WOS.callWshServerRpc("setmeta", data, meta, opts);
} }
// command "setview" [call] // command "setview" [call]
BlockSetViewCommand(data: CommandBlockSetViewData, opts?: WshRpcCommandOpts): Promise<void> { BlockSetViewCommand(data: CommandBlockSetViewData, opts?: WshRpcCommandOpts): Promise<void> {
const meta: WshServerCommandMeta = {commandtype: "call"}; return WOS.wshServerRpcHelper_call("setview", data, opts);
return WOS.callWshServerRpc("setview", data, meta, opts); }
// command "streamtest" [responsestream]
RespStreamTest(opts?: WshRpcCommandOpts): AsyncGenerator<number, void, boolean> {
return WOS.wshServerRpcHelper_responsestream("streamtest", null, opts);
} }
} }

View File

@ -189,6 +189,7 @@ declare global {
resid?: string; resid?: string;
timeout?: number; timeout?: number;
cont?: boolean; cont?: boolean;
cancel?: boolean;
error?: string; error?: string;
datatype?: string; datatype?: string;
data?: any; data?: any;

View File

@ -392,6 +392,38 @@ func GenerateServiceClass(serviceName string, serviceObj any, tsTypesMap map[ref
} }
func GenerateWshServerMethod(methodDecl *wshserver.WshServerMethodDecl, tsTypesMap map[reflect.Type]string) string { func GenerateWshServerMethod(methodDecl *wshserver.WshServerMethodDecl, tsTypesMap map[reflect.Type]string) string {
if methodDecl.CommandType == wshutil.RpcType_ResponseStream {
return GenerateWshServerMethod_ResponseStream(methodDecl, tsTypesMap)
} else if methodDecl.CommandType == wshutil.RpcType_Call {
return GenerateWshServerMethod_Call(methodDecl, tsTypesMap)
} else {
panic(fmt.Sprintf("cannot generate wshserver commandtype %q", methodDecl.CommandType))
}
}
func GenerateWshServerMethod_ResponseStream(methodDecl *wshserver.WshServerMethodDecl, tsTypesMap map[reflect.Type]string) string {
var sb strings.Builder
sb.WriteString(fmt.Sprintf(" // command %q [%s]\n", methodDecl.Command, methodDecl.CommandType))
respType := "any"
if methodDecl.DefaultResponseDataType != nil {
respType, _ = TypeToTSType(methodDecl.DefaultResponseDataType, tsTypesMap)
}
dataName := "null"
if methodDecl.CommandDataType != nil {
dataName = "data"
}
genRespType := fmt.Sprintf("AsyncGenerator<%s, void, boolean>", respType)
if methodDecl.CommandDataType != nil {
sb.WriteString(fmt.Sprintf(" %s(data: %s, opts?: WshRpcCommandOpts): %s {\n", methodDecl.MethodName, methodDecl.CommandDataType.Name(), genRespType))
} else {
sb.WriteString(fmt.Sprintf(" %s(opts?: WshRpcCommandOpts): %s {\n", methodDecl.MethodName, genRespType))
}
sb.WriteString(fmt.Sprintf(" return WOS.wshServerRpcHelper_responsestream(%q, %s, opts);\n", methodDecl.Command, dataName))
sb.WriteString(" }\n")
return sb.String()
}
func GenerateWshServerMethod_Call(methodDecl *wshserver.WshServerMethodDecl, tsTypesMap map[reflect.Type]string) string {
var sb strings.Builder var sb strings.Builder
sb.WriteString(fmt.Sprintf(" // command %q [%s]\n", methodDecl.Command, methodDecl.CommandType)) sb.WriteString(fmt.Sprintf(" // command %q [%s]\n", methodDecl.Command, methodDecl.CommandType))
rtnType := "Promise<void>" rtnType := "Promise<void>"
@ -399,14 +431,16 @@ func GenerateWshServerMethod(methodDecl *wshserver.WshServerMethodDecl, tsTypesM
rtnTypeName, _ := TypeToTSType(methodDecl.DefaultResponseDataType, tsTypesMap) rtnTypeName, _ := TypeToTSType(methodDecl.DefaultResponseDataType, tsTypesMap)
rtnType = fmt.Sprintf("Promise<%s>", rtnTypeName) rtnType = fmt.Sprintf("Promise<%s>", rtnTypeName)
} }
dataName := "null"
if methodDecl.CommandDataType != nil {
dataName = "data"
}
if methodDecl.CommandDataType != nil { if methodDecl.CommandDataType != nil {
sb.WriteString(fmt.Sprintf(" %s(data: %s, opts?: WshRpcCommandOpts): %s {\n", methodDecl.MethodName, methodDecl.CommandDataType.Name(), rtnType)) sb.WriteString(fmt.Sprintf(" %s(data: %s, opts?: WshRpcCommandOpts): %s {\n", methodDecl.MethodName, methodDecl.CommandDataType.Name(), rtnType))
} else { } else {
sb.WriteString(fmt.Sprintf(" %s(opts?: WshRpcCommandOpts): %s {\n", methodDecl.MethodName, rtnType)) sb.WriteString(fmt.Sprintf(" %s(opts?: WshRpcCommandOpts): %s {\n", methodDecl.MethodName, rtnType))
} }
metaData := fmt.Sprintf(" const meta: WshServerCommandMeta = {commandtype: %q};\n", methodDecl.CommandType) methodBody := fmt.Sprintf(" return WOS.wshServerRpcHelper_call(%q, %s, opts);\n", methodDecl.Command, dataName)
methodBody := fmt.Sprintf(" return WOS.callWshServerRpc(%q, data, meta, opts);\n", methodDecl.Command)
sb.WriteString(metaData)
sb.WriteString(methodBody) sb.WriteString(methodBody)
sb.WriteString(" }\n") sb.WriteString(" }\n")
return sb.String() return sb.String()

View File

@ -13,62 +13,67 @@ import (
// command "controller:input", wshserver.BlockInputCommand // command "controller:input", wshserver.BlockInputCommand
func BlockInputCommand(w *wshutil.WshRpc, data wshrpc.CommandBlockInputData, opts *wshrpc.WshRpcCommandOpts) error { func BlockInputCommand(w *wshutil.WshRpc, data wshrpc.CommandBlockInputData, opts *wshrpc.WshRpcCommandOpts) error {
_, err := sendRpcRequestHelper[any](w, "controller:input", data, opts) _, err := sendRpcRequestCallHelper[any](w, "controller:input", data, opts)
return err return err
} }
// command "controller:restart", wshserver.BlockRestartCommand // command "controller:restart", wshserver.BlockRestartCommand
func BlockRestartCommand(w *wshutil.WshRpc, data wshrpc.CommandBlockRestartData, opts *wshrpc.WshRpcCommandOpts) error { func BlockRestartCommand(w *wshutil.WshRpc, data wshrpc.CommandBlockRestartData, opts *wshrpc.WshRpcCommandOpts) error {
_, err := sendRpcRequestHelper[any](w, "controller:restart", data, opts) _, err := sendRpcRequestCallHelper[any](w, "controller:restart", data, opts)
return err return err
} }
// command "createblock", wshserver.CreateBlockCommand // command "createblock", wshserver.CreateBlockCommand
func CreateBlockCommand(w *wshutil.WshRpc, data wshrpc.CommandCreateBlockData, opts *wshrpc.WshRpcCommandOpts) (*waveobj.ORef, error) { func CreateBlockCommand(w *wshutil.WshRpc, data wshrpc.CommandCreateBlockData, opts *wshrpc.WshRpcCommandOpts) (*waveobj.ORef, error) {
resp, err := sendRpcRequestHelper[*waveobj.ORef](w, "createblock", data, opts) resp, err := sendRpcRequestCallHelper[*waveobj.ORef](w, "createblock", data, opts)
return resp, err return resp, err
} }
// command "file:append", wshserver.AppendFileCommand // command "file:append", wshserver.AppendFileCommand
func AppendFileCommand(w *wshutil.WshRpc, data wshrpc.CommandAppendFileData, opts *wshrpc.WshRpcCommandOpts) error { func AppendFileCommand(w *wshutil.WshRpc, data wshrpc.CommandAppendFileData, opts *wshrpc.WshRpcCommandOpts) error {
_, err := sendRpcRequestHelper[any](w, "file:append", data, opts) _, err := sendRpcRequestCallHelper[any](w, "file:append", data, opts)
return err return err
} }
// command "file:appendijson", wshserver.AppendIJsonCommand // command "file:appendijson", wshserver.AppendIJsonCommand
func AppendIJsonCommand(w *wshutil.WshRpc, data wshrpc.CommandAppendIJsonData, opts *wshrpc.WshRpcCommandOpts) error { func AppendIJsonCommand(w *wshutil.WshRpc, data wshrpc.CommandAppendIJsonData, opts *wshrpc.WshRpcCommandOpts) error {
_, err := sendRpcRequestHelper[any](w, "file:appendijson", data, opts) _, err := sendRpcRequestCallHelper[any](w, "file:appendijson", data, opts)
return err return err
} }
// command "getmeta", wshserver.GetMetaCommand // command "getmeta", wshserver.GetMetaCommand
func GetMetaCommand(w *wshutil.WshRpc, data wshrpc.CommandGetMetaData, opts *wshrpc.WshRpcCommandOpts) (map[string]interface {}, error) { func GetMetaCommand(w *wshutil.WshRpc, data wshrpc.CommandGetMetaData, opts *wshrpc.WshRpcCommandOpts) (map[string]interface {}, error) {
resp, err := sendRpcRequestHelper[map[string]interface {}](w, "getmeta", data, opts) resp, err := sendRpcRequestCallHelper[map[string]interface {}](w, "getmeta", data, opts)
return resp, err return resp, err
} }
// command "message", wshserver.MessageCommand // command "message", wshserver.MessageCommand
func MessageCommand(w *wshutil.WshRpc, data wshrpc.CommandMessageData, opts *wshrpc.WshRpcCommandOpts) error { func MessageCommand(w *wshutil.WshRpc, data wshrpc.CommandMessageData, opts *wshrpc.WshRpcCommandOpts) error {
_, err := sendRpcRequestHelper[any](w, "message", data, opts) _, err := sendRpcRequestCallHelper[any](w, "message", data, opts)
return err return err
} }
// command "resolveids", wshserver.ResolveIdsCommand // command "resolveids", wshserver.ResolveIdsCommand
func ResolveIdsCommand(w *wshutil.WshRpc, data wshrpc.CommandResolveIdsData, opts *wshrpc.WshRpcCommandOpts) (wshrpc.CommandResolveIdsRtnData, error) { func ResolveIdsCommand(w *wshutil.WshRpc, data wshrpc.CommandResolveIdsData, opts *wshrpc.WshRpcCommandOpts) (wshrpc.CommandResolveIdsRtnData, error) {
resp, err := sendRpcRequestHelper[wshrpc.CommandResolveIdsRtnData](w, "resolveids", data, opts) resp, err := sendRpcRequestCallHelper[wshrpc.CommandResolveIdsRtnData](w, "resolveids", data, opts)
return resp, err return resp, err
} }
// command "setmeta", wshserver.SetMetaCommand // command "setmeta", wshserver.SetMetaCommand
func SetMetaCommand(w *wshutil.WshRpc, data wshrpc.CommandSetMetaData, opts *wshrpc.WshRpcCommandOpts) error { func SetMetaCommand(w *wshutil.WshRpc, data wshrpc.CommandSetMetaData, opts *wshrpc.WshRpcCommandOpts) error {
_, err := sendRpcRequestHelper[any](w, "setmeta", data, opts) _, err := sendRpcRequestCallHelper[any](w, "setmeta", data, opts)
return err return err
} }
// command "setview", wshserver.BlockSetViewCommand // command "setview", wshserver.BlockSetViewCommand
func BlockSetViewCommand(w *wshutil.WshRpc, data wshrpc.CommandBlockSetViewData, opts *wshrpc.WshRpcCommandOpts) error { func BlockSetViewCommand(w *wshutil.WshRpc, data wshrpc.CommandBlockSetViewData, opts *wshrpc.WshRpcCommandOpts) error {
_, err := sendRpcRequestHelper[any](w, "setview", data, opts) _, err := sendRpcRequestCallHelper[any](w, "setview", data, opts)
return err return err
} }
// command "streamtest", wshserver.RespStreamTest
func RespStreamTest(w *wshutil.WshRpc, opts *wshrpc.WshRpcCommandOpts) chan wshrpc.RespOrErrorUnion[int] {
return sendRpcRequestResponseStreamHelper[int](w, "streamtest", nil, opts)
}

View File

@ -9,7 +9,7 @@ import (
"github.com/wavetermdev/thenextwave/pkg/wshutil" "github.com/wavetermdev/thenextwave/pkg/wshutil"
) )
func sendRpcRequestHelper[T any](w *wshutil.WshRpc, command string, data interface{}, opts *wshrpc.WshRpcCommandOpts) (T, error) { func sendRpcRequestCallHelper[T any](w *wshutil.WshRpc, command string, data interface{}, opts *wshrpc.WshRpcCommandOpts) (T, error) {
var respData T var respData T
if opts.NoResponse { if opts.NoResponse {
err := w.SendCommand(command, data) err := w.SendCommand(command, data)
@ -28,3 +28,36 @@ func sendRpcRequestHelper[T any](w *wshutil.WshRpc, command string, data interfa
} }
return respData, nil return respData, nil
} }
func sendRpcRequestResponseStreamHelper[T any](w *wshutil.WshRpc, command string, data interface{}, opts *wshrpc.WshRpcCommandOpts) chan wshrpc.RespOrErrorUnion[T] {
respChan := make(chan wshrpc.RespOrErrorUnion[T])
reqHandler, err := w.SendComplexRequest(command, data, true, opts.Timeout)
if err != nil {
go func() {
respChan <- wshrpc.RespOrErrorUnion[T]{Error: err}
close(respChan)
}()
} else {
go func() {
defer close(respChan)
for {
if reqHandler.ResponseDone() {
break
}
resp, err := reqHandler.NextResponse()
if err != nil {
respChan <- wshrpc.RespOrErrorUnion[T]{Error: err}
break
}
var respData T
err = utilfn.ReUnmarshal(&respData, resp)
if err != nil {
respChan <- wshrpc.RespOrErrorUnion[T]{Error: err}
break
}
respChan <- wshrpc.RespOrErrorUnion[T]{Response: respData}
}
}()
}
return respChan
}

View File

@ -35,6 +35,11 @@ var DataTypeMap = map[string]reflect.Type{
"oref": reflect.TypeOf(waveobj.ORef{}), "oref": reflect.TypeOf(waveobj.ORef{}),
} }
type RespOrErrorUnion[T any] struct {
Response T
Error error
}
// for frontend // for frontend
type WshServerCommandMeta struct { type WshServerCommandMeta struct {
CommandType string `json:"commandtype"` CommandType string `json:"commandtype"`

View File

@ -3,14 +3,14 @@
package wshserver package wshserver
// this file contains the implementation of the wsh server methods
import ( import (
"context" "context"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"io/fs" "io/fs"
"log" "log"
"net"
"os"
"reflect" "reflect"
"strings" "strings"
"time" "time"
@ -18,33 +18,19 @@ import (
"github.com/wavetermdev/thenextwave/pkg/blockcontroller" "github.com/wavetermdev/thenextwave/pkg/blockcontroller"
"github.com/wavetermdev/thenextwave/pkg/eventbus" "github.com/wavetermdev/thenextwave/pkg/eventbus"
"github.com/wavetermdev/thenextwave/pkg/filestore" "github.com/wavetermdev/thenextwave/pkg/filestore"
"github.com/wavetermdev/thenextwave/pkg/util/utilfn"
"github.com/wavetermdev/thenextwave/pkg/wavebase"
"github.com/wavetermdev/thenextwave/pkg/waveobj" "github.com/wavetermdev/thenextwave/pkg/waveobj"
"github.com/wavetermdev/thenextwave/pkg/wshrpc" "github.com/wavetermdev/thenextwave/pkg/wshrpc"
"github.com/wavetermdev/thenextwave/pkg/wshutil" "github.com/wavetermdev/thenextwave/pkg/wshutil"
"github.com/wavetermdev/thenextwave/pkg/wstore" "github.com/wavetermdev/thenextwave/pkg/wstore"
) )
const ( var RespStreamTest_MethodDecl = &WshServerMethodDecl{
DefaultOutputChSize = 32 Command: "streamtest",
DefaultInputChSize = 32 CommandType: wshutil.RpcType_ResponseStream,
) MethodName: "RespStreamTest",
Method: reflect.ValueOf(WshServerImpl.RespStreamTest),
type WshServer struct{} CommandDataType: nil,
DefaultResponseDataType: reflect.TypeOf((int)(0)),
var WshServerImpl = WshServer{}
var contextRType = reflect.TypeOf((*context.Context)(nil)).Elem()
type WshServerMethodDecl struct {
Command string
CommandType string
MethodName string
Method reflect.Value
CommandDataType reflect.Type
DefaultResponseDataType reflect.Type
RequestDataTypes []reflect.Type // for streaming requests
ResponseDataTypes []reflect.Type // for streaming responses
} }
var WshServerCommandToDeclMap = map[string]*WshServerMethodDecl{ var WshServerCommandToDeclMap = map[string]*WshServerMethodDecl{
@ -58,37 +44,28 @@ var WshServerCommandToDeclMap = map[string]*WshServerMethodDecl{
wshrpc.Command_BlockInput: GetWshServerMethod(wshrpc.Command_BlockInput, wshutil.RpcType_Call, "BlockInputCommand", WshServerImpl.BlockInputCommand), wshrpc.Command_BlockInput: GetWshServerMethod(wshrpc.Command_BlockInput, wshutil.RpcType_Call, "BlockInputCommand", WshServerImpl.BlockInputCommand),
wshrpc.Command_AppendFile: GetWshServerMethod(wshrpc.Command_AppendFile, wshutil.RpcType_Call, "AppendFileCommand", WshServerImpl.AppendFileCommand), wshrpc.Command_AppendFile: GetWshServerMethod(wshrpc.Command_AppendFile, wshutil.RpcType_Call, "AppendFileCommand", WshServerImpl.AppendFileCommand),
wshrpc.Command_AppendIJson: GetWshServerMethod(wshrpc.Command_AppendIJson, wshutil.RpcType_Call, "AppendIJsonCommand", WshServerImpl.AppendIJsonCommand), wshrpc.Command_AppendIJson: GetWshServerMethod(wshrpc.Command_AppendIJson, wshutil.RpcType_Call, "AppendIJsonCommand", WshServerImpl.AppendIJsonCommand),
"streamtest": RespStreamTest_MethodDecl,
} }
func GetWshServerMethod(command string, commandType string, methodName string, methodFunc any) *WshServerMethodDecl { // for testing
methodVal := reflect.ValueOf(methodFunc)
methodType := methodVal.Type()
if methodType.Kind() != reflect.Func {
panic(fmt.Sprintf("methodVal must be a function got [%v]", methodType))
}
if methodType.In(0) != contextRType {
panic(fmt.Sprintf("methodVal must have a context as the first argument %v", methodType))
}
var defResponseType reflect.Type
if methodType.NumOut() > 1 {
defResponseType = methodType.Out(0)
}
rtn := &WshServerMethodDecl{
Command: command,
CommandType: commandType,
MethodName: methodName,
Method: methodVal,
CommandDataType: methodType.In(1),
DefaultResponseDataType: defResponseType,
}
return rtn
}
func (ws *WshServer) MessageCommand(ctx context.Context, data wshrpc.CommandMessageData) error { func (ws *WshServer) MessageCommand(ctx context.Context, data wshrpc.CommandMessageData) error {
log.Printf("MESSAGE: %s | %q\n", data.ORef, data.Message) log.Printf("MESSAGE: %s | %q\n", data.ORef, data.Message)
return nil return nil
} }
// for testing
func (ws *WshServer) RespStreamTest(ctx context.Context) chan wshrpc.RespOrErrorUnion[int] {
rtn := make(chan wshrpc.RespOrErrorUnion[int])
go func() {
for i := 1; i <= 5; i++ {
rtn <- wshrpc.RespOrErrorUnion[int]{Response: i}
time.Sleep(1 * time.Second)
}
close(rtn)
}()
return rtn
}
func (ws *WshServer) GetMetaCommand(ctx context.Context, data wshrpc.CommandGetMetaData) (wshrpc.MetaDataType, error) { func (ws *WshServer) GetMetaCommand(ctx context.Context, data wshrpc.CommandGetMetaData) (wshrpc.MetaDataType, error) {
log.Printf("calling meta: %s\n", data.ORef) log.Printf("calling meta: %s\n", data.ORef)
obj, err := wstore.DBGetORef(ctx, data.ORef) obj, err := wstore.DBGetORef(ctx, data.ORef)
@ -317,101 +294,3 @@ func (ws *WshServer) AppendIJsonCommand(ctx context.Context, data wshrpc.Command
}) })
return nil return nil
} }
func decodeRtnVals(rtnVals []reflect.Value) (any, error) {
switch len(rtnVals) {
case 0:
return nil, nil
case 1:
errIf := rtnVals[0].Interface()
if errIf == nil {
return nil, nil
}
return nil, errIf.(error)
case 2:
errIf := rtnVals[1].Interface()
if errIf == nil {
return rtnVals[0].Interface(), nil
}
return rtnVals[0].Interface(), errIf.(error)
default:
return nil, fmt.Errorf("too many return values: %d", len(rtnVals))
}
}
func mainWshServerHandler(handler *wshutil.RpcResponseHandler) {
command := handler.GetCommand()
methodDecl := WshServerCommandToDeclMap[command]
if methodDecl == nil {
handler.SendResponseError(fmt.Errorf("command %q not found", command))
return
}
var callParams []reflect.Value
callParams = append(callParams, reflect.ValueOf(handler.Context()))
if methodDecl.CommandDataType != nil {
commandData := reflect.New(methodDecl.CommandDataType).Interface()
err := utilfn.ReUnmarshal(commandData, handler.GetCommandRawData())
if err != nil {
handler.SendResponseError(fmt.Errorf("error re-marshalling command data: %w", err))
return
}
wshrpc.HackRpcContextIntoData(commandData, handler.GetRpcContext())
callParams = append(callParams, reflect.ValueOf(commandData).Elem())
}
rtnVals := methodDecl.Method.Call(callParams)
rtnData, rtnErr := decodeRtnVals(rtnVals)
if rtnErr != nil {
handler.SendResponseError(rtnErr)
return
} else {
handler.SendResponse(rtnData, true)
}
}
func MakeUnixListener(sockName string) (net.Listener, error) {
os.Remove(sockName) // ignore error
rtn, err := net.Listen("unix", sockName)
if err != nil {
return nil, fmt.Errorf("error creating listener at %v: %v", sockName, err)
}
os.Chmod(sockName, 0700)
log.Printf("Server listening on %s\n", sockName)
return rtn, nil
}
func runWshRpcWithStream(conn net.Conn) {
defer conn.Close()
inputCh := make(chan []byte, DefaultInputChSize)
outputCh := make(chan []byte, DefaultOutputChSize)
go wshutil.AdaptMsgChToStream(outputCh, conn)
go wshutil.AdaptStreamToMsgCh(conn, inputCh)
wshutil.MakeWshRpc(inputCh, outputCh, wshutil.RpcContext{}, mainWshServerHandler)
}
func RunWshRpcOverListener(listener net.Listener) {
go func() {
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("error accepting connection: %v\n", err)
continue
}
go runWshRpcWithStream(conn)
}
}()
}
func RunDomainSocketWshServer() error {
sockName := wavebase.GetDomainSocketName()
listener, err := MakeUnixListener(sockName)
if err != nil {
return fmt.Errorf("error starging unix listener for wsh-server: %w", err)
}
defer listener.Close()
RunWshRpcOverListener(listener)
return nil
}
func MakeWshServer(inputCh chan []byte, outputCh chan []byte, initialCtx wshutil.RpcContext) {
wshutil.MakeWshRpc(inputCh, outputCh, initialCtx, mainWshServerHandler)
}

View File

@ -0,0 +1,193 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package wshserver
import (
"context"
"fmt"
"log"
"net"
"os"
"reflect"
"github.com/wavetermdev/thenextwave/pkg/util/utilfn"
"github.com/wavetermdev/thenextwave/pkg/wavebase"
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
"github.com/wavetermdev/thenextwave/pkg/wshutil"
)
// this file contains the generic types and functions that create and power the WSH server
const (
DefaultOutputChSize = 32
DefaultInputChSize = 32
)
type WshServer struct{}
type WshServerMethodDecl struct {
Command string
CommandType string
MethodName string
Method reflect.Value
CommandDataType reflect.Type
DefaultResponseDataType reflect.Type
RequestDataTypes []reflect.Type // for streaming requests
ResponseDataTypes []reflect.Type // for streaming responses
}
var WshServerImpl = WshServer{}
var contextRType = reflect.TypeOf((*context.Context)(nil)).Elem()
func GetWshServerMethod(command string, commandType string, methodName string, methodFunc any) *WshServerMethodDecl {
methodVal := reflect.ValueOf(methodFunc)
methodType := methodVal.Type()
if methodType.Kind() != reflect.Func {
panic(fmt.Sprintf("methodVal must be a function got [%v]", methodType))
}
if methodType.In(0) != contextRType {
panic(fmt.Sprintf("methodVal must have a context as the first argument %v", methodType))
}
var defResponseType reflect.Type
if methodType.NumOut() > 1 {
defResponseType = methodType.Out(0)
}
rtn := &WshServerMethodDecl{
Command: command,
CommandType: commandType,
MethodName: methodName,
Method: methodVal,
CommandDataType: methodType.In(1),
DefaultResponseDataType: defResponseType,
}
return rtn
}
func decodeRtnVals(rtnVals []reflect.Value) (any, error) {
switch len(rtnVals) {
case 0:
return nil, nil
case 1:
errIf := rtnVals[0].Interface()
if errIf == nil {
return nil, nil
}
return nil, errIf.(error)
case 2:
errIf := rtnVals[1].Interface()
if errIf == nil {
return rtnVals[0].Interface(), nil
}
return rtnVals[0].Interface(), errIf.(error)
default:
return nil, fmt.Errorf("too many return values: %d", len(rtnVals))
}
}
func mainWshServerHandler(handler *wshutil.RpcResponseHandler) bool {
command := handler.GetCommand()
methodDecl := WshServerCommandToDeclMap[command]
if methodDecl == nil {
handler.SendResponseError(fmt.Errorf("command %q not found", command))
return true
}
var callParams []reflect.Value
callParams = append(callParams, reflect.ValueOf(handler.Context()))
if methodDecl.CommandDataType != nil {
commandData := reflect.New(methodDecl.CommandDataType).Interface()
err := utilfn.ReUnmarshal(commandData, handler.GetCommandRawData())
if err != nil {
handler.SendResponseError(fmt.Errorf("error re-marshalling command data: %w", err))
return true
}
wshrpc.HackRpcContextIntoData(commandData, handler.GetRpcContext())
callParams = append(callParams, reflect.ValueOf(commandData).Elem())
}
if methodDecl.CommandType == wshutil.RpcType_Call {
rtnVals := methodDecl.Method.Call(callParams)
rtnData, rtnErr := decodeRtnVals(rtnVals)
if rtnErr != nil {
handler.SendResponseError(rtnErr)
return true
}
handler.SendResponse(rtnData, true)
return true
} else if methodDecl.CommandType == wshutil.RpcType_ResponseStream {
rtnVals := methodDecl.Method.Call(callParams)
rtnChVal := rtnVals[0]
if rtnChVal.IsNil() {
handler.SendResponse(nil, true)
return true
}
go func() {
defer handler.Finalize()
// must use reflection here because we don't know the generic type of RespOrErrorUnion
for {
respVal, ok := rtnChVal.Recv()
if !ok {
break
}
errorVal := respVal.FieldByName("Error")
if !errorVal.IsNil() {
handler.SendResponseError(errorVal.Interface().(error))
break
}
respData := respVal.FieldByName("Response").Interface()
handler.SendResponse(respData, false)
}
}()
return false
} else {
handler.SendResponseError(fmt.Errorf("unsupported command type %q", methodDecl.CommandType))
return true
}
}
func MakeUnixListener(sockName string) (net.Listener, error) {
os.Remove(sockName) // ignore error
rtn, err := net.Listen("unix", sockName)
if err != nil {
return nil, fmt.Errorf("error creating listener at %v: %v", sockName, err)
}
os.Chmod(sockName, 0700)
log.Printf("Server listening on %s\n", sockName)
return rtn, nil
}
func runWshRpcWithStream(conn net.Conn) {
defer conn.Close()
inputCh := make(chan []byte, DefaultInputChSize)
outputCh := make(chan []byte, DefaultOutputChSize)
go wshutil.AdaptMsgChToStream(outputCh, conn)
go wshutil.AdaptStreamToMsgCh(conn, inputCh)
wshutil.MakeWshRpc(inputCh, outputCh, wshutil.RpcContext{}, mainWshServerHandler)
}
func RunWshRpcOverListener(listener net.Listener) {
go func() {
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("error accepting connection: %v\n", err)
continue
}
go runWshRpcWithStream(conn)
}
}()
}
func RunDomainSocketWshServer() error {
sockName := wavebase.GetDomainSocketName()
listener, err := MakeUnixListener(sockName)
if err != nil {
return fmt.Errorf("error starging unix listener for wsh-server: %w", err)
}
defer listener.Close()
RunWshRpcOverListener(listener)
return nil
}
func MakeWshServer(inputCh chan []byte, outputCh chan []byte, initialCtx wshutil.RpcContext) {
wshutil.MakeWshRpc(inputCh, outputCh, initialCtx, mainWshServerHandler)
}

View File

@ -29,7 +29,9 @@ const (
) )
type ResponseFnType = func(any) error type ResponseFnType = func(any) error
type CommandHandlerFnType = func(*RpcResponseHandler)
// returns true if handler is complete, false for an async handler
type CommandHandlerFnType = func(*RpcResponseHandler) bool
type wshRpcContextKey struct{} type wshRpcContextKey struct{}
@ -50,7 +52,8 @@ type RpcMessage struct {
ReqId string `json:"reqid,omitempty"` ReqId string `json:"reqid,omitempty"`
ResId string `json:"resid,omitempty"` ResId string `json:"resid,omitempty"`
Timeout int `json:"timeout,omitempty"` Timeout int `json:"timeout,omitempty"`
Cont bool `json:"cont,omitempty"` Cont bool `json:"cont,omitempty"` // flag if additional requests/responses are forthcoming
Cancel bool `json:"cancel,omitempty"` // used to cancel a streaming request or response (sent from the side that is not streaming)
Error string `json:"error,omitempty"` Error string `json:"error,omitempty"`
DataType string `json:"datatype,omitempty"` DataType string `json:"datatype,omitempty"`
Data any `json:"data,omitempty"` Data any `json:"data,omitempty"`
@ -61,6 +64,21 @@ func (r *RpcMessage) IsRpcRequest() bool {
} }
func (r *RpcMessage) Validate() error { func (r *RpcMessage) Validate() error {
if r.ReqId != "" && r.ResId != "" {
return fmt.Errorf("request packets may not have both reqid and resid set")
}
if r.Cancel {
if r.Command != "" {
return fmt.Errorf("cancel packets may not have command set")
}
if r.ReqId == "" && r.ResId == "" {
return fmt.Errorf("cancel packets must have reqid or resid set")
}
if r.Data != nil {
return fmt.Errorf("cancel packets may not have data set")
}
return nil
}
if r.Command != "" { if r.Command != "" {
if r.ResId != "" { if r.ResId != "" {
return fmt.Errorf("command packets may not have resid set") return fmt.Errorf("command packets may not have resid set")
@ -110,6 +128,8 @@ type WshRpc struct {
RpcContext *atomic.Pointer[RpcContext] RpcContext *atomic.Pointer[RpcContext]
RpcMap map[string]*rpcData RpcMap map[string]*rpcData
HandlerFn CommandHandlerFnType HandlerFn CommandHandlerFnType
ResponseHandlerMap map[string]*RpcResponseHandler // reqId => handler
} }
type rpcData struct { type rpcData struct {
@ -127,6 +147,7 @@ func MakeWshRpc(inputCh chan []byte, outputCh chan []byte, rpcCtx RpcContext, co
RpcMap: make(map[string]*rpcData), RpcMap: make(map[string]*rpcData),
RpcContext: &atomic.Pointer[RpcContext]{}, RpcContext: &atomic.Pointer[RpcContext]{},
HandlerFn: commandHandlerFn, HandlerFn: commandHandlerFn,
ResponseHandlerMap: make(map[string]*RpcResponseHandler),
} }
rtn.RpcContext.Store(&rpcCtx) rtn.RpcContext.Store(&rpcCtx)
go rtn.runServer() go rtn.runServer()
@ -142,6 +163,31 @@ func (w *WshRpc) SetRpcContext(ctx RpcContext) {
w.RpcContext.Store(&ctx) w.RpcContext.Store(&ctx)
} }
func (w *WshRpc) registerResponseHandler(reqId string, handler *RpcResponseHandler) {
w.Lock.Lock()
defer w.Lock.Unlock()
w.ResponseHandlerMap[reqId] = handler
}
func (w *WshRpc) unregisterResponseHandler(reqId string) {
w.Lock.Lock()
defer w.Lock.Unlock()
delete(w.ResponseHandlerMap, reqId)
}
func (w *WshRpc) cancelRequest(reqId string) {
if reqId == "" {
return
}
w.Lock.Lock()
defer w.Lock.Unlock()
handler := w.ResponseHandlerMap[reqId]
if handler != nil {
handler.canceled.Store(true)
}
}
func (w *WshRpc) handleRequest(req *RpcMessage) { func (w *WshRpc) handleRequest(req *RpcMessage) {
var respHandler *RpcResponseHandler var respHandler *RpcResponseHandler
defer func() { defer func() {
@ -159,7 +205,6 @@ func (w *WshRpc) handleRequest(req *RpcMessage) {
} }
ctx, cancelFn := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond) ctx, cancelFn := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
ctx = withWshRpcContext(ctx, w) ctx = withWshRpcContext(ctx, w)
defer cancelFn()
respHandler = &RpcResponseHandler{ respHandler = &RpcResponseHandler{
w: w, w: w,
ctx: ctx, ctx: ctx,
@ -167,18 +212,31 @@ func (w *WshRpc) handleRequest(req *RpcMessage) {
command: req.Command, command: req.Command,
commandData: req.Data, commandData: req.Data,
done: &atomic.Bool{}, done: &atomic.Bool{},
canceled: &atomic.Bool{},
contextCancelFn: &atomic.Pointer[context.CancelFunc]{},
rpcCtx: w.GetRpcContext(), rpcCtx: w.GetRpcContext(),
} }
respHandler.contextCancelFn.Store(&cancelFn)
w.registerResponseHandler(req.ReqId, respHandler)
isAsync := false
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
log.Printf("panic in handleRequest: %v\n", r) log.Printf("panic in handleRequest: %v\n", r)
debug.PrintStack() debug.PrintStack()
respHandler.SendResponseError(fmt.Errorf("panic: %v", r)) respHandler.SendResponseError(fmt.Errorf("panic: %v", r))
} }
respHandler.finalize() if isAsync {
go func() {
<-ctx.Done()
respHandler.Finalize()
}()
} else {
cancelFn()
respHandler.Finalize()
}
}() }()
if w.HandlerFn != nil { if w.HandlerFn != nil {
w.HandlerFn(respHandler) isAsync = !w.HandlerFn(respHandler)
} }
} }
@ -191,6 +249,12 @@ func (w *WshRpc) runServer() {
log.Printf("wshrpc received bad message: %v\n", err) log.Printf("wshrpc received bad message: %v\n", err)
continue continue
} }
if msg.Cancel {
if msg.ReqId != "" {
w.cancelRequest(msg.ReqId)
}
continue
}
if msg.IsRpcRequest() { if msg.IsRpcRequest() {
w.handleRequest(&msg) w.handleRequest(&msg)
} else { } else {
@ -281,7 +345,7 @@ func (w *WshRpc) SendRpcRequest(command string, data any, timeoutMs int) (any, e
type RpcRequestHandler struct { type RpcRequestHandler struct {
w *WshRpc w *WshRpc
ctx context.Context ctx context.Context
cancelFn func() ctxCancelFn *atomic.Pointer[context.CancelFunc]
reqId string reqId string
respCh chan *RpcMessage respCh chan *RpcMessage
} }
@ -290,6 +354,16 @@ func (handler *RpcRequestHandler) Context() context.Context {
return handler.ctx return handler.ctx
} }
func (handler *RpcRequestHandler) SendCancel() {
msg := &RpcMessage{
Cancel: true,
ReqId: handler.reqId,
}
barr, _ := json.Marshal(msg) // will never fail
handler.w.OutputCh <- barr
handler.finalize()
}
func (handler *RpcRequestHandler) ResponseDone() bool { func (handler *RpcRequestHandler) ResponseDone() bool {
select { select {
case _, more := <-handler.respCh: case _, more := <-handler.respCh:
@ -308,8 +382,10 @@ func (handler *RpcRequestHandler) NextResponse() (any, error) {
} }
func (handler *RpcRequestHandler) finalize() { func (handler *RpcRequestHandler) finalize() {
if handler.cancelFn != nil { cancelFnPtr := handler.ctxCancelFn.Load()
handler.cancelFn() if cancelFnPtr != nil && *cancelFnPtr != nil {
(*cancelFnPtr)()
handler.ctxCancelFn.Store(nil)
} }
if handler.reqId != "" { if handler.reqId != "" {
handler.w.unregisterRpc(handler.reqId, nil) handler.w.unregisterRpc(handler.reqId, nil)
@ -319,10 +395,12 @@ func (handler *RpcRequestHandler) finalize() {
type RpcResponseHandler struct { type RpcResponseHandler struct {
w *WshRpc w *WshRpc
ctx context.Context ctx context.Context
contextCancelFn *atomic.Pointer[context.CancelFunc]
reqId string reqId string
command string command string
commandData any commandData any
rpcCtx RpcContext rpcCtx RpcContext
canceled *atomic.Bool // canceled by requestor
done *atomic.Bool done *atomic.Bool
} }
@ -350,7 +428,7 @@ func (handler *RpcResponseHandler) SendResponse(data any, done bool) error {
return fmt.Errorf("request already done, cannot send additional response") return fmt.Errorf("request already done, cannot send additional response")
} }
if done { if done {
handler.done.Store(true) defer handler.close()
} }
msg := &RpcMessage{ msg := &RpcMessage{
ResId: handler.reqId, ResId: handler.reqId,
@ -369,7 +447,7 @@ func (handler *RpcResponseHandler) SendResponseError(err error) {
if handler.reqId == "" || handler.done.Load() { if handler.reqId == "" || handler.done.Load() {
return return
} }
handler.done.Store(true) defer handler.close()
msg := &RpcMessage{ msg := &RpcMessage{
ResId: handler.reqId, ResId: handler.reqId,
Error: err.Error(), Error: err.Error(),
@ -378,12 +456,27 @@ func (handler *RpcResponseHandler) SendResponseError(err error) {
handler.w.OutputCh <- barr handler.w.OutputCh <- barr
} }
func (handler *RpcResponseHandler) finalize() { func (handler *RpcResponseHandler) IsCanceled() bool {
return handler.canceled.Load()
}
func (handler *RpcResponseHandler) close() {
cancelFn := handler.contextCancelFn.Load()
if cancelFn != nil && *cancelFn != nil {
(*cancelFn)()
handler.contextCancelFn.Store(nil)
}
handler.done.Store(true)
}
// if async, caller must call finalize
func (handler *RpcResponseHandler) Finalize() {
if handler.reqId == "" || handler.done.Load() { if handler.reqId == "" || handler.done.Load() {
return return
} }
handler.done.Store(true)
handler.SendResponse(nil, true) handler.SendResponse(nil, true)
handler.close()
handler.w.unregisterResponseHandler(handler.reqId)
} }
func (handler *RpcResponseHandler) IsDone() bool { func (handler *RpcResponseHandler) IsDone() bool {
@ -396,11 +489,14 @@ func (w *WshRpc) SendComplexRequest(command string, data any, expectsResponse bo
} }
handler := &RpcRequestHandler{ handler := &RpcRequestHandler{
w: w, w: w,
ctxCancelFn: &atomic.Pointer[context.CancelFunc]{},
} }
if timeoutMs < 0 { if timeoutMs < 0 {
handler.ctx = context.Background() handler.ctx = context.Background()
} else { } else {
handler.ctx, handler.cancelFn = context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond) var cancelFn context.CancelFunc
handler.ctx, cancelFn = context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
handler.ctxCancelFn.Store(&cancelFn)
} }
if expectsResponse { if expectsResponse {
handler.reqId = uuid.New().String() handler.reqId = uuid.New().String()