From dbacae8a9925e4379d792aae10cbd033f3257329 Mon Sep 17 00:00:00 2001 From: Evan Simkowitz Date: Mon, 23 Dec 2024 13:55:04 -0500 Subject: [PATCH] Rename outdated WaveAI types (#1609) A bunch of the Wave AI types still mentioned OpenAI. Now that most of them are being used for multiple AI backends, we need to update the names to be more generic. --- frontend/app/store/services.ts | 2 +- frontend/app/store/wshclientapi.ts | 2 +- frontend/app/view/waveai/waveai.tsx | 18 ++--- frontend/types/gotypes.d.ts | 92 ++++++++++++------------ pkg/service/blockservice/blockservice.go | 2 +- pkg/waveai/anthropicbackend.go | 28 ++++---- pkg/waveai/cloudbackend.go | 20 +++--- pkg/waveai/openaibackend.go | 35 ++++----- pkg/waveai/perplexitybackend.go | 12 ++-- pkg/waveai/waveai.go | 22 +++--- pkg/wshrpc/wshclient/wshclient.go | 4 +- pkg/wshrpc/wshrpctypes.go | 18 ++--- pkg/wshrpc/wshserver/wshserver.go | 2 +- 13 files changed, 123 insertions(+), 134 deletions(-) diff --git a/frontend/app/store/services.ts b/frontend/app/store/services.ts index 9a9e5f453..5705af21c 100644 --- a/frontend/app/store/services.ts +++ b/frontend/app/store/services.ts @@ -15,7 +15,7 @@ class BlockServiceType { SaveTerminalState(blockId: string, state: string, stateType: string, ptyOffset: number, termSize: TermSize): Promise { return WOS.callBackendService("block", "SaveTerminalState", Array.from(arguments)) } - SaveWaveAiData(arg2: string, arg3: OpenAIPromptMessageType[]): Promise { + SaveWaveAiData(arg2: string, arg3: WaveAIPromptMessageType[]): Promise { return WOS.callBackendService("block", "SaveWaveAiData", Array.from(arguments)) } } diff --git a/frontend/app/store/wshclientapi.ts b/frontend/app/store/wshclientapi.ts index 846091596..b8a422fb0 100644 --- a/frontend/app/store/wshclientapi.ts +++ b/frontend/app/store/wshclientapi.ts @@ -303,7 +303,7 @@ class RpcApiType { } // command "streamwaveai" [responsestream] - StreamWaveAiCommand(client: WshClient, data: OpenAiStreamRequest, opts?: RpcOpts): AsyncGenerator { + StreamWaveAiCommand(client: WshClient, data: WaveAIStreamRequest, opts?: RpcOpts): AsyncGenerator { return client.wshRpcStream("streamwaveai", data, opts); } diff --git a/frontend/app/view/waveai/waveai.tsx b/frontend/app/view/waveai/waveai.tsx index 6ac27e721..6dbd3a03b 100644 --- a/frontend/app/view/waveai/waveai.tsx +++ b/frontend/app/view/waveai/waveai.tsx @@ -35,7 +35,7 @@ interface ChatItemProps { model: WaveAiModel; } -function promptToMsg(prompt: OpenAIPromptMessageType): ChatMessageType { +function promptToMsg(prompt: WaveAIPromptMessageType): ChatMessageType { return { id: crypto.randomUUID(), user: prompt.role, @@ -67,7 +67,7 @@ export class WaveAiModel implements ViewModel { blockAtom: Atom; presetKey: Atom; presetMap: Atom<{ [k: string]: MetaType }>; - aiOpts: Atom; + aiOpts: Atom; viewIcon?: Atom; viewName?: Atom; viewText?: Atom; @@ -167,7 +167,7 @@ export class WaveAiModel implements ViewModel { ...settings, ...meta, }; - const opts: OpenAIOptsType = { + const opts: WaveAIOptsType = { model: settings["ai:model"] ?? null, apitype: settings["ai:apitype"] ?? null, orgid: settings["ai:orgid"] ?? null, @@ -293,12 +293,12 @@ export class WaveAiModel implements ViewModel { globalStore.set(this.messagesAtom, history.map(promptToMsg)); } - async fetchAiData(): Promise> { + async fetchAiData(): Promise> { const { data } = await fetchWaveFile(this.blockId, "aidata"); if (!data) { return []; } - const history: Array = JSON.parse(new TextDecoder().decode(data)); + const history: Array = JSON.parse(new TextDecoder().decode(data)); return history.slice(Math.max(history.length - slidingWindowSize, 0)); } @@ -333,7 +333,7 @@ export class WaveAiModel implements ViewModel { globalStore.set(this.addMessageAtom, newMessage); // send message to backend and get response const opts = globalStore.get(this.aiOpts); - const newPrompt: OpenAIPromptMessageType = { + const newPrompt: WaveAIPromptMessageType = { role: "user", content: text, }; @@ -368,7 +368,7 @@ export class WaveAiModel implements ViewModel { // only save the author's prompt await BlockService.SaveWaveAiData(this.blockId, [...history, newPrompt]); } else { - const responsePrompt: OpenAIPromptMessageType = { + const responsePrompt: WaveAIPromptMessageType = { role: "assistant", content: fullMsg, }; @@ -383,7 +383,7 @@ export class WaveAiModel implements ViewModel { globalStore.set(this.removeLastMessageAtom); } else { globalStore.set(this.updateLastMessageAtom, "", false); - const responsePrompt: OpenAIPromptMessageType = { + const responsePrompt: WaveAIPromptMessageType = { role: "assistant", content: fullMsg, }; @@ -397,7 +397,7 @@ export class WaveAiModel implements ViewModel { }; globalStore.set(this.addMessageAtom, errorMessage); globalStore.set(this.updateLastMessageAtom, "", false); - const errorPrompt: OpenAIPromptMessageType = { + const errorPrompt: WaveAIPromptMessageType = { role: "error", content: errMsg, }; diff --git a/frontend/types/gotypes.d.ts b/frontend/types/gotypes.d.ts index 22f582776..12b83c1ef 100644 --- a/frontend/types/gotypes.d.ts +++ b/frontend/types/gotypes.d.ts @@ -519,52 +519,6 @@ declare global { // waveobj.ORef type ORef = string; - // wshrpc.OpenAIOptsType - type OpenAIOptsType = { - model: string; - apitype?: string; - apitoken: string; - orgid?: string; - apiversion?: string; - baseurl?: string; - maxtokens?: number; - maxchoices?: number; - timeoutms?: number; - }; - - // wshrpc.OpenAIPacketType - type OpenAIPacketType = { - type: string; - model?: string; - created?: number; - finish_reason?: string; - usage?: OpenAIUsageType; - index?: number; - text?: string; - error?: string; - }; - - // wshrpc.OpenAIPromptMessageType - type OpenAIPromptMessageType = { - role: string; - content: string; - name?: string; - }; - - // wshrpc.OpenAIUsageType - type OpenAIUsageType = { - prompt_tokens?: number; - completion_tokens?: number; - total_tokens?: number; - }; - - // wshrpc.OpenAiStreamRequest - type OpenAiStreamRequest = { - clientid?: string; - opts: OpenAIOptsType; - prompt: OpenAIPromptMessageType[]; - }; - // wshrpc.PathCommandData type PathCommandData = { pathtype: string; @@ -1016,6 +970,52 @@ declare global { fullconfig: FullConfigType; }; + // wshrpc.WaveAIOptsType + type WaveAIOptsType = { + model: string; + apitype?: string; + apitoken: string; + orgid?: string; + apiversion?: string; + baseurl?: string; + maxtokens?: number; + maxchoices?: number; + timeoutms?: number; + }; + + // wshrpc.WaveAIPacketType + type WaveAIPacketType = { + type: string; + model?: string; + created?: number; + finish_reason?: string; + usage?: WaveAIUsageType; + index?: number; + text?: string; + error?: string; + }; + + // wshrpc.WaveAIPromptMessageType + type WaveAIPromptMessageType = { + role: string; + content: string; + name?: string; + }; + + // wshrpc.WaveAIStreamRequest + type WaveAIStreamRequest = { + clientid?: string; + opts: WaveAIOptsType; + prompt: WaveAIPromptMessageType[]; + }; + + // wshrpc.WaveAIUsageType + type WaveAIUsageType = { + prompt_tokens?: number; + completion_tokens?: number; + total_tokens?: number; + }; + // wps.WaveEvent type WaveEvent = { event: string; diff --git a/pkg/service/blockservice/blockservice.go b/pkg/service/blockservice/blockservice.go index 92a6ba149..f5fd5f6ff 100644 --- a/pkg/service/blockservice/blockservice.go +++ b/pkg/service/blockservice/blockservice.go @@ -70,7 +70,7 @@ func (bs *BlockService) SaveTerminalState(ctx context.Context, blockId string, s return nil } -func (bs *BlockService) SaveWaveAiData(ctx context.Context, blockId string, history []wshrpc.OpenAIPromptMessageType) error { +func (bs *BlockService) SaveWaveAiData(ctx context.Context, blockId string, history []wshrpc.WaveAIPromptMessageType) error { block, err := wstore.DBMustGet[*waveobj.Block](ctx, blockId) if err != nil { return err diff --git a/pkg/waveai/anthropicbackend.go b/pkg/waveai/anthropicbackend.go index 3ec6f264a..d11fc680b 100644 --- a/pkg/waveai/anthropicbackend.go +++ b/pkg/waveai/anthropicbackend.go @@ -109,8 +109,8 @@ func parseSSE(reader *bufio.Reader) (*sseEvent, error) { } } -func (AnthropicBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { - rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]) +func (AnthropicBackend) StreamCompletion(ctx context.Context, request wshrpc.WaveAIStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType] { + rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]) go func() { defer func() { @@ -231,23 +231,23 @@ func (AnthropicBackend) StreamCompletion(ctx context.Context, request wshrpc.Ope switch sse.Event { case "message_start": if event.Message != nil { - pk := MakeOpenAIPacket() + pk := MakeWaveAIPacket() pk.Model = event.Message.Model - rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk} + rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk} } case "content_block_start": if event.ContentBlock != nil && event.ContentBlock.Text != "" { - pk := MakeOpenAIPacket() + pk := MakeWaveAIPacket() pk.Text = event.ContentBlock.Text - rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk} + rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk} } case "content_block_delta": if event.Delta != nil && event.Delta.Text != "" { - pk := MakeOpenAIPacket() + pk := MakeWaveAIPacket() pk.Text = event.Delta.Text - rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk} + rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk} } case "content_block_stop": @@ -258,27 +258,27 @@ func (AnthropicBackend) StreamCompletion(ctx context.Context, request wshrpc.Ope case "message_delta": // Update message metadata, usage stats if event.Usage != nil { - pk := MakeOpenAIPacket() - pk.Usage = &wshrpc.OpenAIUsageType{ + pk := MakeWaveAIPacket() + pk.Usage = &wshrpc.WaveAIUsageType{ PromptTokens: event.Usage.InputTokens, CompletionTokens: event.Usage.OutputTokens, TotalTokens: event.Usage.InputTokens + event.Usage.OutputTokens, } - rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk} + rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk} } case "message_stop": if event.Message != nil { - pk := MakeOpenAIPacket() + pk := MakeWaveAIPacket() pk.FinishReason = event.Message.StopReason if event.Message.Usage != nil { - pk.Usage = &wshrpc.OpenAIUsageType{ + pk.Usage = &wshrpc.WaveAIUsageType{ PromptTokens: event.Message.Usage.InputTokens, CompletionTokens: event.Message.Usage.OutputTokens, TotalTokens: event.Message.Usage.InputTokens + event.Message.Usage.OutputTokens, } } - rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk} + rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk} } default: diff --git a/pkg/waveai/cloudbackend.go b/pkg/waveai/cloudbackend.go index ded005702..710730590 100644 --- a/pkg/waveai/cloudbackend.go +++ b/pkg/waveai/cloudbackend.go @@ -20,22 +20,22 @@ type WaveAICloudBackend struct{} var _ AIBackend = WaveAICloudBackend{} -type OpenAICloudReqPacketType struct { +type WaveAICloudReqPacketType struct { Type string `json:"type"` ClientId string `json:"clientid"` - Prompt []wshrpc.OpenAIPromptMessageType `json:"prompt"` + Prompt []wshrpc.WaveAIPromptMessageType `json:"prompt"` MaxTokens int `json:"maxtokens,omitempty"` MaxChoices int `json:"maxchoices,omitempty"` } -func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType { - return &OpenAICloudReqPacketType{ +func MakeWaveAICloudReqPacket() *WaveAICloudReqPacketType { + return &WaveAICloudReqPacketType{ Type: OpenAICloudReqStr, } } -func (WaveAICloudBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { - rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]) +func (WaveAICloudBackend) StreamCompletion(ctx context.Context, request wshrpc.WaveAIStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType] { + rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]) wsEndpoint := wcloud.GetWSEndpoint() go func() { defer func() { @@ -69,14 +69,14 @@ func (WaveAICloudBackend) StreamCompletion(ctx context.Context, request wshrpc.O rtn <- makeAIError(fmt.Errorf("unable to close openai channel: %v", err)) } }() - var sendablePromptMsgs []wshrpc.OpenAIPromptMessageType + var sendablePromptMsgs []wshrpc.WaveAIPromptMessageType for _, promptMsg := range request.Prompt { if promptMsg.Role == "error" { continue } sendablePromptMsgs = append(sendablePromptMsgs, promptMsg) } - reqPk := MakeOpenAICloudReqPacket() + reqPk := MakeWaveAICloudReqPacket() reqPk.ClientId = request.ClientId reqPk.Prompt = sendablePromptMsgs reqPk.MaxTokens = request.Opts.MaxTokens @@ -101,7 +101,7 @@ func (WaveAICloudBackend) StreamCompletion(ctx context.Context, request wshrpc.O rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket error reading message: %v", err)) break } - var streamResp *wshrpc.OpenAIPacketType + var streamResp *wshrpc.WaveAIPacketType err = json.Unmarshal(socketMessage, &streamResp) if err != nil { rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket response json decode error: %v", err)) @@ -115,7 +115,7 @@ func (WaveAICloudBackend) StreamCompletion(ctx context.Context, request wshrpc.O rtn <- makeAIError(fmt.Errorf("%v", streamResp.Error)) break } - rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *streamResp} + rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *streamResp} } }() return rtn diff --git a/pkg/waveai/openaibackend.go b/pkg/waveai/openaibackend.go index c5e30b7ad..a334fb523 100644 --- a/pkg/waveai/openaibackend.go +++ b/pkg/waveai/openaibackend.go @@ -25,7 +25,7 @@ func defaultAzureMapperFn(model string) string { return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "") } -func setApiType(opts *wshrpc.OpenAIOptsType, clientConfig *openaiapi.ClientConfig) error { +func setApiType(opts *wshrpc.WaveAIOptsType, clientConfig *openaiapi.ClientConfig) error { ourApiType := strings.ToLower(opts.APIType) if ourApiType == "" || ourApiType == strings.ToLower(string(openaiapi.APITypeOpenAI)) { clientConfig.APIType = openaiapi.APITypeOpenAI @@ -50,7 +50,7 @@ func setApiType(opts *wshrpc.OpenAIOptsType, clientConfig *openaiapi.ClientConfi } } -func convertPrompt(prompt []wshrpc.OpenAIPromptMessageType) []openaiapi.ChatCompletionMessage { +func convertPrompt(prompt []wshrpc.WaveAIPromptMessageType) []openaiapi.ChatCompletionMessage { var rtn []openaiapi.ChatCompletionMessage for _, p := range prompt { msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name} @@ -59,19 +59,8 @@ func convertPrompt(prompt []wshrpc.OpenAIPromptMessageType) []openaiapi.ChatComp return rtn } -func convertUsage(resp openaiapi.ChatCompletionResponse) *wshrpc.OpenAIUsageType { - if resp.Usage.TotalTokens == 0 { - return nil - } - return &wshrpc.OpenAIUsageType{ - PromptTokens: resp.Usage.PromptTokens, - CompletionTokens: resp.Usage.CompletionTokens, - TotalTokens: resp.Usage.TotalTokens, - } -} - -func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { - rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]) +func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.WaveAIStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType] { + rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]) go func() { defer func() { panicErr := panichandler.PanicHandler("OpenAIBackend.StreamCompletion") @@ -128,18 +117,18 @@ func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAi } // Send header packet - headerPk := MakeOpenAIPacket() + headerPk := MakeWaveAIPacket() headerPk.Model = resp.Model headerPk.Created = resp.Created - rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *headerPk} + rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *headerPk} // Send content packet(s) for i, choice := range resp.Choices { - pk := MakeOpenAIPacket() + pk := MakeWaveAIPacket() pk.Index = i pk.Text = choice.Message.Content pk.FinishReason = string(choice.FinishReason) - rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk} + rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk} } return } @@ -167,18 +156,18 @@ func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAi break } if streamResp.Model != "" && !sentHeader { - pk := MakeOpenAIPacket() + pk := MakeWaveAIPacket() pk.Model = streamResp.Model pk.Created = streamResp.Created - rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk} + rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk} sentHeader = true } for _, choice := range streamResp.Choices { - pk := MakeOpenAIPacket() + pk := MakeWaveAIPacket() pk.Index = choice.Index pk.Text = choice.Delta.Content pk.FinishReason = string(choice.FinishReason) - rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk} + rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk} } } }() diff --git a/pkg/waveai/perplexitybackend.go b/pkg/waveai/perplexitybackend.go index 991c87098..436d953d4 100644 --- a/pkg/waveai/perplexitybackend.go +++ b/pkg/waveai/perplexitybackend.go @@ -49,8 +49,8 @@ type perplexityResponse struct { Model string `json:"model"` } -func (PerplexityBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { - rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]) +func (PerplexityBackend) StreamCompletion(ctx context.Context, request wshrpc.WaveAIStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType] { + rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]) go func() { defer func() { @@ -160,17 +160,17 @@ func (PerplexityBackend) StreamCompletion(ctx context.Context, request wshrpc.Op } if !sentHeader { - pk := MakeOpenAIPacket() + pk := MakeWaveAIPacket() pk.Model = response.Model - rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk} + rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk} sentHeader = true } for _, choice := range response.Choices { - pk := MakeOpenAIPacket() + pk := MakeWaveAIPacket() pk.Text = choice.Delta.Content pk.FinishReason = choice.FinishReason - rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk} + rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk} } } }() diff --git a/pkg/waveai/waveai.go b/pkg/waveai/waveai.go index 44afda0c1..4ffa56f96 100644 --- a/pkg/waveai/waveai.go +++ b/pkg/waveai/waveai.go @@ -19,7 +19,7 @@ const DefaultAzureAPIVersion = "2023-05-15" const ApiType_Anthropic = "anthropic" const ApiType_Perplexity = "perplexity" -type OpenAICmdInfoPacketOutputType struct { +type WaveAICmdInfoPacketOutputType struct { Model string `json:"model,omitempty"` Created int64 `json:"created,omitempty"` FinishReason string `json:"finish_reason,omitempty"` @@ -27,14 +27,14 @@ type OpenAICmdInfoPacketOutputType struct { Error string `json:"error,omitempty"` } -func MakeOpenAIPacket() *wshrpc.OpenAIPacketType { - return &wshrpc.OpenAIPacketType{Type: OpenAIPacketStr} +func MakeWaveAIPacket() *wshrpc.WaveAIPacketType { + return &wshrpc.WaveAIPacketType{Type: OpenAIPacketStr} } -type OpenAICmdInfoChatMessage struct { +type WaveAICmdInfoChatMessage struct { MessageID int `json:"messageid"` IsAssistantResponse bool `json:"isassistantresponse,omitempty"` - AssistantResponse *OpenAICmdInfoPacketOutputType `json:"assistantresponse,omitempty"` + AssistantResponse *WaveAICmdInfoPacketOutputType `json:"assistantresponse,omitempty"` UserQuery string `json:"userquery,omitempty"` UserEngineeredQuery string `json:"userengineeredquery,omitempty"` } @@ -42,8 +42,8 @@ type OpenAICmdInfoChatMessage struct { type AIBackend interface { StreamCompletion( ctx context.Context, - request wshrpc.OpenAiStreamRequest, - ) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] + request wshrpc.WaveAIStreamRequest, + ) chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType] } const DefaultMaxTokens = 2048 @@ -53,18 +53,18 @@ const WCloudWSEndpointVarName = "WCLOUD_WS_ENDPOINT" const CloudWebsocketConnectTimeout = 1 * time.Minute -func IsCloudAIRequest(opts *wshrpc.OpenAIOptsType) bool { +func IsCloudAIRequest(opts *wshrpc.WaveAIOptsType) bool { if opts == nil { return true } return opts.BaseURL == "" && opts.APIToken == "" } -func makeAIError(err error) wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { - return wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: err} +func makeAIError(err error) wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType] { + return wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Error: err} } -func RunAICommand(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { +func RunAICommand(ctx context.Context, request wshrpc.WaveAIStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType] { telemetry.GoUpdateActivityWrap(wshrpc.ActivityUpdate{NumAIReqs: 1}, "RunAICommand") if request.Opts.APIType == ApiType_Anthropic { endpoint := request.Opts.BaseURL diff --git a/pkg/wshrpc/wshclient/wshclient.go b/pkg/wshrpc/wshclient/wshclient.go index 3dce286c4..657e2184a 100644 --- a/pkg/wshrpc/wshclient/wshclient.go +++ b/pkg/wshrpc/wshclient/wshclient.go @@ -364,8 +364,8 @@ func StreamTestCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) chan wshrpc.Resp } // command "streamwaveai", wshserver.StreamWaveAiCommand -func StreamWaveAiCommand(w *wshutil.WshRpc, data wshrpc.OpenAiStreamRequest, opts *wshrpc.RpcOpts) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { - return sendRpcRequestResponseStreamHelper[wshrpc.OpenAIPacketType](w, "streamwaveai", data, opts) +func StreamWaveAiCommand(w *wshutil.WshRpc, data wshrpc.WaveAIStreamRequest, opts *wshrpc.RpcOpts) chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType] { + return sendRpcRequestResponseStreamHelper[wshrpc.WaveAIPacketType](w, "streamwaveai", data, opts) } // command "test", wshserver.TestCommand diff --git a/pkg/wshrpc/wshrpctypes.go b/pkg/wshrpc/wshrpctypes.go index 71806f6b0..477bb2455 100644 --- a/pkg/wshrpc/wshrpctypes.go +++ b/pkg/wshrpc/wshrpctypes.go @@ -138,7 +138,7 @@ type WshRpcInterface interface { EventUnsubAllCommand(ctx context.Context) error EventReadHistoryCommand(ctx context.Context, data CommandEventReadHistoryData) ([]*wps.WaveEvent, error) StreamTestCommand(ctx context.Context) chan RespOrErrorUnion[int] - StreamWaveAiCommand(ctx context.Context, request OpenAiStreamRequest) chan RespOrErrorUnion[OpenAIPacketType] + StreamWaveAiCommand(ctx context.Context, request WaveAIStreamRequest) chan RespOrErrorUnion[WaveAIPacketType] StreamCpuDataCommand(ctx context.Context, request CpuDataRequest) chan RespOrErrorUnion[TimeSeriesData] TestCommand(ctx context.Context, data string) error SetConfigCommand(ctx context.Context, data MetaSettingsType) error @@ -377,19 +377,19 @@ type CommandEventReadHistoryData struct { MaxItems int `json:"maxitems"` } -type OpenAiStreamRequest struct { +type WaveAIStreamRequest struct { ClientId string `json:"clientid,omitempty"` - Opts *OpenAIOptsType `json:"opts"` - Prompt []OpenAIPromptMessageType `json:"prompt"` + Opts *WaveAIOptsType `json:"opts"` + Prompt []WaveAIPromptMessageType `json:"prompt"` } -type OpenAIPromptMessageType struct { +type WaveAIPromptMessageType struct { Role string `json:"role"` Content string `json:"content"` Name string `json:"name,omitempty"` } -type OpenAIOptsType struct { +type WaveAIOptsType struct { Model string `json:"model"` APIType string `json:"apitype,omitempty"` APIToken string `json:"apitoken"` @@ -401,18 +401,18 @@ type OpenAIOptsType struct { TimeoutMs int `json:"timeoutms,omitempty"` } -type OpenAIPacketType struct { +type WaveAIPacketType struct { Type string `json:"type"` Model string `json:"model,omitempty"` Created int64 `json:"created,omitempty"` FinishReason string `json:"finish_reason,omitempty"` - Usage *OpenAIUsageType `json:"usage,omitempty"` + Usage *WaveAIUsageType `json:"usage,omitempty"` Index int `json:"index,omitempty"` Text string `json:"text,omitempty"` Error string `json:"error,omitempty"` } -type OpenAIUsageType struct { +type WaveAIUsageType struct { PromptTokens int `json:"prompt_tokens,omitempty"` CompletionTokens int `json:"completion_tokens,omitempty"` TotalTokens int `json:"total_tokens,omitempty"` diff --git a/pkg/wshrpc/wshserver/wshserver.go b/pkg/wshrpc/wshserver/wshserver.go index b4e2acfec..be3878177 100644 --- a/pkg/wshrpc/wshserver/wshserver.go +++ b/pkg/wshrpc/wshserver/wshserver.go @@ -73,7 +73,7 @@ func (ws *WshServer) StreamTestCommand(ctx context.Context) chan wshrpc.RespOrEr return rtn } -func (ws *WshServer) StreamWaveAiCommand(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { +func (ws *WshServer) StreamWaveAiCommand(ctx context.Context, request wshrpc.WaveAIStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType] { return waveai.RunAICommand(ctx, request) }