From 0e46b79c226ae2554ae58d1d303f2aae4232d7dc Mon Sep 17 00:00:00 2001 From: Sylvie Crowe <107814465+oneirocosm@users.noreply.github.com> Date: Thu, 25 Jul 2024 02:30:49 -0700 Subject: [PATCH] Open Ai Port (#154) This brings over a simplified version of the open ai feature from the previous app but in widget form. It still needs some work to reach parity with that version, but this includes all of the basic building blocks to get that working. --- .../main-generatewshclient.go | 1 + frontend/app/block/block.tsx | 6 +- frontend/app/store/waveai.ts | 101 ------ frontend/app/store/wshserver.ts | 5 + frontend/app/view/waveai.tsx | 157 ++++++++- frontend/types/gotypes.d.ts | 43 +++ go.mod | 1 + go.sum | 2 + pkg/waveai/waveai.go | 306 ++++++++++++++++++ pkg/wconfig/settingsconfig.go | 9 + pkg/wshrpc/wshclient/wshclient.go | 6 + pkg/wshrpc/wshrpctypes.go | 27 +- pkg/wshrpc/wshserver/wshserver.go | 46 ++- 13 files changed, 576 insertions(+), 134 deletions(-) delete mode 100644 frontend/app/store/waveai.ts create mode 100644 pkg/waveai/waveai.go diff --git a/cmd/generatewshclient/main-generatewshclient.go b/cmd/generatewshclient/main-generatewshclient.go index c62694aae..c63036e8c 100644 --- a/cmd/generatewshclient/main-generatewshclient.go +++ b/cmd/generatewshclient/main-generatewshclient.go @@ -70,6 +70,7 @@ func main() { fmt.Fprintf(fd, " \"github.com/wavetermdev/thenextwave/pkg/wshutil\"\n") fmt.Fprintf(fd, " \"github.com/wavetermdev/thenextwave/pkg/wshrpc\"\n") fmt.Fprintf(fd, " \"github.com/wavetermdev/thenextwave/pkg/waveobj\"\n") + fmt.Fprintf(fd, " \"github.com/wavetermdev/thenextwave/pkg/waveai\"\n") fmt.Fprintf(fd, ")\n\n") for _, key := range utilfn.GetOrderedMapKeys(wshserver.WshServerCommandToDeclMap) { diff --git a/frontend/app/block/block.tsx b/frontend/app/block/block.tsx index fca7bab27..136f2542c 100644 --- a/frontend/app/block/block.tsx +++ b/frontend/app/block/block.tsx @@ -14,7 +14,7 @@ import * as util from "@/util/util"; import { PlotView } from "@/view/plotview"; import { PreviewView, makePreviewModel } from "@/view/preview"; import { TerminalView, makeTerminalModel } from "@/view/term/term"; -import { WaveAi } from "@/view/waveai"; +import { WaveAi, makeWaveAiViewModel } from "@/view/waveai"; import { WebView, makeWebViewModel } from "@/view/webview"; import clsx from "clsx"; import * as jotai from "jotai"; @@ -516,7 +516,9 @@ function getViewElemAndModel( viewElem = ; viewModel = webviewModel; } else if (blockView === "waveai") { - viewElem = ; + const waveAiModel = makeWaveAiViewModel(blockId); + viewElem = ; + viewModel = waveAiModel; } if (viewModel == null) { viewModel = makeDefaultViewModel(blockId); diff --git a/frontend/app/store/waveai.ts b/frontend/app/store/waveai.ts deleted file mode 100644 index cafde4f54..000000000 --- a/frontend/app/store/waveai.ts +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright 2024, Command Line Inc. -// SPDX-License-Identifier: Apache-2.0 - -import { atom, useAtom } from "jotai"; -import { v4 as uuidv4 } from "uuid"; - -interface ChatMessageType { - id: string; - user: string; - text: string; - isAssistant: boolean; - isUpdating?: boolean; - isError?: string; -} - -const defaultMessage: ChatMessageType = { - id: uuidv4(), - user: "assistant", - text: `

Hello, how may I help you with this command?
-(Cmd-Shift-Space: open/close, Ctrl+L: clear chat buffer, Up/Down: select code blocks, Enter: to copy a selected code block to the command input)

`, - isAssistant: true, -}; - -const messagesAtom = atom([defaultMessage]); - -const addMessageAtom = atom(null, (get, set, message: ChatMessageType) => { - const messages = get(messagesAtom); - set(messagesAtom, [...messages, message]); -}); - -const updateLastMessageAtom = atom(null, (get, set, text: string, isUpdating: boolean) => { - const messages = get(messagesAtom); - const lastMessage = messages[messages.length - 1]; - if (lastMessage.isAssistant && !lastMessage.isError) { - const updatedMessage = { ...lastMessage, text: lastMessage.text + text, isUpdating }; - set(messagesAtom, [...messages.slice(0, -1), updatedMessage]); - } -}); - -const simulateAssistantResponseAtom = atom(null, (get, set, userMessage: ChatMessageType) => { - const responseText = `Here is an example of a simple bash script: - -\`\`\`bash -#!/bin/bash -# This is a comment -echo "Hello, World!" -\`\`\` - -You can run this script by saving it to a file, for example, \`hello.sh\`, and then running \`chmod +x hello.sh\` to make it executable. Finally, run it with \`./hello.sh\`.`; - - const typingMessage: ChatMessageType = { - id: uuidv4(), - user: "assistant", - text: "", - isAssistant: true, - }; - - // Add a typing indicator - set(addMessageAtom, typingMessage); - - setTimeout(() => { - const parts = responseText.split(" "); - let currentPart = 0; - - const intervalId = setInterval(() => { - if (currentPart < parts.length) { - const part = parts[currentPart] + " "; - set(updateLastMessageAtom, part, true); - currentPart++; - } else { - clearInterval(intervalId); - set(updateLastMessageAtom, "", false); - } - }, 100); - }, 1500); -}); - -const useWaveAi = () => { - const [messages] = useAtom(messagesAtom); - const [, addMessage] = useAtom(addMessageAtom); - const [, simulateResponse] = useAtom(simulateAssistantResponseAtom); - - const sendMessage = (text: string, user: string = "user") => { - const newMessage: ChatMessageType = { - id: uuidv4(), - user, - text, - isAssistant: false, - }; - addMessage(newMessage); - simulateResponse(newMessage); - }; - - return { - messages, - sendMessage, - }; -}; - -export { useWaveAi }; -export type { ChatMessageType }; diff --git a/frontend/app/store/wshserver.ts b/frontend/app/store/wshserver.ts index ac45fcdc4..d2c399c29 100644 --- a/frontend/app/store/wshserver.ts +++ b/frontend/app/store/wshserver.ts @@ -72,6 +72,11 @@ class WshServerType { return WOS.wshServerRpcHelper_call("setview", data, opts); } + // command "stream:waveai" [responsestream] + RespStreamWaveAi(data: OpenAiStreamRequest, opts?: WshRpcCommandOpts): AsyncGenerator { + return WOS.wshServerRpcHelper_responsestream("stream:waveai", data, opts); + } + // command "streamtest" [responsestream] RespStreamTest(opts?: WshRpcCommandOpts): AsyncGenerator { return WOS.wshServerRpcHelper_responsestream("streamtest", null, opts); diff --git a/frontend/app/view/waveai.tsx b/frontend/app/view/waveai.tsx index 0193a21df..fb590e73a 100644 --- a/frontend/app/view/waveai.tsx +++ b/frontend/app/view/waveai.tsx @@ -3,21 +3,170 @@ import { Markdown } from "@/app/element/markdown"; import { TypingIndicator } from "@/app/element/typingindicator"; -import { ChatMessageType, useWaveAi } from "@/app/store/waveai"; +import { WOS, atoms } from "@/store/global"; +import { WshServer } from "@/store/wshserver"; +import * as jotai from "jotai"; import type { OverlayScrollbars } from "overlayscrollbars"; import { OverlayScrollbarsComponent, OverlayScrollbarsComponentRef } from "overlayscrollbars-react"; import React, { forwardRef, useCallback, useEffect, useImperativeHandle, useRef, useState } from "react"; import tinycolor from "tinycolor2"; +import { v4 as uuidv4 } from "uuid"; import "./waveai.less"; +interface ChatMessageType { + id: string; + user: string; + text: string; + isAssistant: boolean; + isUpdating?: boolean; + isError?: string; +} + const outline = "2px solid var(--accent-color)"; +const defaultMessage: ChatMessageType = { + id: uuidv4(), + user: "assistant", + text: `

Hello, how may I help you with this command?
+(Cmd-Shift-Space: open/close, Ctrl+L: clear chat buffer, Up/Down: select code blocks, Enter: to copy a selected code block to the command input)

`, + isAssistant: true, +}; + interface ChatItemProps { chatItem: ChatMessageType; itemCount: number; } +export class WaveAiModel implements ViewModel { + blockId: string; + blockAtom: jotai.Atom; + viewIcon?: jotai.Atom; + viewName?: jotai.Atom; + viewText?: jotai.Atom; + preIconButton?: jotai.Atom; + endIconButtons?: jotai.Atom; + messagesAtom: jotai.PrimitiveAtom>; + addMessageAtom: jotai.WritableAtom; + updateLastMessageAtom: jotai.WritableAtom; + simulateAssistantResponseAtom: jotai.WritableAtom; + + constructor(blockId: string) { + this.blockId = blockId; + this.blockAtom = WOS.getWaveObjectAtom(`block:${blockId}`); + this.viewIcon = jotai.atom((get) => { + return "sparkles"; // should not be hardcoded + }); + this.viewName = jotai.atom("Ai"); + this.messagesAtom = jotai.atom([defaultMessage]); + + this.addMessageAtom = jotai.atom(null, (get, set, message: ChatMessageType) => { + const messages = get(this.messagesAtom); + set(this.messagesAtom, [...messages, message]); + }); + + this.updateLastMessageAtom = jotai.atom(null, (get, set, text: string, isUpdating: boolean) => { + const messages = get(this.messagesAtom); + const lastMessage = messages[messages.length - 1]; + if (lastMessage.isAssistant && !lastMessage.isError) { + const updatedMessage = { ...lastMessage, text: lastMessage.text + text, isUpdating }; + set(this.messagesAtom, [...messages.slice(0, -1), updatedMessage]); + } + }); + this.simulateAssistantResponseAtom = jotai.atom(null, (get, set, userMessage: ChatMessageType) => { + const typingMessage: ChatMessageType = { + id: uuidv4(), + user: "assistant", + text: "", + isAssistant: true, + }; + + // Add a typing indicator + set(this.addMessageAtom, typingMessage); + + setTimeout(() => { + const parts = userMessage.text.split(" "); + let currentPart = 0; + + const intervalId = setInterval(() => { + if (currentPart < parts.length) { + const part = parts[currentPart] + " "; + set(this.updateLastMessageAtom, part, true); + currentPart++; + } else { + clearInterval(intervalId); + set(this.updateLastMessageAtom, "", false); + } + }, 100); + }, 1500); + }); + } + + useWaveAi() { + const [messages] = jotai.useAtom(this.messagesAtom); + const [, addMessage] = jotai.useAtom(this.addMessageAtom); + const [, simulateResponse] = jotai.useAtom(this.simulateAssistantResponseAtom); + const metadata = jotai.useAtomValue(this.blockAtom).meta; + const clientId = jotai.useAtomValue(atoms.clientId); + + const sendMessage = (text: string, user: string = "user") => { + const newMessage: ChatMessageType = { + id: uuidv4(), + user, + text, + isAssistant: false, + }; + addMessage(newMessage); + // send message to backend and get response + const opts: OpenAIOptsType = { + model: "gpt-3.5-turbo", + apitoken: metadata?.apitoken as string, + maxtokens: 1000, + timeout: 10, + baseurl: metadata?.baseurl as string, + }; + const prompt: Array = [ + { + role: "user", + content: text, + name: (metadata?.name as string) || "user", + }, + ]; + console.log("opts.apitoken:", opts.apitoken); + const beMsg: OpenAiStreamRequest = { + clientid: clientId, + opts: opts, + prompt: prompt, + }; + const aiGen = WshServer.RespStreamWaveAi(beMsg); + let temp = async () => { + let fullMsg = ""; + for await (const msg of aiGen) { + fullMsg += msg.text ?? ""; + } + const response: ChatMessageType = { + id: newMessage.id, + user: newMessage.user, + text: fullMsg, + isAssistant: true, + }; + simulateResponse(response); + }; + temp(); + }; + + return { + messages, + sendMessage, + }; + } +} + +function makeWaveAiViewModel(blockId): WaveAiModel { + const waveAiModel = new WaveAiModel(blockId); + return waveAiModel; +} + const ChatItem = ({ chatItem, itemCount }: ChatItemProps) => { const { isAssistant, text, isError } = chatItem; const senderClassName = isAssistant ? "chat-msg-assistant" : "chat-msg-user"; @@ -208,8 +357,8 @@ const ChatInput = forwardRef( } ); -const WaveAi = () => { - const { messages, sendMessage } = useWaveAi(); +const WaveAi = ({ model }: { model: WaveAiModel }) => { + const { messages, sendMessage } = model.useWaveAi(); const waveaiRef = useRef(null); const chatWindowRef = useRef(null); const osRef = useRef(null); @@ -407,4 +556,4 @@ const WaveAi = () => { ); }; -export { WaveAi }; +export { WaveAi, makeWaveAiViewModel }; diff --git a/frontend/types/gotypes.d.ts b/frontend/types/gotypes.d.ts index f5f185329..95b25e619 100644 --- a/frontend/types/gotypes.d.ts +++ b/frontend/types/gotypes.d.ts @@ -187,6 +187,49 @@ declare global { // waveobj.ORef type ORef = string; + // waveai.OpenAIOptsType + type OpenAIOptsType = { + model: string; + apitoken: string; + baseurl?: string; + maxtokens?: number; + maxchoices?: number; + timeout?: number; + }; + + // waveai.OpenAIPacketType + type OpenAIPacketType = { + type: string; + model?: string; + created?: number; + finish_reason?: string; + usage?: OpenAIUsageType; + index?: number; + text?: string; + error?: string; + }; + + // waveai.OpenAIPromptMessageType + type OpenAIPromptMessageType = { + role: string; + content: string; + name?: string; + }; + + // waveai.OpenAIUsageType + type OpenAIUsageType = { + prompt_tokens?: number; + completion_tokens?: number; + total_tokens?: number; + }; + + // waveai.OpenAiStreamRequest + type OpenAiStreamRequest = { + clientid?: string; + opts: OpenAIOptsType; + prompt: OpenAIPromptMessageType[]; + }; + // wstore.Point type Point = { x: number; diff --git a/go.mod b/go.mod index 4ccf8b8bf..8ae012fe5 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/kevinburke/ssh_config v1.2.0 github.com/mattn/go-sqlite3 v1.14.22 github.com/mitchellh/mapstructure v1.5.0 + github.com/sashabaranov/go-openai v1.27.0 github.com/sawka/txwrap v0.2.0 github.com/spf13/cobra v1.8.1 github.com/wavetermdev/htmltoken v0.1.0 diff --git a/go.sum b/go.sum index f5b9bc7d3..e4a82752c 100644 --- a/go.sum +++ b/go.sum @@ -42,6 +42,8 @@ github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RR github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sashabaranov/go-openai v1.27.0 h1:L3hO6650YUbKrbGUC6yCjsUluhKZ9h1/jcgbTItI8Mo= +github.com/sashabaranov/go-openai v1.27.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/sawka/txwrap v0.2.0 h1:V3LfvKVLULxcYSxdMguLwFyQFMEU9nFDJopg0ZkL+94= github.com/sawka/txwrap v0.2.0/go.mod h1:wwQ2SQiN4U+6DU/iVPhbvr7OzXAtgZlQCIGuvOswEfA= github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= diff --git a/pkg/waveai/waveai.go b/pkg/waveai/waveai.go new file mode 100644 index 000000000..671877414 --- /dev/null +++ b/pkg/waveai/waveai.go @@ -0,0 +1,306 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package waveai + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log" + "os" + "time" + + openaiapi "github.com/sashabaranov/go-openai" + "github.com/wavetermdev/thenextwave/pkg/wavebase" + "github.com/wavetermdev/thenextwave/pkg/wshrpc" + + "github.com/gorilla/websocket" +) + +const OpenAIPacketStr = "openai" +const OpenAICloudReqStr = "openai-cloudreq" +const PacketEOFStr = "EOF" + +type OpenAIUsageType struct { + PromptTokens int `json:"prompt_tokens,omitempty"` + CompletionTokens int `json:"completion_tokens,omitempty"` + TotalTokens int `json:"total_tokens,omitempty"` +} + +type OpenAICmdInfoPacketOutputType struct { + Model string `json:"model,omitempty"` + Created int64 `json:"created,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + Message string `json:"message,omitempty"` + Error string `json:"error,omitempty"` +} + +type OpenAIPacketType struct { + Type string `json:"type"` + Model string `json:"model,omitempty"` + Created int64 `json:"created,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + Usage *OpenAIUsageType `json:"usage,omitempty"` + Index int `json:"index,omitempty"` + Text string `json:"text,omitempty"` + Error string `json:"error,omitempty"` +} + +func MakeOpenAIPacket() *OpenAIPacketType { + return &OpenAIPacketType{Type: OpenAIPacketStr} +} + +type OpenAICmdInfoChatMessage struct { + MessageID int `json:"messageid"` + IsAssistantResponse bool `json:"isassistantresponse,omitempty"` + AssistantResponse *OpenAICmdInfoPacketOutputType `json:"assistantresponse,omitempty"` + UserQuery string `json:"userquery,omitempty"` + UserEngineeredQuery string `json:"userengineeredquery,omitempty"` +} + +type OpenAIPromptMessageType struct { + Role string `json:"role"` + Content string `json:"content"` + Name string `json:"name,omitempty"` +} + +type OpenAICloudReqPacketType struct { + Type string `json:"type"` + ClientId string `json:"clientid"` + Prompt []OpenAIPromptMessageType `json:"prompt"` + MaxTokens int `json:"maxtokens,omitempty"` + MaxChoices int `json:"maxchoices,omitempty"` +} + +type OpenAIOptsType struct { + Model string `json:"model"` + APIToken string `json:"apitoken"` + BaseURL string `json:"baseurl,omitempty"` + MaxTokens int `json:"maxtokens,omitempty"` + MaxChoices int `json:"maxchoices,omitempty"` + Timeout int `json:"timeout,omitempty"` +} + +func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType { + return &OpenAICloudReqPacketType{ + Type: OpenAICloudReqStr, + } +} + +type OpenAiStreamRequest struct { + ClientId string `json:"clientid,omitempty"` + Opts *OpenAIOptsType `json:"opts"` + Prompt []OpenAIPromptMessageType `json:"prompt"` +} + +func GetWSEndpoint() string { + return PCloudWSEndpoint + if !wavebase.IsDevMode() { + return PCloudWSEndpoint + } else { + endpoint := os.Getenv(PCloudWSEndpointVarName) + if endpoint == "" { + panic("Invalid PCloud ws dev endpoint, PCLOUD_WS_ENDPOINT not set or invalid") + } + return endpoint + } +} + +const DefaultMaxTokens = 1000 +const DefaultModel = "gpt-3.5-turbo" +const DefaultStreamChanSize = 10 +const PCloudWSEndpoint = "wss://wsapi.waveterm.dev/" +const PCloudWSEndpointVarName = "PCLOUD_WS_ENDPOINT" + +const CloudWebsocketConnectTimeout = 1 * time.Minute + +func convertUsage(resp openaiapi.ChatCompletionResponse) *OpenAIUsageType { + if resp.Usage.TotalTokens == 0 { + return nil + } + return &OpenAIUsageType{ + PromptTokens: resp.Usage.PromptTokens, + CompletionTokens: resp.Usage.CompletionTokens, + TotalTokens: resp.Usage.TotalTokens, + } +} + +func ConvertPrompt(prompt []OpenAIPromptMessageType) []openaiapi.ChatCompletionMessage { + var rtn []openaiapi.ChatCompletionMessage + for _, p := range prompt { + msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name} + rtn = append(rtn, msg) + } + return rtn +} + +func RunCloudCompletionStream(ctx context.Context, request OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[OpenAIPacketType] { + rtn := make(chan wshrpc.RespOrErrorUnion[OpenAIPacketType]) + go func() { + log.Printf("start: %v", request) + defer close(rtn) + if request.Opts == nil { + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("no openai opts found")} + return + } + websocketContext, dialCancelFn := context.WithTimeout(context.Background(), CloudWebsocketConnectTimeout) + defer dialCancelFn() + conn, _, err := websocket.DefaultDialer.DialContext(websocketContext, GetWSEndpoint(), nil) + defer func() { + err = conn.Close() + if err != nil { + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("unable to close openai channel: %v", err)} + } + }() + if err == context.DeadlineExceeded { + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, timed out connecting to cloud server: %v", err)} + return + } else if err != nil { + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket connect error: %v", err)} + return + } + reqPk := MakeOpenAICloudReqPacket() + reqPk.ClientId = request.ClientId + reqPk.Prompt = request.Prompt + reqPk.MaxTokens = request.Opts.MaxTokens + reqPk.MaxChoices = request.Opts.MaxChoices + configMessageBuf, err := json.Marshal(reqPk) + if err != nil { + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, packet marshal error: %v", err)} + return + } + err = conn.WriteMessage(websocket.TextMessage, configMessageBuf) + if err != nil { + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket write config error: %v", err)} + return + } + for { + log.Printf("loop") + _, socketMessage, err := conn.ReadMessage() + if err == io.EOF { + break + } + if err != nil { + log.Printf("err received: %v", err) + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket error reading message: %v", err)} + break + } + var streamResp *OpenAIPacketType + err = json.Unmarshal(socketMessage, &streamResp) + log.Printf("ai resp: %v", streamResp) + if err != nil { + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket response json decode error: %v", err)} + break + } + if streamResp.Error == PacketEOFStr { + // got eof packet from socket + break + } else if streamResp.Error != "" { + // use error from server directly + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("%v", streamResp.Error)} + break + } + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Response: *streamResp} + } + }() + return rtn +} + +func RunLocalCompletionStream(ctx context.Context, request OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[OpenAIPacketType] { + rtn := make(chan wshrpc.RespOrErrorUnion[OpenAIPacketType]) + go func() { + log.Printf("start2: %v", request) + defer close(rtn) + if request.Opts == nil { + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("no openai opts found")} + return + } + if request.Opts.Model == "" { + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("no openai model specified")} + return + } + if request.Opts.BaseURL == "" && request.Opts.APIToken == "" { + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("no api token")} + return + } + clientConfig := openaiapi.DefaultConfig(request.Opts.APIToken) + if request.Opts.BaseURL != "" { + clientConfig.BaseURL = request.Opts.BaseURL + } + client := openaiapi.NewClientWithConfig(clientConfig) + req := openaiapi.ChatCompletionRequest{ + Model: request.Opts.Model, + Messages: ConvertPrompt(request.Prompt), + MaxTokens: request.Opts.MaxTokens, + Stream: true, + } + if request.Opts.MaxChoices > 1 { + req.N = request.Opts.MaxChoices + } + apiResp, err := client.CreateChatCompletionStream(ctx, req) + if err != nil { + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("error calling openai API: %v", err)} + return + } + sentHeader := false + for { + log.Printf("loop2") + streamResp, err := apiResp.Recv() + if err == io.EOF { + break + } + if err != nil { + log.Printf("err received2: %v", err) + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket error reading message: %v", err)} + break + } + if streamResp.Model != "" && !sentHeader { + pk := MakeOpenAIPacket() + pk.Model = streamResp.Model + pk.Created = streamResp.Created + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Response: *pk} + sentHeader = true + } + for _, choice := range streamResp.Choices { + pk := MakeOpenAIPacket() + pk.Index = choice.Index + pk.Text = choice.Delta.Content + pk.FinishReason = string(choice.FinishReason) + rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Response: *pk} + } + } + }() + return rtn +} + +func marshalResponse(resp openaiapi.ChatCompletionResponse) []*OpenAIPacketType { + var rtn []*OpenAIPacketType + headerPk := MakeOpenAIPacket() + headerPk.Model = resp.Model + headerPk.Created = resp.Created + headerPk.Usage = convertUsage(resp) + rtn = append(rtn, headerPk) + for _, choice := range resp.Choices { + choicePk := MakeOpenAIPacket() + choicePk.Index = choice.Index + choicePk.Text = choice.Message.Content + choicePk.FinishReason = string(choice.FinishReason) + rtn = append(rtn, choicePk) + } + return rtn +} + +func CreateErrorPacket(errStr string) *OpenAIPacketType { + errPk := MakeOpenAIPacket() + errPk.FinishReason = "error" + errPk.Error = errStr + return errPk +} + +func CreateTextPacket(text string) *OpenAIPacketType { + pk := MakeOpenAIPacket() + pk.Text = text + return pk +} diff --git a/pkg/wconfig/settingsconfig.go b/pkg/wconfig/settingsconfig.go index b5dd02620..6777868e4 100644 --- a/pkg/wconfig/settingsconfig.go +++ b/pkg/wconfig/settingsconfig.go @@ -4,6 +4,7 @@ package wconfig import ( + "os/user" "path/filepath" "github.com/wavetermdev/thenextwave/pkg/wavebase" @@ -141,6 +142,13 @@ func applyDefaultSettings(settings *SettingsConfigType) { IntervalMs: 3600000, } } + var userName string + currentUser, err := user.Current() + if err != nil { + userName = "user" + } else { + userName = currentUser.Username + } defaultWidgets := []WidgetsConfigType{ { Icon: "files", @@ -170,6 +178,7 @@ func applyDefaultSettings(settings *SettingsConfigType) { Label: "waveai", BlockDef: wstore.BlockDef{ View: "waveai", + Meta: map[string]any{"name": userName, "baseurl": "", "apitoken": ""}, }, }, } diff --git a/pkg/wshrpc/wshclient/wshclient.go b/pkg/wshrpc/wshclient/wshclient.go index 197f96df5..0e4e9e8d4 100644 --- a/pkg/wshrpc/wshclient/wshclient.go +++ b/pkg/wshrpc/wshclient/wshclient.go @@ -9,6 +9,7 @@ import ( "github.com/wavetermdev/thenextwave/pkg/wshutil" "github.com/wavetermdev/thenextwave/pkg/wshrpc" "github.com/wavetermdev/thenextwave/pkg/waveobj" + "github.com/wavetermdev/thenextwave/pkg/waveai" ) // command "controller:input", wshserver.BlockInputCommand @@ -89,6 +90,11 @@ func BlockSetViewCommand(w *wshutil.WshRpc, data wshrpc.CommandBlockSetViewData, return err } +// command "stream:waveai", wshserver.RespStreamWaveAi +func RespStreamWaveAi(w *wshutil.WshRpc, data waveai.OpenAiStreamRequest, opts *wshrpc.WshRpcCommandOpts) chan wshrpc.RespOrErrorUnion[waveai.OpenAIPacketType] { + return sendRpcRequestResponseStreamHelper[waveai.OpenAIPacketType](w, "stream:waveai", data, opts) +} + // command "streamtest", wshserver.RespStreamTest func RespStreamTest(w *wshutil.WshRpc, opts *wshrpc.WshRpcCommandOpts) chan wshrpc.RespOrErrorUnion[int] { return sendRpcRequestResponseStreamHelper[int](w, "streamtest", nil, opts) diff --git a/pkg/wshrpc/wshrpctypes.go b/pkg/wshrpc/wshrpctypes.go index a0d224783..1679b95c9 100644 --- a/pkg/wshrpc/wshrpctypes.go +++ b/pkg/wshrpc/wshrpctypes.go @@ -15,19 +15,20 @@ import ( ) const ( - Command_Message = "message" - Command_SetView = "setview" - Command_SetMeta = "setmeta" - Command_GetMeta = "getmeta" - Command_BlockInput = "controller:input" - Command_Restart = "controller:restart" - Command_AppendFile = "file:append" - Command_AppendIJson = "file:appendijson" - Command_ResolveIds = "resolveids" - Command_CreateBlock = "createblock" - Command_DeleteBlock = "deleteblock" - Command_WriteFile = "file:write" - Command_ReadFile = "file:read" + Command_Message = "message" + Command_SetView = "setview" + Command_SetMeta = "setmeta" + Command_GetMeta = "getmeta" + Command_BlockInput = "controller:input" + Command_Restart = "controller:restart" + Command_AppendFile = "file:append" + Command_AppendIJson = "file:appendijson" + Command_ResolveIds = "resolveids" + Command_CreateBlock = "createblock" + Command_DeleteBlock = "deleteblock" + Command_WriteFile = "file:write" + Command_ReadFile = "file:read" + Command_StreamWaveAi = "stream:waveai" ) type MetaDataType = map[string]any diff --git a/pkg/wshrpc/wshserver/wshserver.go b/pkg/wshrpc/wshserver/wshserver.go index 18706238f..07fbf0e66 100644 --- a/pkg/wshrpc/wshserver/wshserver.go +++ b/pkg/wshrpc/wshserver/wshserver.go @@ -18,6 +18,7 @@ import ( "github.com/wavetermdev/thenextwave/pkg/blockcontroller" "github.com/wavetermdev/thenextwave/pkg/eventbus" "github.com/wavetermdev/thenextwave/pkg/filestore" + "github.com/wavetermdev/thenextwave/pkg/waveai" "github.com/wavetermdev/thenextwave/pkg/waveobj" "github.com/wavetermdev/thenextwave/pkg/wshrpc" "github.com/wavetermdev/thenextwave/pkg/wshutil" @@ -33,21 +34,31 @@ var RespStreamTest_MethodDecl = &WshServerMethodDecl{ DefaultResponseDataType: reflect.TypeOf((int)(0)), } +var RespStreamWaveAi_MethodDecl = &WshServerMethodDecl{ + Command: wshrpc.Command_StreamWaveAi, + CommandType: wshutil.RpcType_ResponseStream, + MethodName: "RespStreamWaveAi", + Method: reflect.ValueOf(WshServerImpl.RespStreamWaveAi), + CommandDataType: reflect.TypeOf(waveai.OpenAiStreamRequest{}), + DefaultResponseDataType: reflect.TypeOf(waveai.OpenAIPacketType{}), +} + var WshServerCommandToDeclMap = map[string]*WshServerMethodDecl{ - wshrpc.Command_Message: GetWshServerMethod(wshrpc.Command_Message, wshutil.RpcType_Call, "MessageCommand", WshServerImpl.MessageCommand), - wshrpc.Command_SetView: GetWshServerMethod(wshrpc.Command_SetView, wshutil.RpcType_Call, "BlockSetViewCommand", WshServerImpl.BlockSetViewCommand), - wshrpc.Command_SetMeta: GetWshServerMethod(wshrpc.Command_SetMeta, wshutil.RpcType_Call, "SetMetaCommand", WshServerImpl.SetMetaCommand), - wshrpc.Command_GetMeta: GetWshServerMethod(wshrpc.Command_GetMeta, wshutil.RpcType_Call, "GetMetaCommand", WshServerImpl.GetMetaCommand), - wshrpc.Command_ResolveIds: GetWshServerMethod(wshrpc.Command_ResolveIds, wshutil.RpcType_Call, "ResolveIdsCommand", WshServerImpl.ResolveIdsCommand), - wshrpc.Command_CreateBlock: GetWshServerMethod(wshrpc.Command_CreateBlock, wshutil.RpcType_Call, "CreateBlockCommand", WshServerImpl.CreateBlockCommand), - wshrpc.Command_Restart: GetWshServerMethod(wshrpc.Command_Restart, wshutil.RpcType_Call, "BlockRestartCommand", WshServerImpl.BlockRestartCommand), - wshrpc.Command_BlockInput: GetWshServerMethod(wshrpc.Command_BlockInput, wshutil.RpcType_Call, "BlockInputCommand", WshServerImpl.BlockInputCommand), - wshrpc.Command_AppendFile: GetWshServerMethod(wshrpc.Command_AppendFile, wshutil.RpcType_Call, "AppendFileCommand", WshServerImpl.AppendFileCommand), - wshrpc.Command_AppendIJson: GetWshServerMethod(wshrpc.Command_AppendIJson, wshutil.RpcType_Call, "AppendIJsonCommand", WshServerImpl.AppendIJsonCommand), - wshrpc.Command_DeleteBlock: GetWshServerMethod(wshrpc.Command_DeleteBlock, wshutil.RpcType_Call, "DeleteBlockCommand", WshServerImpl.DeleteBlockCommand), - wshrpc.Command_WriteFile: GetWshServerMethod(wshrpc.Command_WriteFile, wshutil.RpcType_Call, "WriteFile", WshServerImpl.WriteFile), - wshrpc.Command_ReadFile: GetWshServerMethod(wshrpc.Command_ReadFile, wshutil.RpcType_Call, "ReadFile", WshServerImpl.ReadFile), - "streamtest": RespStreamTest_MethodDecl, + wshrpc.Command_Message: GetWshServerMethod(wshrpc.Command_Message, wshutil.RpcType_Call, "MessageCommand", WshServerImpl.MessageCommand), + wshrpc.Command_SetView: GetWshServerMethod(wshrpc.Command_SetView, wshutil.RpcType_Call, "BlockSetViewCommand", WshServerImpl.BlockSetViewCommand), + wshrpc.Command_SetMeta: GetWshServerMethod(wshrpc.Command_SetMeta, wshutil.RpcType_Call, "SetMetaCommand", WshServerImpl.SetMetaCommand), + wshrpc.Command_GetMeta: GetWshServerMethod(wshrpc.Command_GetMeta, wshutil.RpcType_Call, "GetMetaCommand", WshServerImpl.GetMetaCommand), + wshrpc.Command_ResolveIds: GetWshServerMethod(wshrpc.Command_ResolveIds, wshutil.RpcType_Call, "ResolveIdsCommand", WshServerImpl.ResolveIdsCommand), + wshrpc.Command_CreateBlock: GetWshServerMethod(wshrpc.Command_CreateBlock, wshutil.RpcType_Call, "CreateBlockCommand", WshServerImpl.CreateBlockCommand), + wshrpc.Command_Restart: GetWshServerMethod(wshrpc.Command_Restart, wshutil.RpcType_Call, "BlockRestartCommand", WshServerImpl.BlockRestartCommand), + wshrpc.Command_BlockInput: GetWshServerMethod(wshrpc.Command_BlockInput, wshutil.RpcType_Call, "BlockInputCommand", WshServerImpl.BlockInputCommand), + wshrpc.Command_AppendFile: GetWshServerMethod(wshrpc.Command_AppendFile, wshutil.RpcType_Call, "AppendFileCommand", WshServerImpl.AppendFileCommand), + wshrpc.Command_AppendIJson: GetWshServerMethod(wshrpc.Command_AppendIJson, wshutil.RpcType_Call, "AppendIJsonCommand", WshServerImpl.AppendIJsonCommand), + wshrpc.Command_DeleteBlock: GetWshServerMethod(wshrpc.Command_DeleteBlock, wshutil.RpcType_Call, "DeleteBlockCommand", WshServerImpl.DeleteBlockCommand), + wshrpc.Command_WriteFile: GetWshServerMethod(wshrpc.Command_WriteFile, wshutil.RpcType_Call, "WriteFile", WshServerImpl.WriteFile), + wshrpc.Command_ReadFile: GetWshServerMethod(wshrpc.Command_ReadFile, wshutil.RpcType_Call, "ReadFile", WshServerImpl.ReadFile), + wshrpc.Command_StreamWaveAi: RespStreamWaveAi_MethodDecl, + "streamtest": RespStreamTest_MethodDecl, } // for testing @@ -69,6 +80,13 @@ func (ws *WshServer) RespStreamTest(ctx context.Context) chan wshrpc.RespOrError return rtn } +func (ws *WshServer) RespStreamWaveAi(ctx context.Context, request waveai.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[waveai.OpenAIPacketType] { + if request.Opts.BaseURL == "" && request.Opts.APIToken == "" { + return waveai.RunCloudCompletionStream(ctx, request) + } + return waveai.RunLocalCompletionStream(ctx, request) +} + func (ws *WshServer) GetMetaCommand(ctx context.Context, data wshrpc.CommandGetMetaData) (wshrpc.MetaDataType, error) { log.Printf("calling meta: %s\n", data.ORef) obj, err := wstore.DBGetORef(ctx, data.ORef)