Open Ai Port (#154)

This brings over a simplified version of the open ai feature from the
previous app but in widget form. It still needs some work to reach
parity with that version, but this includes all of the basic building
blocks to get that working.
This commit is contained in:
Sylvie Crowe 2024-07-25 02:30:49 -07:00 committed by GitHub
parent dcb4d5f2bf
commit 0e46b79c22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 576 additions and 134 deletions

View File

@ -70,6 +70,7 @@ func main() {
fmt.Fprintf(fd, " \"github.com/wavetermdev/thenextwave/pkg/wshutil\"\n") fmt.Fprintf(fd, " \"github.com/wavetermdev/thenextwave/pkg/wshutil\"\n")
fmt.Fprintf(fd, " \"github.com/wavetermdev/thenextwave/pkg/wshrpc\"\n") fmt.Fprintf(fd, " \"github.com/wavetermdev/thenextwave/pkg/wshrpc\"\n")
fmt.Fprintf(fd, " \"github.com/wavetermdev/thenextwave/pkg/waveobj\"\n") fmt.Fprintf(fd, " \"github.com/wavetermdev/thenextwave/pkg/waveobj\"\n")
fmt.Fprintf(fd, " \"github.com/wavetermdev/thenextwave/pkg/waveai\"\n")
fmt.Fprintf(fd, ")\n\n") fmt.Fprintf(fd, ")\n\n")
for _, key := range utilfn.GetOrderedMapKeys(wshserver.WshServerCommandToDeclMap) { for _, key := range utilfn.GetOrderedMapKeys(wshserver.WshServerCommandToDeclMap) {

View File

@ -14,7 +14,7 @@ import * as util from "@/util/util";
import { PlotView } from "@/view/plotview"; import { PlotView } from "@/view/plotview";
import { PreviewView, makePreviewModel } from "@/view/preview"; import { PreviewView, makePreviewModel } from "@/view/preview";
import { TerminalView, makeTerminalModel } from "@/view/term/term"; import { TerminalView, makeTerminalModel } from "@/view/term/term";
import { WaveAi } from "@/view/waveai"; import { WaveAi, makeWaveAiViewModel } from "@/view/waveai";
import { WebView, makeWebViewModel } from "@/view/webview"; import { WebView, makeWebViewModel } from "@/view/webview";
import clsx from "clsx"; import clsx from "clsx";
import * as jotai from "jotai"; import * as jotai from "jotai";
@ -516,7 +516,9 @@ function getViewElemAndModel(
viewElem = <WebView key={blockId} parentRef={blockRef} model={webviewModel} />; viewElem = <WebView key={blockId} parentRef={blockRef} model={webviewModel} />;
viewModel = webviewModel; viewModel = webviewModel;
} else if (blockView === "waveai") { } else if (blockView === "waveai") {
viewElem = <WaveAi key={blockId} />; const waveAiModel = makeWaveAiViewModel(blockId);
viewElem = <WaveAi key={blockId} model={waveAiModel} />;
viewModel = waveAiModel;
} }
if (viewModel == null) { if (viewModel == null) {
viewModel = makeDefaultViewModel(blockId); viewModel = makeDefaultViewModel(blockId);

View File

@ -1,101 +0,0 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
import { atom, useAtom } from "jotai";
import { v4 as uuidv4 } from "uuid";
interface ChatMessageType {
id: string;
user: string;
text: string;
isAssistant: boolean;
isUpdating?: boolean;
isError?: string;
}
const defaultMessage: ChatMessageType = {
id: uuidv4(),
user: "assistant",
text: `<p>Hello, how may I help you with this command?<br>
(Cmd-Shift-Space: open/close, Ctrl+L: clear chat buffer, Up/Down: select code blocks, Enter: to copy a selected code block to the command input)</p>`,
isAssistant: true,
};
const messagesAtom = atom<ChatMessageType[]>([defaultMessage]);
const addMessageAtom = atom(null, (get, set, message: ChatMessageType) => {
const messages = get(messagesAtom);
set(messagesAtom, [...messages, message]);
});
const updateLastMessageAtom = atom(null, (get, set, text: string, isUpdating: boolean) => {
const messages = get(messagesAtom);
const lastMessage = messages[messages.length - 1];
if (lastMessage.isAssistant && !lastMessage.isError) {
const updatedMessage = { ...lastMessage, text: lastMessage.text + text, isUpdating };
set(messagesAtom, [...messages.slice(0, -1), updatedMessage]);
}
});
const simulateAssistantResponseAtom = atom(null, (get, set, userMessage: ChatMessageType) => {
const responseText = `Here is an example of a simple bash script:
\`\`\`bash
#!/bin/bash
# This is a comment
echo "Hello, World!"
\`\`\`
You can run this script by saving it to a file, for example, \`hello.sh\`, and then running \`chmod +x hello.sh\` to make it executable. Finally, run it with \`./hello.sh\`.`;
const typingMessage: ChatMessageType = {
id: uuidv4(),
user: "assistant",
text: "",
isAssistant: true,
};
// Add a typing indicator
set(addMessageAtom, typingMessage);
setTimeout(() => {
const parts = responseText.split(" ");
let currentPart = 0;
const intervalId = setInterval(() => {
if (currentPart < parts.length) {
const part = parts[currentPart] + " ";
set(updateLastMessageAtom, part, true);
currentPart++;
} else {
clearInterval(intervalId);
set(updateLastMessageAtom, "", false);
}
}, 100);
}, 1500);
});
const useWaveAi = () => {
const [messages] = useAtom(messagesAtom);
const [, addMessage] = useAtom(addMessageAtom);
const [, simulateResponse] = useAtom(simulateAssistantResponseAtom);
const sendMessage = (text: string, user: string = "user") => {
const newMessage: ChatMessageType = {
id: uuidv4(),
user,
text,
isAssistant: false,
};
addMessage(newMessage);
simulateResponse(newMessage);
};
return {
messages,
sendMessage,
};
};
export { useWaveAi };
export type { ChatMessageType };

View File

@ -72,6 +72,11 @@ class WshServerType {
return WOS.wshServerRpcHelper_call("setview", data, opts); return WOS.wshServerRpcHelper_call("setview", data, opts);
} }
// command "stream:waveai" [responsestream]
RespStreamWaveAi(data: OpenAiStreamRequest, opts?: WshRpcCommandOpts): AsyncGenerator<OpenAIPacketType, void, boolean> {
return WOS.wshServerRpcHelper_responsestream("stream:waveai", data, opts);
}
// command "streamtest" [responsestream] // command "streamtest" [responsestream]
RespStreamTest(opts?: WshRpcCommandOpts): AsyncGenerator<number, void, boolean> { RespStreamTest(opts?: WshRpcCommandOpts): AsyncGenerator<number, void, boolean> {
return WOS.wshServerRpcHelper_responsestream("streamtest", null, opts); return WOS.wshServerRpcHelper_responsestream("streamtest", null, opts);

View File

@ -3,21 +3,170 @@
import { Markdown } from "@/app/element/markdown"; import { Markdown } from "@/app/element/markdown";
import { TypingIndicator } from "@/app/element/typingindicator"; import { TypingIndicator } from "@/app/element/typingindicator";
import { ChatMessageType, useWaveAi } from "@/app/store/waveai"; import { WOS, atoms } from "@/store/global";
import { WshServer } from "@/store/wshserver";
import * as jotai from "jotai";
import type { OverlayScrollbars } from "overlayscrollbars"; import type { OverlayScrollbars } from "overlayscrollbars";
import { OverlayScrollbarsComponent, OverlayScrollbarsComponentRef } from "overlayscrollbars-react"; import { OverlayScrollbarsComponent, OverlayScrollbarsComponentRef } from "overlayscrollbars-react";
import React, { forwardRef, useCallback, useEffect, useImperativeHandle, useRef, useState } from "react"; import React, { forwardRef, useCallback, useEffect, useImperativeHandle, useRef, useState } from "react";
import tinycolor from "tinycolor2"; import tinycolor from "tinycolor2";
import { v4 as uuidv4 } from "uuid";
import "./waveai.less"; import "./waveai.less";
interface ChatMessageType {
id: string;
user: string;
text: string;
isAssistant: boolean;
isUpdating?: boolean;
isError?: string;
}
const outline = "2px solid var(--accent-color)"; const outline = "2px solid var(--accent-color)";
const defaultMessage: ChatMessageType = {
id: uuidv4(),
user: "assistant",
text: `<p>Hello, how may I help you with this command?<br>
(Cmd-Shift-Space: open/close, Ctrl+L: clear chat buffer, Up/Down: select code blocks, Enter: to copy a selected code block to the command input)</p>`,
isAssistant: true,
};
interface ChatItemProps { interface ChatItemProps {
chatItem: ChatMessageType; chatItem: ChatMessageType;
itemCount: number; itemCount: number;
} }
export class WaveAiModel implements ViewModel {
blockId: string;
blockAtom: jotai.Atom<Block>;
viewIcon?: jotai.Atom<string | HeaderIconButton>;
viewName?: jotai.Atom<string>;
viewText?: jotai.Atom<string | HeaderElem[]>;
preIconButton?: jotai.Atom<HeaderIconButton>;
endIconButtons?: jotai.Atom<HeaderIconButton[]>;
messagesAtom: jotai.PrimitiveAtom<Array<ChatMessageType>>;
addMessageAtom: jotai.WritableAtom<unknown, [message: ChatMessageType], void>;
updateLastMessageAtom: jotai.WritableAtom<unknown, [text: string, isUpdating: boolean], void>;
simulateAssistantResponseAtom: jotai.WritableAtom<unknown, [userMessage: ChatMessageType], void>;
constructor(blockId: string) {
this.blockId = blockId;
this.blockAtom = WOS.getWaveObjectAtom<Block>(`block:${blockId}`);
this.viewIcon = jotai.atom((get) => {
return "sparkles"; // should not be hardcoded
});
this.viewName = jotai.atom("Ai");
this.messagesAtom = jotai.atom([defaultMessage]);
this.addMessageAtom = jotai.atom(null, (get, set, message: ChatMessageType) => {
const messages = get(this.messagesAtom);
set(this.messagesAtom, [...messages, message]);
});
this.updateLastMessageAtom = jotai.atom(null, (get, set, text: string, isUpdating: boolean) => {
const messages = get(this.messagesAtom);
const lastMessage = messages[messages.length - 1];
if (lastMessage.isAssistant && !lastMessage.isError) {
const updatedMessage = { ...lastMessage, text: lastMessage.text + text, isUpdating };
set(this.messagesAtom, [...messages.slice(0, -1), updatedMessage]);
}
});
this.simulateAssistantResponseAtom = jotai.atom(null, (get, set, userMessage: ChatMessageType) => {
const typingMessage: ChatMessageType = {
id: uuidv4(),
user: "assistant",
text: "",
isAssistant: true,
};
// Add a typing indicator
set(this.addMessageAtom, typingMessage);
setTimeout(() => {
const parts = userMessage.text.split(" ");
let currentPart = 0;
const intervalId = setInterval(() => {
if (currentPart < parts.length) {
const part = parts[currentPart] + " ";
set(this.updateLastMessageAtom, part, true);
currentPart++;
} else {
clearInterval(intervalId);
set(this.updateLastMessageAtom, "", false);
}
}, 100);
}, 1500);
});
}
useWaveAi() {
const [messages] = jotai.useAtom(this.messagesAtom);
const [, addMessage] = jotai.useAtom(this.addMessageAtom);
const [, simulateResponse] = jotai.useAtom(this.simulateAssistantResponseAtom);
const metadata = jotai.useAtomValue(this.blockAtom).meta;
const clientId = jotai.useAtomValue(atoms.clientId);
const sendMessage = (text: string, user: string = "user") => {
const newMessage: ChatMessageType = {
id: uuidv4(),
user,
text,
isAssistant: false,
};
addMessage(newMessage);
// send message to backend and get response
const opts: OpenAIOptsType = {
model: "gpt-3.5-turbo",
apitoken: metadata?.apitoken as string,
maxtokens: 1000,
timeout: 10,
baseurl: metadata?.baseurl as string,
};
const prompt: Array<OpenAIPromptMessageType> = [
{
role: "user",
content: text,
name: (metadata?.name as string) || "user",
},
];
console.log("opts.apitoken:", opts.apitoken);
const beMsg: OpenAiStreamRequest = {
clientid: clientId,
opts: opts,
prompt: prompt,
};
const aiGen = WshServer.RespStreamWaveAi(beMsg);
let temp = async () => {
let fullMsg = "";
for await (const msg of aiGen) {
fullMsg += msg.text ?? "";
}
const response: ChatMessageType = {
id: newMessage.id,
user: newMessage.user,
text: fullMsg,
isAssistant: true,
};
simulateResponse(response);
};
temp();
};
return {
messages,
sendMessage,
};
}
}
function makeWaveAiViewModel(blockId): WaveAiModel {
const waveAiModel = new WaveAiModel(blockId);
return waveAiModel;
}
const ChatItem = ({ chatItem, itemCount }: ChatItemProps) => { const ChatItem = ({ chatItem, itemCount }: ChatItemProps) => {
const { isAssistant, text, isError } = chatItem; const { isAssistant, text, isError } = chatItem;
const senderClassName = isAssistant ? "chat-msg-assistant" : "chat-msg-user"; const senderClassName = isAssistant ? "chat-msg-assistant" : "chat-msg-user";
@ -208,8 +357,8 @@ const ChatInput = forwardRef<HTMLTextAreaElement, ChatInputProps>(
} }
); );
const WaveAi = () => { const WaveAi = ({ model }: { model: WaveAiModel }) => {
const { messages, sendMessage } = useWaveAi(); const { messages, sendMessage } = model.useWaveAi();
const waveaiRef = useRef<HTMLDivElement>(null); const waveaiRef = useRef<HTMLDivElement>(null);
const chatWindowRef = useRef<HTMLDivElement>(null); const chatWindowRef = useRef<HTMLDivElement>(null);
const osRef = useRef<OverlayScrollbarsComponentRef>(null); const osRef = useRef<OverlayScrollbarsComponentRef>(null);
@ -407,4 +556,4 @@ const WaveAi = () => {
); );
}; };
export { WaveAi }; export { WaveAi, makeWaveAiViewModel };

View File

@ -187,6 +187,49 @@ declare global {
// waveobj.ORef // waveobj.ORef
type ORef = string; type ORef = string;
// waveai.OpenAIOptsType
type OpenAIOptsType = {
model: string;
apitoken: string;
baseurl?: string;
maxtokens?: number;
maxchoices?: number;
timeout?: number;
};
// waveai.OpenAIPacketType
type OpenAIPacketType = {
type: string;
model?: string;
created?: number;
finish_reason?: string;
usage?: OpenAIUsageType;
index?: number;
text?: string;
error?: string;
};
// waveai.OpenAIPromptMessageType
type OpenAIPromptMessageType = {
role: string;
content: string;
name?: string;
};
// waveai.OpenAIUsageType
type OpenAIUsageType = {
prompt_tokens?: number;
completion_tokens?: number;
total_tokens?: number;
};
// waveai.OpenAiStreamRequest
type OpenAiStreamRequest = {
clientid?: string;
opts: OpenAIOptsType;
prompt: OpenAIPromptMessageType[];
};
// wstore.Point // wstore.Point
type Point = { type Point = {
x: number; x: number;

1
go.mod
View File

@ -15,6 +15,7 @@ require (
github.com/kevinburke/ssh_config v1.2.0 github.com/kevinburke/ssh_config v1.2.0
github.com/mattn/go-sqlite3 v1.14.22 github.com/mattn/go-sqlite3 v1.14.22
github.com/mitchellh/mapstructure v1.5.0 github.com/mitchellh/mapstructure v1.5.0
github.com/sashabaranov/go-openai v1.27.0
github.com/sawka/txwrap v0.2.0 github.com/sawka/txwrap v0.2.0
github.com/spf13/cobra v1.8.1 github.com/spf13/cobra v1.8.1
github.com/wavetermdev/htmltoken v0.1.0 github.com/wavetermdev/htmltoken v0.1.0

2
go.sum
View File

@ -42,6 +42,8 @@ github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RR
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sashabaranov/go-openai v1.27.0 h1:L3hO6650YUbKrbGUC6yCjsUluhKZ9h1/jcgbTItI8Mo=
github.com/sashabaranov/go-openai v1.27.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sawka/txwrap v0.2.0 h1:V3LfvKVLULxcYSxdMguLwFyQFMEU9nFDJopg0ZkL+94= github.com/sawka/txwrap v0.2.0 h1:V3LfvKVLULxcYSxdMguLwFyQFMEU9nFDJopg0ZkL+94=
github.com/sawka/txwrap v0.2.0/go.mod h1:wwQ2SQiN4U+6DU/iVPhbvr7OzXAtgZlQCIGuvOswEfA= github.com/sawka/txwrap v0.2.0/go.mod h1:wwQ2SQiN4U+6DU/iVPhbvr7OzXAtgZlQCIGuvOswEfA=
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=

306
pkg/waveai/waveai.go Normal file
View File

@ -0,0 +1,306 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package waveai
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"os"
"time"
openaiapi "github.com/sashabaranov/go-openai"
"github.com/wavetermdev/thenextwave/pkg/wavebase"
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
"github.com/gorilla/websocket"
)
const OpenAIPacketStr = "openai"
const OpenAICloudReqStr = "openai-cloudreq"
const PacketEOFStr = "EOF"
type OpenAIUsageType struct {
PromptTokens int `json:"prompt_tokens,omitempty"`
CompletionTokens int `json:"completion_tokens,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
}
type OpenAICmdInfoPacketOutputType struct {
Model string `json:"model,omitempty"`
Created int64 `json:"created,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
Message string `json:"message,omitempty"`
Error string `json:"error,omitempty"`
}
type OpenAIPacketType struct {
Type string `json:"type"`
Model string `json:"model,omitempty"`
Created int64 `json:"created,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
Usage *OpenAIUsageType `json:"usage,omitempty"`
Index int `json:"index,omitempty"`
Text string `json:"text,omitempty"`
Error string `json:"error,omitempty"`
}
func MakeOpenAIPacket() *OpenAIPacketType {
return &OpenAIPacketType{Type: OpenAIPacketStr}
}
type OpenAICmdInfoChatMessage struct {
MessageID int `json:"messageid"`
IsAssistantResponse bool `json:"isassistantresponse,omitempty"`
AssistantResponse *OpenAICmdInfoPacketOutputType `json:"assistantresponse,omitempty"`
UserQuery string `json:"userquery,omitempty"`
UserEngineeredQuery string `json:"userengineeredquery,omitempty"`
}
type OpenAIPromptMessageType struct {
Role string `json:"role"`
Content string `json:"content"`
Name string `json:"name,omitempty"`
}
type OpenAICloudReqPacketType struct {
Type string `json:"type"`
ClientId string `json:"clientid"`
Prompt []OpenAIPromptMessageType `json:"prompt"`
MaxTokens int `json:"maxtokens,omitempty"`
MaxChoices int `json:"maxchoices,omitempty"`
}
type OpenAIOptsType struct {
Model string `json:"model"`
APIToken string `json:"apitoken"`
BaseURL string `json:"baseurl,omitempty"`
MaxTokens int `json:"maxtokens,omitempty"`
MaxChoices int `json:"maxchoices,omitempty"`
Timeout int `json:"timeout,omitempty"`
}
func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType {
return &OpenAICloudReqPacketType{
Type: OpenAICloudReqStr,
}
}
type OpenAiStreamRequest struct {
ClientId string `json:"clientid,omitempty"`
Opts *OpenAIOptsType `json:"opts"`
Prompt []OpenAIPromptMessageType `json:"prompt"`
}
func GetWSEndpoint() string {
return PCloudWSEndpoint
if !wavebase.IsDevMode() {
return PCloudWSEndpoint
} else {
endpoint := os.Getenv(PCloudWSEndpointVarName)
if endpoint == "" {
panic("Invalid PCloud ws dev endpoint, PCLOUD_WS_ENDPOINT not set or invalid")
}
return endpoint
}
}
const DefaultMaxTokens = 1000
const DefaultModel = "gpt-3.5-turbo"
const DefaultStreamChanSize = 10
const PCloudWSEndpoint = "wss://wsapi.waveterm.dev/"
const PCloudWSEndpointVarName = "PCLOUD_WS_ENDPOINT"
const CloudWebsocketConnectTimeout = 1 * time.Minute
func convertUsage(resp openaiapi.ChatCompletionResponse) *OpenAIUsageType {
if resp.Usage.TotalTokens == 0 {
return nil
}
return &OpenAIUsageType{
PromptTokens: resp.Usage.PromptTokens,
CompletionTokens: resp.Usage.CompletionTokens,
TotalTokens: resp.Usage.TotalTokens,
}
}
func ConvertPrompt(prompt []OpenAIPromptMessageType) []openaiapi.ChatCompletionMessage {
var rtn []openaiapi.ChatCompletionMessage
for _, p := range prompt {
msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name}
rtn = append(rtn, msg)
}
return rtn
}
func RunCloudCompletionStream(ctx context.Context, request OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[OpenAIPacketType] {
rtn := make(chan wshrpc.RespOrErrorUnion[OpenAIPacketType])
go func() {
log.Printf("start: %v", request)
defer close(rtn)
if request.Opts == nil {
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("no openai opts found")}
return
}
websocketContext, dialCancelFn := context.WithTimeout(context.Background(), CloudWebsocketConnectTimeout)
defer dialCancelFn()
conn, _, err := websocket.DefaultDialer.DialContext(websocketContext, GetWSEndpoint(), nil)
defer func() {
err = conn.Close()
if err != nil {
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("unable to close openai channel: %v", err)}
}
}()
if err == context.DeadlineExceeded {
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, timed out connecting to cloud server: %v", err)}
return
} else if err != nil {
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket connect error: %v", err)}
return
}
reqPk := MakeOpenAICloudReqPacket()
reqPk.ClientId = request.ClientId
reqPk.Prompt = request.Prompt
reqPk.MaxTokens = request.Opts.MaxTokens
reqPk.MaxChoices = request.Opts.MaxChoices
configMessageBuf, err := json.Marshal(reqPk)
if err != nil {
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, packet marshal error: %v", err)}
return
}
err = conn.WriteMessage(websocket.TextMessage, configMessageBuf)
if err != nil {
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket write config error: %v", err)}
return
}
for {
log.Printf("loop")
_, socketMessage, err := conn.ReadMessage()
if err == io.EOF {
break
}
if err != nil {
log.Printf("err received: %v", err)
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket error reading message: %v", err)}
break
}
var streamResp *OpenAIPacketType
err = json.Unmarshal(socketMessage, &streamResp)
log.Printf("ai resp: %v", streamResp)
if err != nil {
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket response json decode error: %v", err)}
break
}
if streamResp.Error == PacketEOFStr {
// got eof packet from socket
break
} else if streamResp.Error != "" {
// use error from server directly
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("%v", streamResp.Error)}
break
}
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Response: *streamResp}
}
}()
return rtn
}
func RunLocalCompletionStream(ctx context.Context, request OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[OpenAIPacketType] {
rtn := make(chan wshrpc.RespOrErrorUnion[OpenAIPacketType])
go func() {
log.Printf("start2: %v", request)
defer close(rtn)
if request.Opts == nil {
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("no openai opts found")}
return
}
if request.Opts.Model == "" {
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("no openai model specified")}
return
}
if request.Opts.BaseURL == "" && request.Opts.APIToken == "" {
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("no api token")}
return
}
clientConfig := openaiapi.DefaultConfig(request.Opts.APIToken)
if request.Opts.BaseURL != "" {
clientConfig.BaseURL = request.Opts.BaseURL
}
client := openaiapi.NewClientWithConfig(clientConfig)
req := openaiapi.ChatCompletionRequest{
Model: request.Opts.Model,
Messages: ConvertPrompt(request.Prompt),
MaxTokens: request.Opts.MaxTokens,
Stream: true,
}
if request.Opts.MaxChoices > 1 {
req.N = request.Opts.MaxChoices
}
apiResp, err := client.CreateChatCompletionStream(ctx, req)
if err != nil {
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("error calling openai API: %v", err)}
return
}
sentHeader := false
for {
log.Printf("loop2")
streamResp, err := apiResp.Recv()
if err == io.EOF {
break
}
if err != nil {
log.Printf("err received2: %v", err)
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket error reading message: %v", err)}
break
}
if streamResp.Model != "" && !sentHeader {
pk := MakeOpenAIPacket()
pk.Model = streamResp.Model
pk.Created = streamResp.Created
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Response: *pk}
sentHeader = true
}
for _, choice := range streamResp.Choices {
pk := MakeOpenAIPacket()
pk.Index = choice.Index
pk.Text = choice.Delta.Content
pk.FinishReason = string(choice.FinishReason)
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Response: *pk}
}
}
}()
return rtn
}
func marshalResponse(resp openaiapi.ChatCompletionResponse) []*OpenAIPacketType {
var rtn []*OpenAIPacketType
headerPk := MakeOpenAIPacket()
headerPk.Model = resp.Model
headerPk.Created = resp.Created
headerPk.Usage = convertUsage(resp)
rtn = append(rtn, headerPk)
for _, choice := range resp.Choices {
choicePk := MakeOpenAIPacket()
choicePk.Index = choice.Index
choicePk.Text = choice.Message.Content
choicePk.FinishReason = string(choice.FinishReason)
rtn = append(rtn, choicePk)
}
return rtn
}
func CreateErrorPacket(errStr string) *OpenAIPacketType {
errPk := MakeOpenAIPacket()
errPk.FinishReason = "error"
errPk.Error = errStr
return errPk
}
func CreateTextPacket(text string) *OpenAIPacketType {
pk := MakeOpenAIPacket()
pk.Text = text
return pk
}

View File

@ -4,6 +4,7 @@
package wconfig package wconfig
import ( import (
"os/user"
"path/filepath" "path/filepath"
"github.com/wavetermdev/thenextwave/pkg/wavebase" "github.com/wavetermdev/thenextwave/pkg/wavebase"
@ -141,6 +142,13 @@ func applyDefaultSettings(settings *SettingsConfigType) {
IntervalMs: 3600000, IntervalMs: 3600000,
} }
} }
var userName string
currentUser, err := user.Current()
if err != nil {
userName = "user"
} else {
userName = currentUser.Username
}
defaultWidgets := []WidgetsConfigType{ defaultWidgets := []WidgetsConfigType{
{ {
Icon: "files", Icon: "files",
@ -170,6 +178,7 @@ func applyDefaultSettings(settings *SettingsConfigType) {
Label: "waveai", Label: "waveai",
BlockDef: wstore.BlockDef{ BlockDef: wstore.BlockDef{
View: "waveai", View: "waveai",
Meta: map[string]any{"name": userName, "baseurl": "", "apitoken": ""},
}, },
}, },
} }

View File

@ -9,6 +9,7 @@ import (
"github.com/wavetermdev/thenextwave/pkg/wshutil" "github.com/wavetermdev/thenextwave/pkg/wshutil"
"github.com/wavetermdev/thenextwave/pkg/wshrpc" "github.com/wavetermdev/thenextwave/pkg/wshrpc"
"github.com/wavetermdev/thenextwave/pkg/waveobj" "github.com/wavetermdev/thenextwave/pkg/waveobj"
"github.com/wavetermdev/thenextwave/pkg/waveai"
) )
// command "controller:input", wshserver.BlockInputCommand // command "controller:input", wshserver.BlockInputCommand
@ -89,6 +90,11 @@ func BlockSetViewCommand(w *wshutil.WshRpc, data wshrpc.CommandBlockSetViewData,
return err return err
} }
// command "stream:waveai", wshserver.RespStreamWaveAi
func RespStreamWaveAi(w *wshutil.WshRpc, data waveai.OpenAiStreamRequest, opts *wshrpc.WshRpcCommandOpts) chan wshrpc.RespOrErrorUnion[waveai.OpenAIPacketType] {
return sendRpcRequestResponseStreamHelper[waveai.OpenAIPacketType](w, "stream:waveai", data, opts)
}
// command "streamtest", wshserver.RespStreamTest // command "streamtest", wshserver.RespStreamTest
func RespStreamTest(w *wshutil.WshRpc, opts *wshrpc.WshRpcCommandOpts) chan wshrpc.RespOrErrorUnion[int] { func RespStreamTest(w *wshutil.WshRpc, opts *wshrpc.WshRpcCommandOpts) chan wshrpc.RespOrErrorUnion[int] {
return sendRpcRequestResponseStreamHelper[int](w, "streamtest", nil, opts) return sendRpcRequestResponseStreamHelper[int](w, "streamtest", nil, opts)

View File

@ -28,6 +28,7 @@ const (
Command_DeleteBlock = "deleteblock" Command_DeleteBlock = "deleteblock"
Command_WriteFile = "file:write" Command_WriteFile = "file:write"
Command_ReadFile = "file:read" Command_ReadFile = "file:read"
Command_StreamWaveAi = "stream:waveai"
) )
type MetaDataType = map[string]any type MetaDataType = map[string]any

View File

@ -18,6 +18,7 @@ import (
"github.com/wavetermdev/thenextwave/pkg/blockcontroller" "github.com/wavetermdev/thenextwave/pkg/blockcontroller"
"github.com/wavetermdev/thenextwave/pkg/eventbus" "github.com/wavetermdev/thenextwave/pkg/eventbus"
"github.com/wavetermdev/thenextwave/pkg/filestore" "github.com/wavetermdev/thenextwave/pkg/filestore"
"github.com/wavetermdev/thenextwave/pkg/waveai"
"github.com/wavetermdev/thenextwave/pkg/waveobj" "github.com/wavetermdev/thenextwave/pkg/waveobj"
"github.com/wavetermdev/thenextwave/pkg/wshrpc" "github.com/wavetermdev/thenextwave/pkg/wshrpc"
"github.com/wavetermdev/thenextwave/pkg/wshutil" "github.com/wavetermdev/thenextwave/pkg/wshutil"
@ -33,6 +34,15 @@ var RespStreamTest_MethodDecl = &WshServerMethodDecl{
DefaultResponseDataType: reflect.TypeOf((int)(0)), DefaultResponseDataType: reflect.TypeOf((int)(0)),
} }
var RespStreamWaveAi_MethodDecl = &WshServerMethodDecl{
Command: wshrpc.Command_StreamWaveAi,
CommandType: wshutil.RpcType_ResponseStream,
MethodName: "RespStreamWaveAi",
Method: reflect.ValueOf(WshServerImpl.RespStreamWaveAi),
CommandDataType: reflect.TypeOf(waveai.OpenAiStreamRequest{}),
DefaultResponseDataType: reflect.TypeOf(waveai.OpenAIPacketType{}),
}
var WshServerCommandToDeclMap = map[string]*WshServerMethodDecl{ var WshServerCommandToDeclMap = map[string]*WshServerMethodDecl{
wshrpc.Command_Message: GetWshServerMethod(wshrpc.Command_Message, wshutil.RpcType_Call, "MessageCommand", WshServerImpl.MessageCommand), wshrpc.Command_Message: GetWshServerMethod(wshrpc.Command_Message, wshutil.RpcType_Call, "MessageCommand", WshServerImpl.MessageCommand),
wshrpc.Command_SetView: GetWshServerMethod(wshrpc.Command_SetView, wshutil.RpcType_Call, "BlockSetViewCommand", WshServerImpl.BlockSetViewCommand), wshrpc.Command_SetView: GetWshServerMethod(wshrpc.Command_SetView, wshutil.RpcType_Call, "BlockSetViewCommand", WshServerImpl.BlockSetViewCommand),
@ -47,6 +57,7 @@ var WshServerCommandToDeclMap = map[string]*WshServerMethodDecl{
wshrpc.Command_DeleteBlock: GetWshServerMethod(wshrpc.Command_DeleteBlock, wshutil.RpcType_Call, "DeleteBlockCommand", WshServerImpl.DeleteBlockCommand), wshrpc.Command_DeleteBlock: GetWshServerMethod(wshrpc.Command_DeleteBlock, wshutil.RpcType_Call, "DeleteBlockCommand", WshServerImpl.DeleteBlockCommand),
wshrpc.Command_WriteFile: GetWshServerMethod(wshrpc.Command_WriteFile, wshutil.RpcType_Call, "WriteFile", WshServerImpl.WriteFile), wshrpc.Command_WriteFile: GetWshServerMethod(wshrpc.Command_WriteFile, wshutil.RpcType_Call, "WriteFile", WshServerImpl.WriteFile),
wshrpc.Command_ReadFile: GetWshServerMethod(wshrpc.Command_ReadFile, wshutil.RpcType_Call, "ReadFile", WshServerImpl.ReadFile), wshrpc.Command_ReadFile: GetWshServerMethod(wshrpc.Command_ReadFile, wshutil.RpcType_Call, "ReadFile", WshServerImpl.ReadFile),
wshrpc.Command_StreamWaveAi: RespStreamWaveAi_MethodDecl,
"streamtest": RespStreamTest_MethodDecl, "streamtest": RespStreamTest_MethodDecl,
} }
@ -69,6 +80,13 @@ func (ws *WshServer) RespStreamTest(ctx context.Context) chan wshrpc.RespOrError
return rtn return rtn
} }
func (ws *WshServer) RespStreamWaveAi(ctx context.Context, request waveai.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[waveai.OpenAIPacketType] {
if request.Opts.BaseURL == "" && request.Opts.APIToken == "" {
return waveai.RunCloudCompletionStream(ctx, request)
}
return waveai.RunLocalCompletionStream(ctx, request)
}
func (ws *WshServer) GetMetaCommand(ctx context.Context, data wshrpc.CommandGetMetaData) (wshrpc.MetaDataType, error) { func (ws *WshServer) GetMetaCommand(ctx context.Context, data wshrpc.CommandGetMetaData) (wshrpc.MetaDataType, error) {
log.Printf("calling meta: %s\n", data.ORef) log.Printf("calling meta: %s\n", data.ORef)
obj, err := wstore.DBGetORef(ctx, data.ORef) obj, err := wstore.DBGetORef(ctx, data.ORef)