diff --git a/cmd/generatewshclient/main-generatewshclient.go b/cmd/generatewshclient/main-generatewshclient.go index c62694aae..c63036e8c 100644 --- a/cmd/generatewshclient/main-generatewshclient.go +++ b/cmd/generatewshclient/main-generatewshclient.go @@ -70,6 +70,7 @@ func main() { fmt.Fprintf(fd, " \"github.com/wavetermdev/thenextwave/pkg/wshutil\"\n") fmt.Fprintf(fd, " \"github.com/wavetermdev/thenextwave/pkg/wshrpc\"\n") fmt.Fprintf(fd, " \"github.com/wavetermdev/thenextwave/pkg/waveobj\"\n") + fmt.Fprintf(fd, " \"github.com/wavetermdev/thenextwave/pkg/waveai\"\n") fmt.Fprintf(fd, ")\n\n") for _, key := range utilfn.GetOrderedMapKeys(wshserver.WshServerCommandToDeclMap) { diff --git a/frontend/app/block/block.tsx b/frontend/app/block/block.tsx index fca7bab27..136f2542c 100644 --- a/frontend/app/block/block.tsx +++ b/frontend/app/block/block.tsx @@ -14,7 +14,7 @@ import * as util from "@/util/util"; import { PlotView } from "@/view/plotview"; import { PreviewView, makePreviewModel } from "@/view/preview"; import { TerminalView, makeTerminalModel } from "@/view/term/term"; -import { WaveAi } from "@/view/waveai"; +import { WaveAi, makeWaveAiViewModel } from "@/view/waveai"; import { WebView, makeWebViewModel } from "@/view/webview"; import clsx from "clsx"; import * as jotai from "jotai"; @@ -516,7 +516,9 @@ function getViewElemAndModel( viewElem = ; viewModel = webviewModel; } else if (blockView === "waveai") { - viewElem = ; + const waveAiModel = makeWaveAiViewModel(blockId); + viewElem = ; + viewModel = waveAiModel; } if (viewModel == null) { viewModel = makeDefaultViewModel(blockId); diff --git a/frontend/app/store/waveai.ts b/frontend/app/store/waveai.ts deleted file mode 100644 index cafde4f54..000000000 --- a/frontend/app/store/waveai.ts +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright 2024, Command Line Inc. -// SPDX-License-Identifier: Apache-2.0 - -import { atom, useAtom } from "jotai"; -import { v4 as uuidv4 } from "uuid"; - -interface ChatMessageType { - id: string; - user: string; - text: string; - isAssistant: boolean; - isUpdating?: boolean; - isError?: string; -} - -const defaultMessage: ChatMessageType = { - id: uuidv4(), - user: "assistant", - text: `

Hello, how may I help you with this command?
-(Cmd-Shift-Space: open/close, Ctrl+L: clear chat buffer, Up/Down: select code blocks, Enter: to copy a selected code block to the command input)

`, - isAssistant: true, -}; - -const messagesAtom = atom([defaultMessage]); - -const addMessageAtom = atom(null, (get, set, message: ChatMessageType) => { - const messages = get(messagesAtom); - set(messagesAtom, [...messages, message]); -}); - -const updateLastMessageAtom = atom(null, (get, set, text: string, isUpdating: boolean) => { - const messages = get(messagesAtom); - const lastMessage = messages[messages.length - 1]; - if (lastMessage.isAssistant && !lastMessage.isError) { - const updatedMessage = { ...lastMessage, text: lastMessage.text + text, isUpdating }; - set(messagesAtom, [...messages.slice(0, -1), updatedMessage]); - } -}); - -const simulateAssistantResponseAtom = atom(null, (get, set, userMessage: ChatMessageType) => { - const responseText = `Here is an example of a simple bash script: - -\`\`\`bash -#!/bin/bash -# This is a comment -echo "Hello, World!" -\`\`\` - -You can run this script by saving it to a file, for example, \`hello.sh\`, and then running \`chmod +x hello.sh\` to make it executable. Finally, run it with \`./hello.sh\`.`; - - const typingMessage: ChatMessageType = { - id: uuidv4(), - user: "assistant", - text: "", - isAssistant: true, - }; - - // Add a typing indicator - set(addMessageAtom, typingMessage); - - setTimeout(() => { - const parts = responseText.split(" "); - let currentPart = 0; - - const intervalId = setInterval(() => { - if (currentPart < parts.length) { - const part = parts[currentPart] + " "; - set(updateLastMessageAtom, part, true); - currentPart++; - } else { - clearInterval(intervalId); - set(updateLastMessageAtom, "", false); - } - }, 100); - }, 1500); -}); - -const useWaveAi = () => { - const [messages] = useAtom(messagesAtom); - const [, addMessage] = useAtom(addMessageAtom); - const [, simulateResponse] = useAtom(simulateAssistantResponseAtom); - - const sendMessage = (text: string, user: string = "user") => { - const newMessage: ChatMessageType = { - id: uuidv4(), - user, - text, - isAssistant: false, - }; - addMessage(newMessage); - simulateResponse(newMessage); - }; - - return { - messages, - sendMessage, - }; -}; - -export { useWaveAi }; -export type { ChatMessageType }; diff --git a/frontend/app/store/wshserver.ts b/frontend/app/store/wshserver.ts index ac45fcdc4..d2c399c29 100644 --- a/frontend/app/store/wshserver.ts +++ b/frontend/app/store/wshserver.ts @@ -72,6 +72,11 @@ class WshServerType { return WOS.wshServerRpcHelper_call("setview", data, opts); } + // command "stream:waveai" [responsestream] + RespStreamWaveAi(data: OpenAiStreamRequest, opts?: WshRpcCommandOpts): AsyncGenerator { + return WOS.wshServerRpcHelper_responsestream("stream:waveai", data, opts); + } + // command "streamtest" [responsestream] RespStreamTest(opts?: WshRpcCommandOpts): AsyncGenerator { return WOS.wshServerRpcHelper_responsestream("streamtest", null, opts); diff --git a/frontend/app/view/waveai.tsx b/frontend/app/view/waveai.tsx index 0193a21df..fb590e73a 100644 --- a/frontend/app/view/waveai.tsx +++ b/frontend/app/view/waveai.tsx @@ -3,21 +3,170 @@ import { Markdown } from "@/app/element/markdown"; import { TypingIndicator } from "@/app/element/typingindicator"; -import { ChatMessageType, useWaveAi } from "@/app/store/waveai"; +import { WOS, atoms } from "@/store/global"; +import { WshServer } from "@/store/wshserver"; +import * as jotai from "jotai"; import type { OverlayScrollbars } from "overlayscrollbars"; import { OverlayScrollbarsComponent, OverlayScrollbarsComponentRef } from "overlayscrollbars-react"; import React, { forwardRef, useCallback, useEffect, useImperativeHandle, useRef, useState } from "react"; import tinycolor from "tinycolor2"; +import { v4 as uuidv4 } from "uuid"; import "./waveai.less"; +interface ChatMessageType { + id: string; + user: string; + text: string; + isAssistant: boolean; + isUpdating?: boolean; + isError?: string; +} + const outline = "2px solid var(--accent-color)"; +const defaultMessage: ChatMessageType = { + id: uuidv4(), + user: "assistant", + text: `

Hello, how may I help you with this command?
+(Cmd-Shift-Space: open/close, Ctrl+L: clear chat buffer, Up/Down: select code blocks, Enter: to copy a selected code block to the command input)

`, + isAssistant: true, +}; + interface ChatItemProps { chatItem: ChatMessageType; itemCount: number; } +export class WaveAiModel implements ViewModel { + blockId: string; + blockAtom: jotai.Atom; + viewIcon?: jotai.Atom; + viewName?: jotai.Atom; + viewText?: jotai.Atom; + preIconButton?: jotai.Atom; + endIconButtons?: jotai.Atom; + messagesAtom: jotai.PrimitiveAtom>; + addMessageAtom: jotai.WritableAtom; + updateLastMessageAtom: jotai.WritableAtom; + simulateAssistantResponseAtom: jotai.WritableAtom; + + constructor(blockId: string) { + this.blockId = blockId; + this.blockAtom = WOS.getWaveObjectAtom(`block:${blockId}`); + this.viewIcon = jotai.atom((get) => { + return "sparkles"; // should not be hardcoded + }); + this.viewName = jotai.atom("Ai"); + this.messagesAtom = jotai.atom([defaultMessage]); + + this.addMessageAtom = jotai.atom(null, (get, set, message: ChatMessageType) => { + const messages = get(this.messagesAtom); + set(this.messagesAtom, [...messages, message]); + }); + + this.updateLastMessageAtom = jotai.atom(null, (get, set, text: string, isUpdating: boolean) => { + const messages = get(this.messagesAtom); + const lastMessage = messages[messages.length - 1]; + if (lastMessage.isAssistant && !lastMessage.isError) { + const updatedMessage = { ...lastMessage, text: lastMessage.text + text, isUpdating }; + set(this.messagesAtom, [...messages.slice(0, -1), updatedMessage]); + } + }); + this.simulateAssistantResponseAtom = jotai.atom(null, (get, set, userMessage: ChatMessageType) => { + const typingMessage: ChatMessageType = { + id: uuidv4(), + user: "assistant", + text: "", + isAssistant: true, + }; + + // Add a typing indicator + set(this.addMessageAtom, typingMessage); + + setTimeout(() => { + const parts = userMessage.text.split(" "); + let currentPart = 0; + + const intervalId = setInterval(() => { + if (currentPart < parts.length) { + const part = parts[currentPart] + " "; + set(this.updateLastMessageAtom, part, true); + currentPart++; + } else { + clearInterval(intervalId); + set(this.updateLastMessageAtom, "", false); + } + }, 100); + }, 1500); + }); + } + + useWaveAi() { + const [messages] = jotai.useAtom(this.messagesAtom); + const [, addMessage] = jotai.useAtom(this.addMessageAtom); + const [, simulateResponse] = jotai.useAtom(this.simulateAssistantResponseAtom); + const metadata = jotai.useAtomValue(this.blockAtom).meta; + const clientId = jotai.useAtomValue(atoms.clientId); + + const sendMessage = (text: string, user: string = "user") => { + const newMessage: ChatMessageType = { + id: uuidv4(), + user, + text, + isAssistant: false, + }; + addMessage(newMessage); + // send message to backend and get response + const opts: OpenAIOptsType = { + model: "gpt-3.5-turbo", + apitoken: metadata?.apitoken as string, + maxtokens: 1000, + timeout: 10, + baseurl: metadata?.baseurl as string, + }; + const prompt: Array = [ + { + role: "user", + content: text, + name: (metadata?.name as string) || "user", + }, + ]; + console.log("opts.apitoken:", opts.apitoken); + const beMsg: OpenAiStreamRequest = { + clientid: clientId, + opts: opts, + prompt: prompt, + }; + const aiGen = WshServer.RespStreamWaveAi(beMsg); + let temp = async () => { + let fullMsg = ""; + for await (const msg of aiGen) { + fullMsg += msg.text ?? ""; + } + const response: ChatMessageType = { + id: newMessage.id, + user: newMessage.user, + text: fullMsg, + isAssistant: true, + }; + simulateResponse(response); + }; + temp(); + }; + + return { + messages, + sendMessage, + }; + } +} + +function makeWaveAiViewModel(blockId): WaveAiModel { + const waveAiModel = new WaveAiModel(blockId); + return waveAiModel; +} + const ChatItem = ({ chatItem, itemCount }: ChatItemProps) => { const { isAssistant, text, isError } = chatItem; const senderClassName = isAssistant ? "chat-msg-assistant" : "chat-msg-user"; @@ -208,8 +357,8 @@ const ChatInput = forwardRef( } ); -const WaveAi = () => { - const { messages, sendMessage } = useWaveAi(); +const WaveAi = ({ model }: { model: WaveAiModel }) => { + const { messages, sendMessage } = model.useWaveAi(); const waveaiRef = useRef(null); const chatWindowRef = useRef(null); const osRef = useRef(null); @@ -407,4 +556,4 @@ const WaveAi = () => { ); }; -export { WaveAi }; +export { WaveAi, makeWaveAiViewModel }; diff --git a/frontend/types/gotypes.d.ts b/frontend/types/gotypes.d.ts index f5f185329..95b25e619 100644 --- a/frontend/types/gotypes.d.ts +++ b/frontend/types/gotypes.d.ts @@ -187,6 +187,49 @@ declare global { // waveobj.ORef type ORef = string; + // waveai.OpenAIOptsType + type OpenAIOptsType = { + model: string; + apitoken: string; + baseurl?: string; + maxtokens?: number; + maxchoices?: number; + timeout?: number; + }; + + // waveai.OpenAIPacketType + type OpenAIPacketType = { + type: string; + model?: string; + created?: number; + finish_reason?: string; + usage?: OpenAIUsageType; + index?: number; + text?: string; + error?: string; + }; + + // waveai.OpenAIPromptMessageType + type OpenAIPromptMessageType = { + role: string; + content: string; + name?: string; + }; + + // waveai.OpenAIUsageType + type OpenAIUsageType = { + prompt_tokens?: number; + completion_tokens?: number; + total_tokens?: number; + }; + + // waveai.OpenAiStreamRequest + type OpenAiStreamRequest = { + clientid?: string; + opts: OpenAIOptsType; + prompt: OpenAIPromptMessageType[]; + }; + // wstore.Point type Point = { x: number; diff --git a/go.mod b/go.mod index 4ccf8b8bf..8ae012fe5 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/kevinburke/ssh_config v1.2.0 github.com/mattn/go-sqlite3 v1.14.22 github.com/mitchellh/mapstructure v1.5.0 + github.com/sashabaranov/go-openai v1.27.0 github.com/sawka/txwrap v0.2.0 github.com/spf13/cobra v1.8.1 github.com/wavetermdev/htmltoken v0.1.0 diff --git a/go.sum b/go.sum index f5b9bc7d3..e4a82752c 100644 --- a/go.sum +++ b/go.sum @@ -42,6 +42,8 @@ github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RR github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sashabaranov/go-openai v1.27.0 h1:L3hO6650YUbKrbGUC6yCjsUluhKZ9h1/jcgbTItI8Mo= +github.com/sashabaranov/go-openai v1.27.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/sawka/txwrap v0.2.0 h1:V3LfvKVLULxcYSxdMguLwFyQFMEU9nFDJopg0ZkL+94= github.com/sawka/txwrap v0.2.0/go.mod h1:wwQ2SQiN4U+6DU/iVPhbvr7OzXAtgZlQCIGuvOswEfA= github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= diff --git a/pkg/waveai/waveai.go b/pkg/waveai/waveai.go new file mode 100644 index 000000000..671877414 --- /dev/null +++ b/pkg/waveai/waveai.go @@ -0,0 +1,306 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package waveai + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log" + "os" + "time" + + openaiapi "github.com/sashabaranov/go-openai" + "github.com/wavetermdev/thenextwave/pkg/wavebase" + "github.com/wavetermdev/thenextwave/pkg/wshrpc" + + "github.com/gorilla/websocket" +) + +const OpenAIPacketStr = "openai" +const OpenAICloudReqStr = "openai-cloudreq" +const PacketEOFStr = "EOF" + +type OpenAIUsageType struct { + PromptTokens int `json:"prompt_tokens,omitempty"` + CompletionTokens int `json:"completion_tokens,omitempty"` + TotalTokens int `json:"total_tokens,omitempty"` +} + +type OpenAICmdInfoPacketOutputType struct { + Model string `json:"model,omitempty"` + Created int64 `json:"created,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + Message string `json:"message,omitempty"` + Error string `json:"error,omitempty"` +} + +type OpenAIPacketType struct { + Type string `json:"type"` + Model string `json:"model,omitempty"` + Created int64 `json:"created,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + Usage *OpenAIUsageType `json:"usage,omitempty"` + Index int `json:"index,omitempty"` + Text string `json:"text,omitempty"` + Error string `json:"error,omitempty"` +} + +func MakeOpenAIPacket() *OpenAIPacketType { + return &OpenAIPacketType{Type: OpenAIPacketStr} +} + +type OpenAICmdInfoChatMessage struct { + MessageID int `json:"messageid"` + IsAssistantResponse bool `json:"isassistantresponse,omitempty"` + AssistantResponse *OpenAICmdInfoPacketOutputType `json:"assistantresponse,omitempty"` + UserQuery string `json:"userquery,omitempty"` + UserEngineeredQuery string `json:"userengineeredquery,omitempty"` +} + +type OpenAIPromptMessageType struct { + Role string `json:"role"` + Content string `json:"content"` + Name string `json:"name,omitempty"` +} + +type OpenAICloudReqPacketType struct { + Type string `json:"type"` + ClientId string `json:"clientid"` + Prompt []OpenAIPromptMessageType `json:"prompt"` + MaxTokens int `json:"maxtokens,omitempty"` + MaxChoices int `json:"maxchoices,omitempty"` +} + +type OpenAIOptsType struct { + Model string `json:"model"` + APIToken string `json:"apitoken"` + BaseURL string `json:"baseurl,omitempty"` + MaxTokens int `json:"maxtokens,omitempty"` + MaxChoices int `json:"maxchoices,omitempty"` + Timeout int `json:"timeout,omitempty"` +} + +func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType { + return &OpenAICloudReqPacketType{ + Type: OpenAICloudReqStr, + } +} + +type OpenAiStreamRequest struct { + ClientId string `json:"clientid,omitempty"` + Opts *OpenAIOptsType `json:"opts"` + Prompt []OpenAIPromptMessageType `json:"prompt"` +} + +func GetWSEndpoint() string { + return PCloudWSEndpoint + if !wavebase.IsDevMode() { + return PCloudWSEndpoint + } else { + endpoint := os.Getenv(PCloudWSEndpointVarName) + if endpoint == "" { + panic("Invalid PCloud ws dev endpoint, PCLOUD_WS_ENDPOINT not set or invalid") + } + return endpoint + } +} + +const DefaultMaxTokens = 1000 +const DefaultModel = "gpt-3.5-turbo" +const DefaultStreamChanSize = 10 +const PCloudWSEndpoint = "wss://wsapi.waveterm.dev/" +const PCloudWSEndpointVarName = "PCLOUD_WS_ENDPOINT" + +const CloudWebsocketConnectTimeout = 1 * time.Minute + +func convertUsage(resp openaiapi.ChatCompletionResponse) *OpenAIUsageType { + if resp.Usage.TotalTokens == 0 { + return nil + } + return &OpenAIUsageType{ + PromptTokens: resp.Usage.PromptTokens, + CompletionTokens: resp.Usage.CompletionTokens, + TotalTokens: resp.Usage.TotalTokens, + } +} + +func ConvertPrompt(prompt []OpenAIPromptMessageType) []openaiapi.ChatCompletionMessage { + var rtn []openaiapi.ChatCompletionMessage + for _, p := range prompt { + msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name} + rtn = append(rtn, msg) + } + return rtn +} + +func RunCloudCompletionStream(ctx context.Context, request OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[OpenAIPacketType] { + rtn := make(chan wshrpc.RespOrErrorUnion[OpenAIPacketType]) + go func() { + log.Printf("start: %v", request) + defer close(rtn) + if request.Opts == nil { + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("no openai opts found")} + return + } + websocketContext, dialCancelFn := context.WithTimeout(context.Background(), CloudWebsocketConnectTimeout) + defer dialCancelFn() + conn, _, err := websocket.DefaultDialer.DialContext(websocketContext, GetWSEndpoint(), nil) + defer func() { + err = conn.Close() + if err != nil { + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("unable to close openai channel: %v", err)} + } + }() + if err == context.DeadlineExceeded { + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, timed out connecting to cloud server: %v", err)} + return + } else if err != nil { + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket connect error: %v", err)} + return + } + reqPk := MakeOpenAICloudReqPacket() + reqPk.ClientId = request.ClientId + reqPk.Prompt = request.Prompt + reqPk.MaxTokens = request.Opts.MaxTokens + reqPk.MaxChoices = request.Opts.MaxChoices + configMessageBuf, err := json.Marshal(reqPk) + if err != nil { + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, packet marshal error: %v", err)} + return + } + err = conn.WriteMessage(websocket.TextMessage, configMessageBuf) + if err != nil { + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket write config error: %v", err)} + return + } + for { + log.Printf("loop") + _, socketMessage, err := conn.ReadMessage() + if err == io.EOF { + break + } + if err != nil { + log.Printf("err received: %v", err) + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket error reading message: %v", err)} + break + } + var streamResp *OpenAIPacketType + err = json.Unmarshal(socketMessage, &streamResp) + log.Printf("ai resp: %v", streamResp) + if err != nil { + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket response json decode error: %v", err)} + break + } + if streamResp.Error == PacketEOFStr { + // got eof packet from socket + break + } else if streamResp.Error != "" { + // use error from server directly + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("%v", streamResp.Error)} + break + } + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Response: *streamResp} + } + }() + return rtn +} + +func RunLocalCompletionStream(ctx context.Context, request OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[OpenAIPacketType] { + rtn := make(chan wshrpc.RespOrErrorUnion[OpenAIPacketType]) + go func() { + log.Printf("start2: %v", request) + defer close(rtn) + if request.Opts == nil { + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("no openai opts found")} + return + } + if request.Opts.Model == "" { + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("no openai model specified")} + return + } + if request.Opts.BaseURL == "" && request.Opts.APIToken == "" { + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("no api token")} + return + } + clientConfig := openaiapi.DefaultConfig(request.Opts.APIToken) + if request.Opts.BaseURL != "" { + clientConfig.BaseURL = request.Opts.BaseURL + } + client := openaiapi.NewClientWithConfig(clientConfig) + req := openaiapi.ChatCompletionRequest{ + Model: request.Opts.Model, + Messages: ConvertPrompt(request.Prompt), + MaxTokens: request.Opts.MaxTokens, + Stream: true, + } + if request.Opts.MaxChoices > 1 { + req.N = request.Opts.MaxChoices + } + apiResp, err := client.CreateChatCompletionStream(ctx, req) + if err != nil { + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("error calling openai API: %v", err)} + return + } + sentHeader := false + for { + log.Printf("loop2") + streamResp, err := apiResp.Recv() + if err == io.EOF { + break + } + if err != nil { + log.Printf("err received2: %v", err) + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket error reading message: %v", err)} + break + } + if streamResp.Model != "" && !sentHeader { + pk := MakeOpenAIPacket() + pk.Model = streamResp.Model + pk.Created = streamResp.Created + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Response: *pk} + sentHeader = true + } + for _, choice := range streamResp.Choices { + pk := MakeOpenAIPacket() + pk.Index = choice.Index + pk.Text = choice.Delta.Content + pk.FinishReason = string(choice.FinishReason) + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Response: *pk} + } + } + }() + return rtn +} + +func marshalResponse(resp openaiapi.ChatCompletionResponse) []*OpenAIPacketType { + var rtn []*OpenAIPacketType + headerPk := MakeOpenAIPacket() + headerPk.Model = resp.Model + headerPk.Created = resp.Created + headerPk.Usage = convertUsage(resp) + rtn = append(rtn, headerPk) + for _, choice := range resp.Choices { + choicePk := MakeOpenAIPacket() + choicePk.Index = choice.Index + choicePk.Text = choice.Message.Content + choicePk.FinishReason = string(choice.FinishReason) + rtn = append(rtn, choicePk) + } + return rtn +} + +func CreateErrorPacket(errStr string) *OpenAIPacketType { + errPk := MakeOpenAIPacket() + errPk.FinishReason = "error" + errPk.Error = errStr + return errPk +} + +func CreateTextPacket(text string) *OpenAIPacketType { + pk := MakeOpenAIPacket() + pk.Text = text + return pk +} diff --git a/pkg/wconfig/settingsconfig.go b/pkg/wconfig/settingsconfig.go index b5dd02620..6777868e4 100644 --- a/pkg/wconfig/settingsconfig.go +++ b/pkg/wconfig/settingsconfig.go @@ -4,6 +4,7 @@ package wconfig import ( + "os/user" "path/filepath" "github.com/wavetermdev/thenextwave/pkg/wavebase" @@ -141,6 +142,13 @@ func applyDefaultSettings(settings *SettingsConfigType) { IntervalMs: 3600000, } } + var userName string + currentUser, err := user.Current() + if err != nil { + userName = "user" + } else { + userName = currentUser.Username + } defaultWidgets := []WidgetsConfigType{ { Icon: "files", @@ -170,6 +178,7 @@ func applyDefaultSettings(settings *SettingsConfigType) { Label: "waveai", BlockDef: wstore.BlockDef{ View: "waveai", + Meta: map[string]any{"name": userName, "baseurl": "", "apitoken": ""}, }, }, } diff --git a/pkg/wshrpc/wshclient/wshclient.go b/pkg/wshrpc/wshclient/wshclient.go index 197f96df5..0e4e9e8d4 100644 --- a/pkg/wshrpc/wshclient/wshclient.go +++ b/pkg/wshrpc/wshclient/wshclient.go @@ -9,6 +9,7 @@ import ( "github.com/wavetermdev/thenextwave/pkg/wshutil" "github.com/wavetermdev/thenextwave/pkg/wshrpc" "github.com/wavetermdev/thenextwave/pkg/waveobj" + "github.com/wavetermdev/thenextwave/pkg/waveai" ) // command "controller:input", wshserver.BlockInputCommand @@ -89,6 +90,11 @@ func BlockSetViewCommand(w *wshutil.WshRpc, data wshrpc.CommandBlockSetViewData, return err } +// command "stream:waveai", wshserver.RespStreamWaveAi +func RespStreamWaveAi(w *wshutil.WshRpc, data waveai.OpenAiStreamRequest, opts *wshrpc.WshRpcCommandOpts) chan wshrpc.RespOrErrorUnion[waveai.OpenAIPacketType] { + return sendRpcRequestResponseStreamHelper[waveai.OpenAIPacketType](w, "stream:waveai", data, opts) +} + // command "streamtest", wshserver.RespStreamTest func RespStreamTest(w *wshutil.WshRpc, opts *wshrpc.WshRpcCommandOpts) chan wshrpc.RespOrErrorUnion[int] { return sendRpcRequestResponseStreamHelper[int](w, "streamtest", nil, opts) diff --git a/pkg/wshrpc/wshrpctypes.go b/pkg/wshrpc/wshrpctypes.go index a0d224783..1679b95c9 100644 --- a/pkg/wshrpc/wshrpctypes.go +++ b/pkg/wshrpc/wshrpctypes.go @@ -15,19 +15,20 @@ import ( ) const ( - Command_Message = "message" - Command_SetView = "setview" - Command_SetMeta = "setmeta" - Command_GetMeta = "getmeta" - Command_BlockInput = "controller:input" - Command_Restart = "controller:restart" - Command_AppendFile = "file:append" - Command_AppendIJson = "file:appendijson" - Command_ResolveIds = "resolveids" - Command_CreateBlock = "createblock" - Command_DeleteBlock = "deleteblock" - Command_WriteFile = "file:write" - Command_ReadFile = "file:read" + Command_Message = "message" + Command_SetView = "setview" + Command_SetMeta = "setmeta" + Command_GetMeta = "getmeta" + Command_BlockInput = "controller:input" + Command_Restart = "controller:restart" + Command_AppendFile = "file:append" + Command_AppendIJson = "file:appendijson" + Command_ResolveIds = "resolveids" + Command_CreateBlock = "createblock" + Command_DeleteBlock = "deleteblock" + Command_WriteFile = "file:write" + Command_ReadFile = "file:read" + Command_StreamWaveAi = "stream:waveai" ) type MetaDataType = map[string]any diff --git a/pkg/wshrpc/wshserver/wshserver.go b/pkg/wshrpc/wshserver/wshserver.go index 18706238f..07fbf0e66 100644 --- a/pkg/wshrpc/wshserver/wshserver.go +++ b/pkg/wshrpc/wshserver/wshserver.go @@ -18,6 +18,7 @@ import ( "github.com/wavetermdev/thenextwave/pkg/blockcontroller" "github.com/wavetermdev/thenextwave/pkg/eventbus" "github.com/wavetermdev/thenextwave/pkg/filestore" + "github.com/wavetermdev/thenextwave/pkg/waveai" "github.com/wavetermdev/thenextwave/pkg/waveobj" "github.com/wavetermdev/thenextwave/pkg/wshrpc" "github.com/wavetermdev/thenextwave/pkg/wshutil" @@ -33,21 +34,31 @@ var RespStreamTest_MethodDecl = &WshServerMethodDecl{ DefaultResponseDataType: reflect.TypeOf((int)(0)), } +var RespStreamWaveAi_MethodDecl = &WshServerMethodDecl{ + Command: wshrpc.Command_StreamWaveAi, + CommandType: wshutil.RpcType_ResponseStream, + MethodName: "RespStreamWaveAi", + Method: reflect.ValueOf(WshServerImpl.RespStreamWaveAi), + CommandDataType: reflect.TypeOf(waveai.OpenAiStreamRequest{}), + DefaultResponseDataType: reflect.TypeOf(waveai.OpenAIPacketType{}), +} + var WshServerCommandToDeclMap = map[string]*WshServerMethodDecl{ - wshrpc.Command_Message: GetWshServerMethod(wshrpc.Command_Message, wshutil.RpcType_Call, "MessageCommand", WshServerImpl.MessageCommand), - wshrpc.Command_SetView: GetWshServerMethod(wshrpc.Command_SetView, wshutil.RpcType_Call, "BlockSetViewCommand", WshServerImpl.BlockSetViewCommand), - wshrpc.Command_SetMeta: GetWshServerMethod(wshrpc.Command_SetMeta, wshutil.RpcType_Call, "SetMetaCommand", WshServerImpl.SetMetaCommand), - wshrpc.Command_GetMeta: GetWshServerMethod(wshrpc.Command_GetMeta, wshutil.RpcType_Call, "GetMetaCommand", WshServerImpl.GetMetaCommand), - wshrpc.Command_ResolveIds: GetWshServerMethod(wshrpc.Command_ResolveIds, wshutil.RpcType_Call, "ResolveIdsCommand", WshServerImpl.ResolveIdsCommand), - wshrpc.Command_CreateBlock: GetWshServerMethod(wshrpc.Command_CreateBlock, wshutil.RpcType_Call, "CreateBlockCommand", WshServerImpl.CreateBlockCommand), - wshrpc.Command_Restart: GetWshServerMethod(wshrpc.Command_Restart, wshutil.RpcType_Call, "BlockRestartCommand", WshServerImpl.BlockRestartCommand), - wshrpc.Command_BlockInput: GetWshServerMethod(wshrpc.Command_BlockInput, wshutil.RpcType_Call, "BlockInputCommand", WshServerImpl.BlockInputCommand), - wshrpc.Command_AppendFile: GetWshServerMethod(wshrpc.Command_AppendFile, wshutil.RpcType_Call, "AppendFileCommand", WshServerImpl.AppendFileCommand), - wshrpc.Command_AppendIJson: GetWshServerMethod(wshrpc.Command_AppendIJson, wshutil.RpcType_Call, "AppendIJsonCommand", WshServerImpl.AppendIJsonCommand), - wshrpc.Command_DeleteBlock: GetWshServerMethod(wshrpc.Command_DeleteBlock, wshutil.RpcType_Call, "DeleteBlockCommand", WshServerImpl.DeleteBlockCommand), - wshrpc.Command_WriteFile: GetWshServerMethod(wshrpc.Command_WriteFile, wshutil.RpcType_Call, "WriteFile", WshServerImpl.WriteFile), - wshrpc.Command_ReadFile: GetWshServerMethod(wshrpc.Command_ReadFile, wshutil.RpcType_Call, "ReadFile", WshServerImpl.ReadFile), - "streamtest": RespStreamTest_MethodDecl, + wshrpc.Command_Message: GetWshServerMethod(wshrpc.Command_Message, wshutil.RpcType_Call, "MessageCommand", WshServerImpl.MessageCommand), + wshrpc.Command_SetView: GetWshServerMethod(wshrpc.Command_SetView, wshutil.RpcType_Call, "BlockSetViewCommand", WshServerImpl.BlockSetViewCommand), + wshrpc.Command_SetMeta: GetWshServerMethod(wshrpc.Command_SetMeta, wshutil.RpcType_Call, "SetMetaCommand", WshServerImpl.SetMetaCommand), + wshrpc.Command_GetMeta: GetWshServerMethod(wshrpc.Command_GetMeta, wshutil.RpcType_Call, "GetMetaCommand", WshServerImpl.GetMetaCommand), + wshrpc.Command_ResolveIds: GetWshServerMethod(wshrpc.Command_ResolveIds, wshutil.RpcType_Call, "ResolveIdsCommand", WshServerImpl.ResolveIdsCommand), + wshrpc.Command_CreateBlock: GetWshServerMethod(wshrpc.Command_CreateBlock, wshutil.RpcType_Call, "CreateBlockCommand", WshServerImpl.CreateBlockCommand), + wshrpc.Command_Restart: GetWshServerMethod(wshrpc.Command_Restart, wshutil.RpcType_Call, "BlockRestartCommand", WshServerImpl.BlockRestartCommand), + wshrpc.Command_BlockInput: GetWshServerMethod(wshrpc.Command_BlockInput, wshutil.RpcType_Call, "BlockInputCommand", WshServerImpl.BlockInputCommand), + wshrpc.Command_AppendFile: GetWshServerMethod(wshrpc.Command_AppendFile, wshutil.RpcType_Call, "AppendFileCommand", WshServerImpl.AppendFileCommand), + wshrpc.Command_AppendIJson: GetWshServerMethod(wshrpc.Command_AppendIJson, wshutil.RpcType_Call, "AppendIJsonCommand", WshServerImpl.AppendIJsonCommand), + wshrpc.Command_DeleteBlock: GetWshServerMethod(wshrpc.Command_DeleteBlock, wshutil.RpcType_Call, "DeleteBlockCommand", WshServerImpl.DeleteBlockCommand), + wshrpc.Command_WriteFile: GetWshServerMethod(wshrpc.Command_WriteFile, wshutil.RpcType_Call, "WriteFile", WshServerImpl.WriteFile), + wshrpc.Command_ReadFile: GetWshServerMethod(wshrpc.Command_ReadFile, wshutil.RpcType_Call, "ReadFile", WshServerImpl.ReadFile), + wshrpc.Command_StreamWaveAi: RespStreamWaveAi_MethodDecl, + "streamtest": RespStreamTest_MethodDecl, } // for testing @@ -69,6 +80,13 @@ func (ws *WshServer) RespStreamTest(ctx context.Context) chan wshrpc.RespOrError return rtn } +func (ws *WshServer) RespStreamWaveAi(ctx context.Context, request waveai.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[waveai.OpenAIPacketType] { + if request.Opts.BaseURL == "" && request.Opts.APIToken == "" { + return waveai.RunCloudCompletionStream(ctx, request) + } + return waveai.RunLocalCompletionStream(ctx, request) +} + func (ws *WshServer) GetMetaCommand(ctx context.Context, data wshrpc.CommandGetMetaData) (wshrpc.MetaDataType, error) { log.Printf("calling meta: %s\n", data.ORef) obj, err := wstore.DBGetORef(ctx, data.ORef)