diff --git a/cmd/generatewshclient/main-generatewshclient.go b/cmd/generatewshclient/main-generatewshclient.go
index c62694aae..c63036e8c 100644
--- a/cmd/generatewshclient/main-generatewshclient.go
+++ b/cmd/generatewshclient/main-generatewshclient.go
@@ -70,6 +70,7 @@ func main() {
fmt.Fprintf(fd, " \"github.com/wavetermdev/thenextwave/pkg/wshutil\"\n")
fmt.Fprintf(fd, " \"github.com/wavetermdev/thenextwave/pkg/wshrpc\"\n")
fmt.Fprintf(fd, " \"github.com/wavetermdev/thenextwave/pkg/waveobj\"\n")
+ fmt.Fprintf(fd, " \"github.com/wavetermdev/thenextwave/pkg/waveai\"\n")
fmt.Fprintf(fd, ")\n\n")
for _, key := range utilfn.GetOrderedMapKeys(wshserver.WshServerCommandToDeclMap) {
diff --git a/frontend/app/block/block.tsx b/frontend/app/block/block.tsx
index fca7bab27..136f2542c 100644
--- a/frontend/app/block/block.tsx
+++ b/frontend/app/block/block.tsx
@@ -14,7 +14,7 @@ import * as util from "@/util/util";
import { PlotView } from "@/view/plotview";
import { PreviewView, makePreviewModel } from "@/view/preview";
import { TerminalView, makeTerminalModel } from "@/view/term/term";
-import { WaveAi } from "@/view/waveai";
+import { WaveAi, makeWaveAiViewModel } from "@/view/waveai";
import { WebView, makeWebViewModel } from "@/view/webview";
import clsx from "clsx";
import * as jotai from "jotai";
@@ -516,7 +516,9 @@ function getViewElemAndModel(
viewElem = ;
viewModel = webviewModel;
} else if (blockView === "waveai") {
- viewElem = ;
+ const waveAiModel = makeWaveAiViewModel(blockId);
+ viewElem = ;
+ viewModel = waveAiModel;
}
if (viewModel == null) {
viewModel = makeDefaultViewModel(blockId);
diff --git a/frontend/app/store/waveai.ts b/frontend/app/store/waveai.ts
deleted file mode 100644
index cafde4f54..000000000
--- a/frontend/app/store/waveai.ts
+++ /dev/null
@@ -1,101 +0,0 @@
-// Copyright 2024, Command Line Inc.
-// SPDX-License-Identifier: Apache-2.0
-
-import { atom, useAtom } from "jotai";
-import { v4 as uuidv4 } from "uuid";
-
-interface ChatMessageType {
- id: string;
- user: string;
- text: string;
- isAssistant: boolean;
- isUpdating?: boolean;
- isError?: string;
-}
-
-const defaultMessage: ChatMessageType = {
- id: uuidv4(),
- user: "assistant",
- text: `
Hello, how may I help you with this command?
-(Cmd-Shift-Space: open/close, Ctrl+L: clear chat buffer, Up/Down: select code blocks, Enter: to copy a selected code block to the command input)
`,
- isAssistant: true,
-};
-
-const messagesAtom = atom([defaultMessage]);
-
-const addMessageAtom = atom(null, (get, set, message: ChatMessageType) => {
- const messages = get(messagesAtom);
- set(messagesAtom, [...messages, message]);
-});
-
-const updateLastMessageAtom = atom(null, (get, set, text: string, isUpdating: boolean) => {
- const messages = get(messagesAtom);
- const lastMessage = messages[messages.length - 1];
- if (lastMessage.isAssistant && !lastMessage.isError) {
- const updatedMessage = { ...lastMessage, text: lastMessage.text + text, isUpdating };
- set(messagesAtom, [...messages.slice(0, -1), updatedMessage]);
- }
-});
-
-const simulateAssistantResponseAtom = atom(null, (get, set, userMessage: ChatMessageType) => {
- const responseText = `Here is an example of a simple bash script:
-
-\`\`\`bash
-#!/bin/bash
-# This is a comment
-echo "Hello, World!"
-\`\`\`
-
-You can run this script by saving it to a file, for example, \`hello.sh\`, and then running \`chmod +x hello.sh\` to make it executable. Finally, run it with \`./hello.sh\`.`;
-
- const typingMessage: ChatMessageType = {
- id: uuidv4(),
- user: "assistant",
- text: "",
- isAssistant: true,
- };
-
- // Add a typing indicator
- set(addMessageAtom, typingMessage);
-
- setTimeout(() => {
- const parts = responseText.split(" ");
- let currentPart = 0;
-
- const intervalId = setInterval(() => {
- if (currentPart < parts.length) {
- const part = parts[currentPart] + " ";
- set(updateLastMessageAtom, part, true);
- currentPart++;
- } else {
- clearInterval(intervalId);
- set(updateLastMessageAtom, "", false);
- }
- }, 100);
- }, 1500);
-});
-
-const useWaveAi = () => {
- const [messages] = useAtom(messagesAtom);
- const [, addMessage] = useAtom(addMessageAtom);
- const [, simulateResponse] = useAtom(simulateAssistantResponseAtom);
-
- const sendMessage = (text: string, user: string = "user") => {
- const newMessage: ChatMessageType = {
- id: uuidv4(),
- user,
- text,
- isAssistant: false,
- };
- addMessage(newMessage);
- simulateResponse(newMessage);
- };
-
- return {
- messages,
- sendMessage,
- };
-};
-
-export { useWaveAi };
-export type { ChatMessageType };
diff --git a/frontend/app/store/wshserver.ts b/frontend/app/store/wshserver.ts
index ac45fcdc4..d2c399c29 100644
--- a/frontend/app/store/wshserver.ts
+++ b/frontend/app/store/wshserver.ts
@@ -72,6 +72,11 @@ class WshServerType {
return WOS.wshServerRpcHelper_call("setview", data, opts);
}
+ // command "stream:waveai" [responsestream]
+ RespStreamWaveAi(data: OpenAiStreamRequest, opts?: WshRpcCommandOpts): AsyncGenerator {
+ return WOS.wshServerRpcHelper_responsestream("stream:waveai", data, opts);
+ }
+
// command "streamtest" [responsestream]
RespStreamTest(opts?: WshRpcCommandOpts): AsyncGenerator {
return WOS.wshServerRpcHelper_responsestream("streamtest", null, opts);
diff --git a/frontend/app/view/waveai.tsx b/frontend/app/view/waveai.tsx
index 0193a21df..fb590e73a 100644
--- a/frontend/app/view/waveai.tsx
+++ b/frontend/app/view/waveai.tsx
@@ -3,21 +3,170 @@
import { Markdown } from "@/app/element/markdown";
import { TypingIndicator } from "@/app/element/typingindicator";
-import { ChatMessageType, useWaveAi } from "@/app/store/waveai";
+import { WOS, atoms } from "@/store/global";
+import { WshServer } from "@/store/wshserver";
+import * as jotai from "jotai";
import type { OverlayScrollbars } from "overlayscrollbars";
import { OverlayScrollbarsComponent, OverlayScrollbarsComponentRef } from "overlayscrollbars-react";
import React, { forwardRef, useCallback, useEffect, useImperativeHandle, useRef, useState } from "react";
import tinycolor from "tinycolor2";
+import { v4 as uuidv4 } from "uuid";
import "./waveai.less";
+interface ChatMessageType {
+ id: string;
+ user: string;
+ text: string;
+ isAssistant: boolean;
+ isUpdating?: boolean;
+ isError?: string;
+}
+
const outline = "2px solid var(--accent-color)";
+const defaultMessage: ChatMessageType = {
+ id: uuidv4(),
+ user: "assistant",
+ text: `Hello, how may I help you with this command?
+(Cmd-Shift-Space: open/close, Ctrl+L: clear chat buffer, Up/Down: select code blocks, Enter: to copy a selected code block to the command input)
`,
+ isAssistant: true,
+};
+
interface ChatItemProps {
chatItem: ChatMessageType;
itemCount: number;
}
+export class WaveAiModel implements ViewModel {
+ blockId: string;
+ blockAtom: jotai.Atom;
+ viewIcon?: jotai.Atom;
+ viewName?: jotai.Atom;
+ viewText?: jotai.Atom;
+ preIconButton?: jotai.Atom;
+ endIconButtons?: jotai.Atom;
+ messagesAtom: jotai.PrimitiveAtom>;
+ addMessageAtom: jotai.WritableAtom;
+ updateLastMessageAtom: jotai.WritableAtom;
+ simulateAssistantResponseAtom: jotai.WritableAtom;
+
+ constructor(blockId: string) {
+ this.blockId = blockId;
+ this.blockAtom = WOS.getWaveObjectAtom(`block:${blockId}`);
+ this.viewIcon = jotai.atom((get) => {
+ return "sparkles"; // should not be hardcoded
+ });
+ this.viewName = jotai.atom("Ai");
+ this.messagesAtom = jotai.atom([defaultMessage]);
+
+ this.addMessageAtom = jotai.atom(null, (get, set, message: ChatMessageType) => {
+ const messages = get(this.messagesAtom);
+ set(this.messagesAtom, [...messages, message]);
+ });
+
+ this.updateLastMessageAtom = jotai.atom(null, (get, set, text: string, isUpdating: boolean) => {
+ const messages = get(this.messagesAtom);
+ const lastMessage = messages[messages.length - 1];
+ if (lastMessage.isAssistant && !lastMessage.isError) {
+ const updatedMessage = { ...lastMessage, text: lastMessage.text + text, isUpdating };
+ set(this.messagesAtom, [...messages.slice(0, -1), updatedMessage]);
+ }
+ });
+ this.simulateAssistantResponseAtom = jotai.atom(null, (get, set, userMessage: ChatMessageType) => {
+ const typingMessage: ChatMessageType = {
+ id: uuidv4(),
+ user: "assistant",
+ text: "",
+ isAssistant: true,
+ };
+
+ // Add a typing indicator
+ set(this.addMessageAtom, typingMessage);
+
+ setTimeout(() => {
+ const parts = userMessage.text.split(" ");
+ let currentPart = 0;
+
+ const intervalId = setInterval(() => {
+ if (currentPart < parts.length) {
+ const part = parts[currentPart] + " ";
+ set(this.updateLastMessageAtom, part, true);
+ currentPart++;
+ } else {
+ clearInterval(intervalId);
+ set(this.updateLastMessageAtom, "", false);
+ }
+ }, 100);
+ }, 1500);
+ });
+ }
+
+ useWaveAi() {
+ const [messages] = jotai.useAtom(this.messagesAtom);
+ const [, addMessage] = jotai.useAtom(this.addMessageAtom);
+ const [, simulateResponse] = jotai.useAtom(this.simulateAssistantResponseAtom);
+ const metadata = jotai.useAtomValue(this.blockAtom).meta;
+ const clientId = jotai.useAtomValue(atoms.clientId);
+
+ const sendMessage = (text: string, user: string = "user") => {
+ const newMessage: ChatMessageType = {
+ id: uuidv4(),
+ user,
+ text,
+ isAssistant: false,
+ };
+ addMessage(newMessage);
+ // send message to backend and get response
+ const opts: OpenAIOptsType = {
+ model: "gpt-3.5-turbo",
+ apitoken: metadata?.apitoken as string,
+ maxtokens: 1000,
+ timeout: 10,
+ baseurl: metadata?.baseurl as string,
+ };
+ const prompt: Array = [
+ {
+ role: "user",
+ content: text,
+ name: (metadata?.name as string) || "user",
+ },
+ ];
+ console.log("opts.apitoken:", opts.apitoken);
+ const beMsg: OpenAiStreamRequest = {
+ clientid: clientId,
+ opts: opts,
+ prompt: prompt,
+ };
+ const aiGen = WshServer.RespStreamWaveAi(beMsg);
+ let temp = async () => {
+ let fullMsg = "";
+ for await (const msg of aiGen) {
+ fullMsg += msg.text ?? "";
+ }
+ const response: ChatMessageType = {
+ id: newMessage.id,
+ user: newMessage.user,
+ text: fullMsg,
+ isAssistant: true,
+ };
+ simulateResponse(response);
+ };
+ temp();
+ };
+
+ return {
+ messages,
+ sendMessage,
+ };
+ }
+}
+
+function makeWaveAiViewModel(blockId): WaveAiModel {
+ const waveAiModel = new WaveAiModel(blockId);
+ return waveAiModel;
+}
+
const ChatItem = ({ chatItem, itemCount }: ChatItemProps) => {
const { isAssistant, text, isError } = chatItem;
const senderClassName = isAssistant ? "chat-msg-assistant" : "chat-msg-user";
@@ -208,8 +357,8 @@ const ChatInput = forwardRef(
}
);
-const WaveAi = () => {
- const { messages, sendMessage } = useWaveAi();
+const WaveAi = ({ model }: { model: WaveAiModel }) => {
+ const { messages, sendMessage } = model.useWaveAi();
const waveaiRef = useRef(null);
const chatWindowRef = useRef(null);
const osRef = useRef(null);
@@ -407,4 +556,4 @@ const WaveAi = () => {
);
};
-export { WaveAi };
+export { WaveAi, makeWaveAiViewModel };
diff --git a/frontend/types/gotypes.d.ts b/frontend/types/gotypes.d.ts
index f5f185329..95b25e619 100644
--- a/frontend/types/gotypes.d.ts
+++ b/frontend/types/gotypes.d.ts
@@ -187,6 +187,49 @@ declare global {
// waveobj.ORef
type ORef = string;
+ // waveai.OpenAIOptsType
+ type OpenAIOptsType = {
+ model: string;
+ apitoken: string;
+ baseurl?: string;
+ maxtokens?: number;
+ maxchoices?: number;
+ timeout?: number;
+ };
+
+ // waveai.OpenAIPacketType
+ type OpenAIPacketType = {
+ type: string;
+ model?: string;
+ created?: number;
+ finish_reason?: string;
+ usage?: OpenAIUsageType;
+ index?: number;
+ text?: string;
+ error?: string;
+ };
+
+ // waveai.OpenAIPromptMessageType
+ type OpenAIPromptMessageType = {
+ role: string;
+ content: string;
+ name?: string;
+ };
+
+ // waveai.OpenAIUsageType
+ type OpenAIUsageType = {
+ prompt_tokens?: number;
+ completion_tokens?: number;
+ total_tokens?: number;
+ };
+
+ // waveai.OpenAiStreamRequest
+ type OpenAiStreamRequest = {
+ clientid?: string;
+ opts: OpenAIOptsType;
+ prompt: OpenAIPromptMessageType[];
+ };
+
// wstore.Point
type Point = {
x: number;
diff --git a/go.mod b/go.mod
index 4ccf8b8bf..8ae012fe5 100644
--- a/go.mod
+++ b/go.mod
@@ -15,6 +15,7 @@ require (
github.com/kevinburke/ssh_config v1.2.0
github.com/mattn/go-sqlite3 v1.14.22
github.com/mitchellh/mapstructure v1.5.0
+ github.com/sashabaranov/go-openai v1.27.0
github.com/sawka/txwrap v0.2.0
github.com/spf13/cobra v1.8.1
github.com/wavetermdev/htmltoken v0.1.0
diff --git a/go.sum b/go.sum
index f5b9bc7d3..e4a82752c 100644
--- a/go.sum
+++ b/go.sum
@@ -42,6 +42,8 @@ github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RR
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
+github.com/sashabaranov/go-openai v1.27.0 h1:L3hO6650YUbKrbGUC6yCjsUluhKZ9h1/jcgbTItI8Mo=
+github.com/sashabaranov/go-openai v1.27.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sawka/txwrap v0.2.0 h1:V3LfvKVLULxcYSxdMguLwFyQFMEU9nFDJopg0ZkL+94=
github.com/sawka/txwrap v0.2.0/go.mod h1:wwQ2SQiN4U+6DU/iVPhbvr7OzXAtgZlQCIGuvOswEfA=
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
diff --git a/pkg/waveai/waveai.go b/pkg/waveai/waveai.go
new file mode 100644
index 000000000..671877414
--- /dev/null
+++ b/pkg/waveai/waveai.go
@@ -0,0 +1,306 @@
+// Copyright 2024, Command Line Inc.
+// SPDX-License-Identifier: Apache-2.0
+
+package waveai
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log"
+ "os"
+ "time"
+
+ openaiapi "github.com/sashabaranov/go-openai"
+ "github.com/wavetermdev/thenextwave/pkg/wavebase"
+ "github.com/wavetermdev/thenextwave/pkg/wshrpc"
+
+ "github.com/gorilla/websocket"
+)
+
+const OpenAIPacketStr = "openai"
+const OpenAICloudReqStr = "openai-cloudreq"
+const PacketEOFStr = "EOF"
+
+type OpenAIUsageType struct {
+ PromptTokens int `json:"prompt_tokens,omitempty"`
+ CompletionTokens int `json:"completion_tokens,omitempty"`
+ TotalTokens int `json:"total_tokens,omitempty"`
+}
+
+type OpenAICmdInfoPacketOutputType struct {
+ Model string `json:"model,omitempty"`
+ Created int64 `json:"created,omitempty"`
+ FinishReason string `json:"finish_reason,omitempty"`
+ Message string `json:"message,omitempty"`
+ Error string `json:"error,omitempty"`
+}
+
+type OpenAIPacketType struct {
+ Type string `json:"type"`
+ Model string `json:"model,omitempty"`
+ Created int64 `json:"created,omitempty"`
+ FinishReason string `json:"finish_reason,omitempty"`
+ Usage *OpenAIUsageType `json:"usage,omitempty"`
+ Index int `json:"index,omitempty"`
+ Text string `json:"text,omitempty"`
+ Error string `json:"error,omitempty"`
+}
+
+func MakeOpenAIPacket() *OpenAIPacketType {
+ return &OpenAIPacketType{Type: OpenAIPacketStr}
+}
+
+type OpenAICmdInfoChatMessage struct {
+ MessageID int `json:"messageid"`
+ IsAssistantResponse bool `json:"isassistantresponse,omitempty"`
+ AssistantResponse *OpenAICmdInfoPacketOutputType `json:"assistantresponse,omitempty"`
+ UserQuery string `json:"userquery,omitempty"`
+ UserEngineeredQuery string `json:"userengineeredquery,omitempty"`
+}
+
+type OpenAIPromptMessageType struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+ Name string `json:"name,omitempty"`
+}
+
+type OpenAICloudReqPacketType struct {
+ Type string `json:"type"`
+ ClientId string `json:"clientid"`
+ Prompt []OpenAIPromptMessageType `json:"prompt"`
+ MaxTokens int `json:"maxtokens,omitempty"`
+ MaxChoices int `json:"maxchoices,omitempty"`
+}
+
+type OpenAIOptsType struct {
+ Model string `json:"model"`
+ APIToken string `json:"apitoken"`
+ BaseURL string `json:"baseurl,omitempty"`
+ MaxTokens int `json:"maxtokens,omitempty"`
+ MaxChoices int `json:"maxchoices,omitempty"`
+ Timeout int `json:"timeout,omitempty"`
+}
+
+func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType {
+ return &OpenAICloudReqPacketType{
+ Type: OpenAICloudReqStr,
+ }
+}
+
+type OpenAiStreamRequest struct {
+ ClientId string `json:"clientid,omitempty"`
+ Opts *OpenAIOptsType `json:"opts"`
+ Prompt []OpenAIPromptMessageType `json:"prompt"`
+}
+
+func GetWSEndpoint() string {
+ return PCloudWSEndpoint
+ if !wavebase.IsDevMode() {
+ return PCloudWSEndpoint
+ } else {
+ endpoint := os.Getenv(PCloudWSEndpointVarName)
+ if endpoint == "" {
+ panic("Invalid PCloud ws dev endpoint, PCLOUD_WS_ENDPOINT not set or invalid")
+ }
+ return endpoint
+ }
+}
+
+const DefaultMaxTokens = 1000
+const DefaultModel = "gpt-3.5-turbo"
+const DefaultStreamChanSize = 10
+const PCloudWSEndpoint = "wss://wsapi.waveterm.dev/"
+const PCloudWSEndpointVarName = "PCLOUD_WS_ENDPOINT"
+
+const CloudWebsocketConnectTimeout = 1 * time.Minute
+
+func convertUsage(resp openaiapi.ChatCompletionResponse) *OpenAIUsageType {
+ if resp.Usage.TotalTokens == 0 {
+ return nil
+ }
+ return &OpenAIUsageType{
+ PromptTokens: resp.Usage.PromptTokens,
+ CompletionTokens: resp.Usage.CompletionTokens,
+ TotalTokens: resp.Usage.TotalTokens,
+ }
+}
+
+func ConvertPrompt(prompt []OpenAIPromptMessageType) []openaiapi.ChatCompletionMessage {
+ var rtn []openaiapi.ChatCompletionMessage
+ for _, p := range prompt {
+ msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name}
+ rtn = append(rtn, msg)
+ }
+ return rtn
+}
+
+func RunCloudCompletionStream(ctx context.Context, request OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[OpenAIPacketType] {
+ rtn := make(chan wshrpc.RespOrErrorUnion[OpenAIPacketType])
+ go func() {
+ log.Printf("start: %v", request)
+ defer close(rtn)
+ if request.Opts == nil {
+ rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("no openai opts found")}
+ return
+ }
+ websocketContext, dialCancelFn := context.WithTimeout(context.Background(), CloudWebsocketConnectTimeout)
+ defer dialCancelFn()
+ conn, _, err := websocket.DefaultDialer.DialContext(websocketContext, GetWSEndpoint(), nil)
+ defer func() {
+ err = conn.Close()
+ if err != nil {
+ rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("unable to close openai channel: %v", err)}
+ }
+ }()
+ if err == context.DeadlineExceeded {
+ rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, timed out connecting to cloud server: %v", err)}
+ return
+ } else if err != nil {
+ rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket connect error: %v", err)}
+ return
+ }
+ reqPk := MakeOpenAICloudReqPacket()
+ reqPk.ClientId = request.ClientId
+ reqPk.Prompt = request.Prompt
+ reqPk.MaxTokens = request.Opts.MaxTokens
+ reqPk.MaxChoices = request.Opts.MaxChoices
+ configMessageBuf, err := json.Marshal(reqPk)
+ if err != nil {
+ rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, packet marshal error: %v", err)}
+ return
+ }
+ err = conn.WriteMessage(websocket.TextMessage, configMessageBuf)
+ if err != nil {
+ rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket write config error: %v", err)}
+ return
+ }
+ for {
+ log.Printf("loop")
+ _, socketMessage, err := conn.ReadMessage()
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ log.Printf("err received: %v", err)
+ rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket error reading message: %v", err)}
+ break
+ }
+ var streamResp *OpenAIPacketType
+ err = json.Unmarshal(socketMessage, &streamResp)
+ log.Printf("ai resp: %v", streamResp)
+ if err != nil {
+ rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket response json decode error: %v", err)}
+ break
+ }
+ if streamResp.Error == PacketEOFStr {
+ // got eof packet from socket
+ break
+ } else if streamResp.Error != "" {
+ // use error from server directly
+ rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("%v", streamResp.Error)}
+ break
+ }
+ rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Response: *streamResp}
+ }
+ }()
+ return rtn
+}
+
+func RunLocalCompletionStream(ctx context.Context, request OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[OpenAIPacketType] {
+ rtn := make(chan wshrpc.RespOrErrorUnion[OpenAIPacketType])
+ go func() {
+ log.Printf("start2: %v", request)
+ defer close(rtn)
+ if request.Opts == nil {
+ rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("no openai opts found")}
+ return
+ }
+ if request.Opts.Model == "" {
+ rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("no openai model specified")}
+ return
+ }
+ if request.Opts.BaseURL == "" && request.Opts.APIToken == "" {
+ rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("no api token")}
+ return
+ }
+ clientConfig := openaiapi.DefaultConfig(request.Opts.APIToken)
+ if request.Opts.BaseURL != "" {
+ clientConfig.BaseURL = request.Opts.BaseURL
+ }
+ client := openaiapi.NewClientWithConfig(clientConfig)
+ req := openaiapi.ChatCompletionRequest{
+ Model: request.Opts.Model,
+ Messages: ConvertPrompt(request.Prompt),
+ MaxTokens: request.Opts.MaxTokens,
+ Stream: true,
+ }
+ if request.Opts.MaxChoices > 1 {
+ req.N = request.Opts.MaxChoices
+ }
+ apiResp, err := client.CreateChatCompletionStream(ctx, req)
+ if err != nil {
+ rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("error calling openai API: %v", err)}
+ return
+ }
+ sentHeader := false
+ for {
+ log.Printf("loop2")
+ streamResp, err := apiResp.Recv()
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ log.Printf("err received2: %v", err)
+ rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket error reading message: %v", err)}
+ break
+ }
+ if streamResp.Model != "" && !sentHeader {
+ pk := MakeOpenAIPacket()
+ pk.Model = streamResp.Model
+ pk.Created = streamResp.Created
+ rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Response: *pk}
+ sentHeader = true
+ }
+ for _, choice := range streamResp.Choices {
+ pk := MakeOpenAIPacket()
+ pk.Index = choice.Index
+ pk.Text = choice.Delta.Content
+ pk.FinishReason = string(choice.FinishReason)
+ rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Response: *pk}
+ }
+ }
+ }()
+ return rtn
+}
+
+func marshalResponse(resp openaiapi.ChatCompletionResponse) []*OpenAIPacketType {
+ var rtn []*OpenAIPacketType
+ headerPk := MakeOpenAIPacket()
+ headerPk.Model = resp.Model
+ headerPk.Created = resp.Created
+ headerPk.Usage = convertUsage(resp)
+ rtn = append(rtn, headerPk)
+ for _, choice := range resp.Choices {
+ choicePk := MakeOpenAIPacket()
+ choicePk.Index = choice.Index
+ choicePk.Text = choice.Message.Content
+ choicePk.FinishReason = string(choice.FinishReason)
+ rtn = append(rtn, choicePk)
+ }
+ return rtn
+}
+
+func CreateErrorPacket(errStr string) *OpenAIPacketType {
+ errPk := MakeOpenAIPacket()
+ errPk.FinishReason = "error"
+ errPk.Error = errStr
+ return errPk
+}
+
+func CreateTextPacket(text string) *OpenAIPacketType {
+ pk := MakeOpenAIPacket()
+ pk.Text = text
+ return pk
+}
diff --git a/pkg/wconfig/settingsconfig.go b/pkg/wconfig/settingsconfig.go
index b5dd02620..6777868e4 100644
--- a/pkg/wconfig/settingsconfig.go
+++ b/pkg/wconfig/settingsconfig.go
@@ -4,6 +4,7 @@
package wconfig
import (
+ "os/user"
"path/filepath"
"github.com/wavetermdev/thenextwave/pkg/wavebase"
@@ -141,6 +142,13 @@ func applyDefaultSettings(settings *SettingsConfigType) {
IntervalMs: 3600000,
}
}
+ var userName string
+ currentUser, err := user.Current()
+ if err != nil {
+ userName = "user"
+ } else {
+ userName = currentUser.Username
+ }
defaultWidgets := []WidgetsConfigType{
{
Icon: "files",
@@ -170,6 +178,7 @@ func applyDefaultSettings(settings *SettingsConfigType) {
Label: "waveai",
BlockDef: wstore.BlockDef{
View: "waveai",
+ Meta: map[string]any{"name": userName, "baseurl": "", "apitoken": ""},
},
},
}
diff --git a/pkg/wshrpc/wshclient/wshclient.go b/pkg/wshrpc/wshclient/wshclient.go
index 197f96df5..0e4e9e8d4 100644
--- a/pkg/wshrpc/wshclient/wshclient.go
+++ b/pkg/wshrpc/wshclient/wshclient.go
@@ -9,6 +9,7 @@ import (
"github.com/wavetermdev/thenextwave/pkg/wshutil"
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
"github.com/wavetermdev/thenextwave/pkg/waveobj"
+ "github.com/wavetermdev/thenextwave/pkg/waveai"
)
// command "controller:input", wshserver.BlockInputCommand
@@ -89,6 +90,11 @@ func BlockSetViewCommand(w *wshutil.WshRpc, data wshrpc.CommandBlockSetViewData,
return err
}
+// command "stream:waveai", wshserver.RespStreamWaveAi
+func RespStreamWaveAi(w *wshutil.WshRpc, data waveai.OpenAiStreamRequest, opts *wshrpc.WshRpcCommandOpts) chan wshrpc.RespOrErrorUnion[waveai.OpenAIPacketType] {
+ return sendRpcRequestResponseStreamHelper[waveai.OpenAIPacketType](w, "stream:waveai", data, opts)
+}
+
// command "streamtest", wshserver.RespStreamTest
func RespStreamTest(w *wshutil.WshRpc, opts *wshrpc.WshRpcCommandOpts) chan wshrpc.RespOrErrorUnion[int] {
return sendRpcRequestResponseStreamHelper[int](w, "streamtest", nil, opts)
diff --git a/pkg/wshrpc/wshrpctypes.go b/pkg/wshrpc/wshrpctypes.go
index a0d224783..1679b95c9 100644
--- a/pkg/wshrpc/wshrpctypes.go
+++ b/pkg/wshrpc/wshrpctypes.go
@@ -15,19 +15,20 @@ import (
)
const (
- Command_Message = "message"
- Command_SetView = "setview"
- Command_SetMeta = "setmeta"
- Command_GetMeta = "getmeta"
- Command_BlockInput = "controller:input"
- Command_Restart = "controller:restart"
- Command_AppendFile = "file:append"
- Command_AppendIJson = "file:appendijson"
- Command_ResolveIds = "resolveids"
- Command_CreateBlock = "createblock"
- Command_DeleteBlock = "deleteblock"
- Command_WriteFile = "file:write"
- Command_ReadFile = "file:read"
+ Command_Message = "message"
+ Command_SetView = "setview"
+ Command_SetMeta = "setmeta"
+ Command_GetMeta = "getmeta"
+ Command_BlockInput = "controller:input"
+ Command_Restart = "controller:restart"
+ Command_AppendFile = "file:append"
+ Command_AppendIJson = "file:appendijson"
+ Command_ResolveIds = "resolveids"
+ Command_CreateBlock = "createblock"
+ Command_DeleteBlock = "deleteblock"
+ Command_WriteFile = "file:write"
+ Command_ReadFile = "file:read"
+ Command_StreamWaveAi = "stream:waveai"
)
type MetaDataType = map[string]any
diff --git a/pkg/wshrpc/wshserver/wshserver.go b/pkg/wshrpc/wshserver/wshserver.go
index 18706238f..07fbf0e66 100644
--- a/pkg/wshrpc/wshserver/wshserver.go
+++ b/pkg/wshrpc/wshserver/wshserver.go
@@ -18,6 +18,7 @@ import (
"github.com/wavetermdev/thenextwave/pkg/blockcontroller"
"github.com/wavetermdev/thenextwave/pkg/eventbus"
"github.com/wavetermdev/thenextwave/pkg/filestore"
+ "github.com/wavetermdev/thenextwave/pkg/waveai"
"github.com/wavetermdev/thenextwave/pkg/waveobj"
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
"github.com/wavetermdev/thenextwave/pkg/wshutil"
@@ -33,21 +34,31 @@ var RespStreamTest_MethodDecl = &WshServerMethodDecl{
DefaultResponseDataType: reflect.TypeOf((int)(0)),
}
+var RespStreamWaveAi_MethodDecl = &WshServerMethodDecl{
+ Command: wshrpc.Command_StreamWaveAi,
+ CommandType: wshutil.RpcType_ResponseStream,
+ MethodName: "RespStreamWaveAi",
+ Method: reflect.ValueOf(WshServerImpl.RespStreamWaveAi),
+ CommandDataType: reflect.TypeOf(waveai.OpenAiStreamRequest{}),
+ DefaultResponseDataType: reflect.TypeOf(waveai.OpenAIPacketType{}),
+}
+
var WshServerCommandToDeclMap = map[string]*WshServerMethodDecl{
- wshrpc.Command_Message: GetWshServerMethod(wshrpc.Command_Message, wshutil.RpcType_Call, "MessageCommand", WshServerImpl.MessageCommand),
- wshrpc.Command_SetView: GetWshServerMethod(wshrpc.Command_SetView, wshutil.RpcType_Call, "BlockSetViewCommand", WshServerImpl.BlockSetViewCommand),
- wshrpc.Command_SetMeta: GetWshServerMethod(wshrpc.Command_SetMeta, wshutil.RpcType_Call, "SetMetaCommand", WshServerImpl.SetMetaCommand),
- wshrpc.Command_GetMeta: GetWshServerMethod(wshrpc.Command_GetMeta, wshutil.RpcType_Call, "GetMetaCommand", WshServerImpl.GetMetaCommand),
- wshrpc.Command_ResolveIds: GetWshServerMethod(wshrpc.Command_ResolveIds, wshutil.RpcType_Call, "ResolveIdsCommand", WshServerImpl.ResolveIdsCommand),
- wshrpc.Command_CreateBlock: GetWshServerMethod(wshrpc.Command_CreateBlock, wshutil.RpcType_Call, "CreateBlockCommand", WshServerImpl.CreateBlockCommand),
- wshrpc.Command_Restart: GetWshServerMethod(wshrpc.Command_Restart, wshutil.RpcType_Call, "BlockRestartCommand", WshServerImpl.BlockRestartCommand),
- wshrpc.Command_BlockInput: GetWshServerMethod(wshrpc.Command_BlockInput, wshutil.RpcType_Call, "BlockInputCommand", WshServerImpl.BlockInputCommand),
- wshrpc.Command_AppendFile: GetWshServerMethod(wshrpc.Command_AppendFile, wshutil.RpcType_Call, "AppendFileCommand", WshServerImpl.AppendFileCommand),
- wshrpc.Command_AppendIJson: GetWshServerMethod(wshrpc.Command_AppendIJson, wshutil.RpcType_Call, "AppendIJsonCommand", WshServerImpl.AppendIJsonCommand),
- wshrpc.Command_DeleteBlock: GetWshServerMethod(wshrpc.Command_DeleteBlock, wshutil.RpcType_Call, "DeleteBlockCommand", WshServerImpl.DeleteBlockCommand),
- wshrpc.Command_WriteFile: GetWshServerMethod(wshrpc.Command_WriteFile, wshutil.RpcType_Call, "WriteFile", WshServerImpl.WriteFile),
- wshrpc.Command_ReadFile: GetWshServerMethod(wshrpc.Command_ReadFile, wshutil.RpcType_Call, "ReadFile", WshServerImpl.ReadFile),
- "streamtest": RespStreamTest_MethodDecl,
+ wshrpc.Command_Message: GetWshServerMethod(wshrpc.Command_Message, wshutil.RpcType_Call, "MessageCommand", WshServerImpl.MessageCommand),
+ wshrpc.Command_SetView: GetWshServerMethod(wshrpc.Command_SetView, wshutil.RpcType_Call, "BlockSetViewCommand", WshServerImpl.BlockSetViewCommand),
+ wshrpc.Command_SetMeta: GetWshServerMethod(wshrpc.Command_SetMeta, wshutil.RpcType_Call, "SetMetaCommand", WshServerImpl.SetMetaCommand),
+ wshrpc.Command_GetMeta: GetWshServerMethod(wshrpc.Command_GetMeta, wshutil.RpcType_Call, "GetMetaCommand", WshServerImpl.GetMetaCommand),
+ wshrpc.Command_ResolveIds: GetWshServerMethod(wshrpc.Command_ResolveIds, wshutil.RpcType_Call, "ResolveIdsCommand", WshServerImpl.ResolveIdsCommand),
+ wshrpc.Command_CreateBlock: GetWshServerMethod(wshrpc.Command_CreateBlock, wshutil.RpcType_Call, "CreateBlockCommand", WshServerImpl.CreateBlockCommand),
+ wshrpc.Command_Restart: GetWshServerMethod(wshrpc.Command_Restart, wshutil.RpcType_Call, "BlockRestartCommand", WshServerImpl.BlockRestartCommand),
+ wshrpc.Command_BlockInput: GetWshServerMethod(wshrpc.Command_BlockInput, wshutil.RpcType_Call, "BlockInputCommand", WshServerImpl.BlockInputCommand),
+ wshrpc.Command_AppendFile: GetWshServerMethod(wshrpc.Command_AppendFile, wshutil.RpcType_Call, "AppendFileCommand", WshServerImpl.AppendFileCommand),
+ wshrpc.Command_AppendIJson: GetWshServerMethod(wshrpc.Command_AppendIJson, wshutil.RpcType_Call, "AppendIJsonCommand", WshServerImpl.AppendIJsonCommand),
+ wshrpc.Command_DeleteBlock: GetWshServerMethod(wshrpc.Command_DeleteBlock, wshutil.RpcType_Call, "DeleteBlockCommand", WshServerImpl.DeleteBlockCommand),
+ wshrpc.Command_WriteFile: GetWshServerMethod(wshrpc.Command_WriteFile, wshutil.RpcType_Call, "WriteFile", WshServerImpl.WriteFile),
+ wshrpc.Command_ReadFile: GetWshServerMethod(wshrpc.Command_ReadFile, wshutil.RpcType_Call, "ReadFile", WshServerImpl.ReadFile),
+ wshrpc.Command_StreamWaveAi: RespStreamWaveAi_MethodDecl,
+ "streamtest": RespStreamTest_MethodDecl,
}
// for testing
@@ -69,6 +80,13 @@ func (ws *WshServer) RespStreamTest(ctx context.Context) chan wshrpc.RespOrError
return rtn
}
+func (ws *WshServer) RespStreamWaveAi(ctx context.Context, request waveai.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[waveai.OpenAIPacketType] {
+ if request.Opts.BaseURL == "" && request.Opts.APIToken == "" {
+ return waveai.RunCloudCompletionStream(ctx, request)
+ }
+ return waveai.RunLocalCompletionStream(ctx, request)
+}
+
func (ws *WshServer) GetMetaCommand(ctx context.Context, data wshrpc.CommandGetMetaData) (wshrpc.MetaDataType, error) {
log.Printf("calling meta: %s\n", data.ORef)
obj, err := wstore.DBGetORef(ctx, data.ORef)