mirror of
https://github.com/wavetermdev/waveterm.git
synced 2024-12-21 16:38:23 +01:00
Open Ai Port (#154)
This brings over a simplified version of the open ai feature from the previous app but in widget form. It still needs some work to reach parity with that version, but this includes all of the basic building blocks to get that working.
This commit is contained in:
parent
dcb4d5f2bf
commit
0e46b79c22
@ -70,6 +70,7 @@ func main() {
|
||||
fmt.Fprintf(fd, " \"github.com/wavetermdev/thenextwave/pkg/wshutil\"\n")
|
||||
fmt.Fprintf(fd, " \"github.com/wavetermdev/thenextwave/pkg/wshrpc\"\n")
|
||||
fmt.Fprintf(fd, " \"github.com/wavetermdev/thenextwave/pkg/waveobj\"\n")
|
||||
fmt.Fprintf(fd, " \"github.com/wavetermdev/thenextwave/pkg/waveai\"\n")
|
||||
fmt.Fprintf(fd, ")\n\n")
|
||||
|
||||
for _, key := range utilfn.GetOrderedMapKeys(wshserver.WshServerCommandToDeclMap) {
|
||||
|
@ -14,7 +14,7 @@ import * as util from "@/util/util";
|
||||
import { PlotView } from "@/view/plotview";
|
||||
import { PreviewView, makePreviewModel } from "@/view/preview";
|
||||
import { TerminalView, makeTerminalModel } from "@/view/term/term";
|
||||
import { WaveAi } from "@/view/waveai";
|
||||
import { WaveAi, makeWaveAiViewModel } from "@/view/waveai";
|
||||
import { WebView, makeWebViewModel } from "@/view/webview";
|
||||
import clsx from "clsx";
|
||||
import * as jotai from "jotai";
|
||||
@ -516,7 +516,9 @@ function getViewElemAndModel(
|
||||
viewElem = <WebView key={blockId} parentRef={blockRef} model={webviewModel} />;
|
||||
viewModel = webviewModel;
|
||||
} else if (blockView === "waveai") {
|
||||
viewElem = <WaveAi key={blockId} />;
|
||||
const waveAiModel = makeWaveAiViewModel(blockId);
|
||||
viewElem = <WaveAi key={blockId} model={waveAiModel} />;
|
||||
viewModel = waveAiModel;
|
||||
}
|
||||
if (viewModel == null) {
|
||||
viewModel = makeDefaultViewModel(blockId);
|
||||
|
@ -1,101 +0,0 @@
|
||||
// Copyright 2024, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import { atom, useAtom } from "jotai";
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
|
||||
interface ChatMessageType {
|
||||
id: string;
|
||||
user: string;
|
||||
text: string;
|
||||
isAssistant: boolean;
|
||||
isUpdating?: boolean;
|
||||
isError?: string;
|
||||
}
|
||||
|
||||
const defaultMessage: ChatMessageType = {
|
||||
id: uuidv4(),
|
||||
user: "assistant",
|
||||
text: `<p>Hello, how may I help you with this command?<br>
|
||||
(Cmd-Shift-Space: open/close, Ctrl+L: clear chat buffer, Up/Down: select code blocks, Enter: to copy a selected code block to the command input)</p>`,
|
||||
isAssistant: true,
|
||||
};
|
||||
|
||||
const messagesAtom = atom<ChatMessageType[]>([defaultMessage]);
|
||||
|
||||
const addMessageAtom = atom(null, (get, set, message: ChatMessageType) => {
|
||||
const messages = get(messagesAtom);
|
||||
set(messagesAtom, [...messages, message]);
|
||||
});
|
||||
|
||||
const updateLastMessageAtom = atom(null, (get, set, text: string, isUpdating: boolean) => {
|
||||
const messages = get(messagesAtom);
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
if (lastMessage.isAssistant && !lastMessage.isError) {
|
||||
const updatedMessage = { ...lastMessage, text: lastMessage.text + text, isUpdating };
|
||||
set(messagesAtom, [...messages.slice(0, -1), updatedMessage]);
|
||||
}
|
||||
});
|
||||
|
||||
const simulateAssistantResponseAtom = atom(null, (get, set, userMessage: ChatMessageType) => {
|
||||
const responseText = `Here is an example of a simple bash script:
|
||||
|
||||
\`\`\`bash
|
||||
#!/bin/bash
|
||||
# This is a comment
|
||||
echo "Hello, World!"
|
||||
\`\`\`
|
||||
|
||||
You can run this script by saving it to a file, for example, \`hello.sh\`, and then running \`chmod +x hello.sh\` to make it executable. Finally, run it with \`./hello.sh\`.`;
|
||||
|
||||
const typingMessage: ChatMessageType = {
|
||||
id: uuidv4(),
|
||||
user: "assistant",
|
||||
text: "",
|
||||
isAssistant: true,
|
||||
};
|
||||
|
||||
// Add a typing indicator
|
||||
set(addMessageAtom, typingMessage);
|
||||
|
||||
setTimeout(() => {
|
||||
const parts = responseText.split(" ");
|
||||
let currentPart = 0;
|
||||
|
||||
const intervalId = setInterval(() => {
|
||||
if (currentPart < parts.length) {
|
||||
const part = parts[currentPart] + " ";
|
||||
set(updateLastMessageAtom, part, true);
|
||||
currentPart++;
|
||||
} else {
|
||||
clearInterval(intervalId);
|
||||
set(updateLastMessageAtom, "", false);
|
||||
}
|
||||
}, 100);
|
||||
}, 1500);
|
||||
});
|
||||
|
||||
const useWaveAi = () => {
|
||||
const [messages] = useAtom(messagesAtom);
|
||||
const [, addMessage] = useAtom(addMessageAtom);
|
||||
const [, simulateResponse] = useAtom(simulateAssistantResponseAtom);
|
||||
|
||||
const sendMessage = (text: string, user: string = "user") => {
|
||||
const newMessage: ChatMessageType = {
|
||||
id: uuidv4(),
|
||||
user,
|
||||
text,
|
||||
isAssistant: false,
|
||||
};
|
||||
addMessage(newMessage);
|
||||
simulateResponse(newMessage);
|
||||
};
|
||||
|
||||
return {
|
||||
messages,
|
||||
sendMessage,
|
||||
};
|
||||
};
|
||||
|
||||
export { useWaveAi };
|
||||
export type { ChatMessageType };
|
@ -72,6 +72,11 @@ class WshServerType {
|
||||
return WOS.wshServerRpcHelper_call("setview", data, opts);
|
||||
}
|
||||
|
||||
// command "stream:waveai" [responsestream]
|
||||
RespStreamWaveAi(data: OpenAiStreamRequest, opts?: WshRpcCommandOpts): AsyncGenerator<OpenAIPacketType, void, boolean> {
|
||||
return WOS.wshServerRpcHelper_responsestream("stream:waveai", data, opts);
|
||||
}
|
||||
|
||||
// command "streamtest" [responsestream]
|
||||
RespStreamTest(opts?: WshRpcCommandOpts): AsyncGenerator<number, void, boolean> {
|
||||
return WOS.wshServerRpcHelper_responsestream("streamtest", null, opts);
|
||||
|
@ -3,21 +3,170 @@
|
||||
|
||||
import { Markdown } from "@/app/element/markdown";
|
||||
import { TypingIndicator } from "@/app/element/typingindicator";
|
||||
import { ChatMessageType, useWaveAi } from "@/app/store/waveai";
|
||||
import { WOS, atoms } from "@/store/global";
|
||||
import { WshServer } from "@/store/wshserver";
|
||||
import * as jotai from "jotai";
|
||||
import type { OverlayScrollbars } from "overlayscrollbars";
|
||||
import { OverlayScrollbarsComponent, OverlayScrollbarsComponentRef } from "overlayscrollbars-react";
|
||||
import React, { forwardRef, useCallback, useEffect, useImperativeHandle, useRef, useState } from "react";
|
||||
import tinycolor from "tinycolor2";
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
|
||||
import "./waveai.less";
|
||||
|
||||
interface ChatMessageType {
|
||||
id: string;
|
||||
user: string;
|
||||
text: string;
|
||||
isAssistant: boolean;
|
||||
isUpdating?: boolean;
|
||||
isError?: string;
|
||||
}
|
||||
|
||||
const outline = "2px solid var(--accent-color)";
|
||||
|
||||
const defaultMessage: ChatMessageType = {
|
||||
id: uuidv4(),
|
||||
user: "assistant",
|
||||
text: `<p>Hello, how may I help you with this command?<br>
|
||||
(Cmd-Shift-Space: open/close, Ctrl+L: clear chat buffer, Up/Down: select code blocks, Enter: to copy a selected code block to the command input)</p>`,
|
||||
isAssistant: true,
|
||||
};
|
||||
|
||||
interface ChatItemProps {
|
||||
chatItem: ChatMessageType;
|
||||
itemCount: number;
|
||||
}
|
||||
|
||||
export class WaveAiModel implements ViewModel {
|
||||
blockId: string;
|
||||
blockAtom: jotai.Atom<Block>;
|
||||
viewIcon?: jotai.Atom<string | HeaderIconButton>;
|
||||
viewName?: jotai.Atom<string>;
|
||||
viewText?: jotai.Atom<string | HeaderElem[]>;
|
||||
preIconButton?: jotai.Atom<HeaderIconButton>;
|
||||
endIconButtons?: jotai.Atom<HeaderIconButton[]>;
|
||||
messagesAtom: jotai.PrimitiveAtom<Array<ChatMessageType>>;
|
||||
addMessageAtom: jotai.WritableAtom<unknown, [message: ChatMessageType], void>;
|
||||
updateLastMessageAtom: jotai.WritableAtom<unknown, [text: string, isUpdating: boolean], void>;
|
||||
simulateAssistantResponseAtom: jotai.WritableAtom<unknown, [userMessage: ChatMessageType], void>;
|
||||
|
||||
constructor(blockId: string) {
|
||||
this.blockId = blockId;
|
||||
this.blockAtom = WOS.getWaveObjectAtom<Block>(`block:${blockId}`);
|
||||
this.viewIcon = jotai.atom((get) => {
|
||||
return "sparkles"; // should not be hardcoded
|
||||
});
|
||||
this.viewName = jotai.atom("Ai");
|
||||
this.messagesAtom = jotai.atom([defaultMessage]);
|
||||
|
||||
this.addMessageAtom = jotai.atom(null, (get, set, message: ChatMessageType) => {
|
||||
const messages = get(this.messagesAtom);
|
||||
set(this.messagesAtom, [...messages, message]);
|
||||
});
|
||||
|
||||
this.updateLastMessageAtom = jotai.atom(null, (get, set, text: string, isUpdating: boolean) => {
|
||||
const messages = get(this.messagesAtom);
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
if (lastMessage.isAssistant && !lastMessage.isError) {
|
||||
const updatedMessage = { ...lastMessage, text: lastMessage.text + text, isUpdating };
|
||||
set(this.messagesAtom, [...messages.slice(0, -1), updatedMessage]);
|
||||
}
|
||||
});
|
||||
this.simulateAssistantResponseAtom = jotai.atom(null, (get, set, userMessage: ChatMessageType) => {
|
||||
const typingMessage: ChatMessageType = {
|
||||
id: uuidv4(),
|
||||
user: "assistant",
|
||||
text: "",
|
||||
isAssistant: true,
|
||||
};
|
||||
|
||||
// Add a typing indicator
|
||||
set(this.addMessageAtom, typingMessage);
|
||||
|
||||
setTimeout(() => {
|
||||
const parts = userMessage.text.split(" ");
|
||||
let currentPart = 0;
|
||||
|
||||
const intervalId = setInterval(() => {
|
||||
if (currentPart < parts.length) {
|
||||
const part = parts[currentPart] + " ";
|
||||
set(this.updateLastMessageAtom, part, true);
|
||||
currentPart++;
|
||||
} else {
|
||||
clearInterval(intervalId);
|
||||
set(this.updateLastMessageAtom, "", false);
|
||||
}
|
||||
}, 100);
|
||||
}, 1500);
|
||||
});
|
||||
}
|
||||
|
||||
useWaveAi() {
|
||||
const [messages] = jotai.useAtom(this.messagesAtom);
|
||||
const [, addMessage] = jotai.useAtom(this.addMessageAtom);
|
||||
const [, simulateResponse] = jotai.useAtom(this.simulateAssistantResponseAtom);
|
||||
const metadata = jotai.useAtomValue(this.blockAtom).meta;
|
||||
const clientId = jotai.useAtomValue(atoms.clientId);
|
||||
|
||||
const sendMessage = (text: string, user: string = "user") => {
|
||||
const newMessage: ChatMessageType = {
|
||||
id: uuidv4(),
|
||||
user,
|
||||
text,
|
||||
isAssistant: false,
|
||||
};
|
||||
addMessage(newMessage);
|
||||
// send message to backend and get response
|
||||
const opts: OpenAIOptsType = {
|
||||
model: "gpt-3.5-turbo",
|
||||
apitoken: metadata?.apitoken as string,
|
||||
maxtokens: 1000,
|
||||
timeout: 10,
|
||||
baseurl: metadata?.baseurl as string,
|
||||
};
|
||||
const prompt: Array<OpenAIPromptMessageType> = [
|
||||
{
|
||||
role: "user",
|
||||
content: text,
|
||||
name: (metadata?.name as string) || "user",
|
||||
},
|
||||
];
|
||||
console.log("opts.apitoken:", opts.apitoken);
|
||||
const beMsg: OpenAiStreamRequest = {
|
||||
clientid: clientId,
|
||||
opts: opts,
|
||||
prompt: prompt,
|
||||
};
|
||||
const aiGen = WshServer.RespStreamWaveAi(beMsg);
|
||||
let temp = async () => {
|
||||
let fullMsg = "";
|
||||
for await (const msg of aiGen) {
|
||||
fullMsg += msg.text ?? "";
|
||||
}
|
||||
const response: ChatMessageType = {
|
||||
id: newMessage.id,
|
||||
user: newMessage.user,
|
||||
text: fullMsg,
|
||||
isAssistant: true,
|
||||
};
|
||||
simulateResponse(response);
|
||||
};
|
||||
temp();
|
||||
};
|
||||
|
||||
return {
|
||||
messages,
|
||||
sendMessage,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
function makeWaveAiViewModel(blockId): WaveAiModel {
|
||||
const waveAiModel = new WaveAiModel(blockId);
|
||||
return waveAiModel;
|
||||
}
|
||||
|
||||
const ChatItem = ({ chatItem, itemCount }: ChatItemProps) => {
|
||||
const { isAssistant, text, isError } = chatItem;
|
||||
const senderClassName = isAssistant ? "chat-msg-assistant" : "chat-msg-user";
|
||||
@ -208,8 +357,8 @@ const ChatInput = forwardRef<HTMLTextAreaElement, ChatInputProps>(
|
||||
}
|
||||
);
|
||||
|
||||
const WaveAi = () => {
|
||||
const { messages, sendMessage } = useWaveAi();
|
||||
const WaveAi = ({ model }: { model: WaveAiModel }) => {
|
||||
const { messages, sendMessage } = model.useWaveAi();
|
||||
const waveaiRef = useRef<HTMLDivElement>(null);
|
||||
const chatWindowRef = useRef<HTMLDivElement>(null);
|
||||
const osRef = useRef<OverlayScrollbarsComponentRef>(null);
|
||||
@ -407,4 +556,4 @@ const WaveAi = () => {
|
||||
);
|
||||
};
|
||||
|
||||
export { WaveAi };
|
||||
export { WaveAi, makeWaveAiViewModel };
|
||||
|
43
frontend/types/gotypes.d.ts
vendored
43
frontend/types/gotypes.d.ts
vendored
@ -187,6 +187,49 @@ declare global {
|
||||
// waveobj.ORef
|
||||
type ORef = string;
|
||||
|
||||
// waveai.OpenAIOptsType
|
||||
type OpenAIOptsType = {
|
||||
model: string;
|
||||
apitoken: string;
|
||||
baseurl?: string;
|
||||
maxtokens?: number;
|
||||
maxchoices?: number;
|
||||
timeout?: number;
|
||||
};
|
||||
|
||||
// waveai.OpenAIPacketType
|
||||
type OpenAIPacketType = {
|
||||
type: string;
|
||||
model?: string;
|
||||
created?: number;
|
||||
finish_reason?: string;
|
||||
usage?: OpenAIUsageType;
|
||||
index?: number;
|
||||
text?: string;
|
||||
error?: string;
|
||||
};
|
||||
|
||||
// waveai.OpenAIPromptMessageType
|
||||
type OpenAIPromptMessageType = {
|
||||
role: string;
|
||||
content: string;
|
||||
name?: string;
|
||||
};
|
||||
|
||||
// waveai.OpenAIUsageType
|
||||
type OpenAIUsageType = {
|
||||
prompt_tokens?: number;
|
||||
completion_tokens?: number;
|
||||
total_tokens?: number;
|
||||
};
|
||||
|
||||
// waveai.OpenAiStreamRequest
|
||||
type OpenAiStreamRequest = {
|
||||
clientid?: string;
|
||||
opts: OpenAIOptsType;
|
||||
prompt: OpenAIPromptMessageType[];
|
||||
};
|
||||
|
||||
// wstore.Point
|
||||
type Point = {
|
||||
x: number;
|
||||
|
1
go.mod
1
go.mod
@ -15,6 +15,7 @@ require (
|
||||
github.com/kevinburke/ssh_config v1.2.0
|
||||
github.com/mattn/go-sqlite3 v1.14.22
|
||||
github.com/mitchellh/mapstructure v1.5.0
|
||||
github.com/sashabaranov/go-openai v1.27.0
|
||||
github.com/sawka/txwrap v0.2.0
|
||||
github.com/spf13/cobra v1.8.1
|
||||
github.com/wavetermdev/htmltoken v0.1.0
|
||||
|
2
go.sum
2
go.sum
@ -42,6 +42,8 @@ github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RR
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/sashabaranov/go-openai v1.27.0 h1:L3hO6650YUbKrbGUC6yCjsUluhKZ9h1/jcgbTItI8Mo=
|
||||
github.com/sashabaranov/go-openai v1.27.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||
github.com/sawka/txwrap v0.2.0 h1:V3LfvKVLULxcYSxdMguLwFyQFMEU9nFDJopg0ZkL+94=
|
||||
github.com/sawka/txwrap v0.2.0/go.mod h1:wwQ2SQiN4U+6DU/iVPhbvr7OzXAtgZlQCIGuvOswEfA=
|
||||
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
|
||||
|
306
pkg/waveai/waveai.go
Normal file
306
pkg/waveai/waveai.go
Normal file
@ -0,0 +1,306 @@
|
||||
// Copyright 2024, Command Line Inc.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package waveai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
openaiapi "github.com/sashabaranov/go-openai"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wavebase"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
const OpenAIPacketStr = "openai"
|
||||
const OpenAICloudReqStr = "openai-cloudreq"
|
||||
const PacketEOFStr = "EOF"
|
||||
|
||||
type OpenAIUsageType struct {
|
||||
PromptTokens int `json:"prompt_tokens,omitempty"`
|
||||
CompletionTokens int `json:"completion_tokens,omitempty"`
|
||||
TotalTokens int `json:"total_tokens,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAICmdInfoPacketOutputType struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Created int64 `json:"created,omitempty"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAIPacketType struct {
|
||||
Type string `json:"type"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Created int64 `json:"created,omitempty"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
Usage *OpenAIUsageType `json:"usage,omitempty"`
|
||||
Index int `json:"index,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func MakeOpenAIPacket() *OpenAIPacketType {
|
||||
return &OpenAIPacketType{Type: OpenAIPacketStr}
|
||||
}
|
||||
|
||||
type OpenAICmdInfoChatMessage struct {
|
||||
MessageID int `json:"messageid"`
|
||||
IsAssistantResponse bool `json:"isassistantresponse,omitempty"`
|
||||
AssistantResponse *OpenAICmdInfoPacketOutputType `json:"assistantresponse,omitempty"`
|
||||
UserQuery string `json:"userquery,omitempty"`
|
||||
UserEngineeredQuery string `json:"userengineeredquery,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAIPromptMessageType struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAICloudReqPacketType struct {
|
||||
Type string `json:"type"`
|
||||
ClientId string `json:"clientid"`
|
||||
Prompt []OpenAIPromptMessageType `json:"prompt"`
|
||||
MaxTokens int `json:"maxtokens,omitempty"`
|
||||
MaxChoices int `json:"maxchoices,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAIOptsType struct {
|
||||
Model string `json:"model"`
|
||||
APIToken string `json:"apitoken"`
|
||||
BaseURL string `json:"baseurl,omitempty"`
|
||||
MaxTokens int `json:"maxtokens,omitempty"`
|
||||
MaxChoices int `json:"maxchoices,omitempty"`
|
||||
Timeout int `json:"timeout,omitempty"`
|
||||
}
|
||||
|
||||
func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType {
|
||||
return &OpenAICloudReqPacketType{
|
||||
Type: OpenAICloudReqStr,
|
||||
}
|
||||
}
|
||||
|
||||
type OpenAiStreamRequest struct {
|
||||
ClientId string `json:"clientid,omitempty"`
|
||||
Opts *OpenAIOptsType `json:"opts"`
|
||||
Prompt []OpenAIPromptMessageType `json:"prompt"`
|
||||
}
|
||||
|
||||
func GetWSEndpoint() string {
|
||||
return PCloudWSEndpoint
|
||||
if !wavebase.IsDevMode() {
|
||||
return PCloudWSEndpoint
|
||||
} else {
|
||||
endpoint := os.Getenv(PCloudWSEndpointVarName)
|
||||
if endpoint == "" {
|
||||
panic("Invalid PCloud ws dev endpoint, PCLOUD_WS_ENDPOINT not set or invalid")
|
||||
}
|
||||
return endpoint
|
||||
}
|
||||
}
|
||||
|
||||
const DefaultMaxTokens = 1000
|
||||
const DefaultModel = "gpt-3.5-turbo"
|
||||
const DefaultStreamChanSize = 10
|
||||
const PCloudWSEndpoint = "wss://wsapi.waveterm.dev/"
|
||||
const PCloudWSEndpointVarName = "PCLOUD_WS_ENDPOINT"
|
||||
|
||||
const CloudWebsocketConnectTimeout = 1 * time.Minute
|
||||
|
||||
func convertUsage(resp openaiapi.ChatCompletionResponse) *OpenAIUsageType {
|
||||
if resp.Usage.TotalTokens == 0 {
|
||||
return nil
|
||||
}
|
||||
return &OpenAIUsageType{
|
||||
PromptTokens: resp.Usage.PromptTokens,
|
||||
CompletionTokens: resp.Usage.CompletionTokens,
|
||||
TotalTokens: resp.Usage.TotalTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func ConvertPrompt(prompt []OpenAIPromptMessageType) []openaiapi.ChatCompletionMessage {
|
||||
var rtn []openaiapi.ChatCompletionMessage
|
||||
for _, p := range prompt {
|
||||
msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name}
|
||||
rtn = append(rtn, msg)
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func RunCloudCompletionStream(ctx context.Context, request OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[OpenAIPacketType] {
|
||||
rtn := make(chan wshrpc.RespOrErrorUnion[OpenAIPacketType])
|
||||
go func() {
|
||||
log.Printf("start: %v", request)
|
||||
defer close(rtn)
|
||||
if request.Opts == nil {
|
||||
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("no openai opts found")}
|
||||
return
|
||||
}
|
||||
websocketContext, dialCancelFn := context.WithTimeout(context.Background(), CloudWebsocketConnectTimeout)
|
||||
defer dialCancelFn()
|
||||
conn, _, err := websocket.DefaultDialer.DialContext(websocketContext, GetWSEndpoint(), nil)
|
||||
defer func() {
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("unable to close openai channel: %v", err)}
|
||||
}
|
||||
}()
|
||||
if err == context.DeadlineExceeded {
|
||||
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, timed out connecting to cloud server: %v", err)}
|
||||
return
|
||||
} else if err != nil {
|
||||
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket connect error: %v", err)}
|
||||
return
|
||||
}
|
||||
reqPk := MakeOpenAICloudReqPacket()
|
||||
reqPk.ClientId = request.ClientId
|
||||
reqPk.Prompt = request.Prompt
|
||||
reqPk.MaxTokens = request.Opts.MaxTokens
|
||||
reqPk.MaxChoices = request.Opts.MaxChoices
|
||||
configMessageBuf, err := json.Marshal(reqPk)
|
||||
if err != nil {
|
||||
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, packet marshal error: %v", err)}
|
||||
return
|
||||
}
|
||||
err = conn.WriteMessage(websocket.TextMessage, configMessageBuf)
|
||||
if err != nil {
|
||||
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket write config error: %v", err)}
|
||||
return
|
||||
}
|
||||
for {
|
||||
log.Printf("loop")
|
||||
_, socketMessage, err := conn.ReadMessage()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("err received: %v", err)
|
||||
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket error reading message: %v", err)}
|
||||
break
|
||||
}
|
||||
var streamResp *OpenAIPacketType
|
||||
err = json.Unmarshal(socketMessage, &streamResp)
|
||||
log.Printf("ai resp: %v", streamResp)
|
||||
if err != nil {
|
||||
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket response json decode error: %v", err)}
|
||||
break
|
||||
}
|
||||
if streamResp.Error == PacketEOFStr {
|
||||
// got eof packet from socket
|
||||
break
|
||||
} else if streamResp.Error != "" {
|
||||
// use error from server directly
|
||||
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("%v", streamResp.Error)}
|
||||
break
|
||||
}
|
||||
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Response: *streamResp}
|
||||
}
|
||||
}()
|
||||
return rtn
|
||||
}
|
||||
|
||||
func RunLocalCompletionStream(ctx context.Context, request OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[OpenAIPacketType] {
|
||||
rtn := make(chan wshrpc.RespOrErrorUnion[OpenAIPacketType])
|
||||
go func() {
|
||||
log.Printf("start2: %v", request)
|
||||
defer close(rtn)
|
||||
if request.Opts == nil {
|
||||
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("no openai opts found")}
|
||||
return
|
||||
}
|
||||
if request.Opts.Model == "" {
|
||||
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("no openai model specified")}
|
||||
return
|
||||
}
|
||||
if request.Opts.BaseURL == "" && request.Opts.APIToken == "" {
|
||||
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("no api token")}
|
||||
return
|
||||
}
|
||||
clientConfig := openaiapi.DefaultConfig(request.Opts.APIToken)
|
||||
if request.Opts.BaseURL != "" {
|
||||
clientConfig.BaseURL = request.Opts.BaseURL
|
||||
}
|
||||
client := openaiapi.NewClientWithConfig(clientConfig)
|
||||
req := openaiapi.ChatCompletionRequest{
|
||||
Model: request.Opts.Model,
|
||||
Messages: ConvertPrompt(request.Prompt),
|
||||
MaxTokens: request.Opts.MaxTokens,
|
||||
Stream: true,
|
||||
}
|
||||
if request.Opts.MaxChoices > 1 {
|
||||
req.N = request.Opts.MaxChoices
|
||||
}
|
||||
apiResp, err := client.CreateChatCompletionStream(ctx, req)
|
||||
if err != nil {
|
||||
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("error calling openai API: %v", err)}
|
||||
return
|
||||
}
|
||||
sentHeader := false
|
||||
for {
|
||||
log.Printf("loop2")
|
||||
streamResp, err := apiResp.Recv()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("err received2: %v", err)
|
||||
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Error: fmt.Errorf("OpenAI request, websocket error reading message: %v", err)}
|
||||
break
|
||||
}
|
||||
if streamResp.Model != "" && !sentHeader {
|
||||
pk := MakeOpenAIPacket()
|
||||
pk.Model = streamResp.Model
|
||||
pk.Created = streamResp.Created
|
||||
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Response: *pk}
|
||||
sentHeader = true
|
||||
}
|
||||
for _, choice := range streamResp.Choices {
|
||||
pk := MakeOpenAIPacket()
|
||||
pk.Index = choice.Index
|
||||
pk.Text = choice.Delta.Content
|
||||
pk.FinishReason = string(choice.FinishReason)
|
||||
rtn <- wshrpc.RespOrErrorUnion[OpenAIPacketType]{Response: *pk}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return rtn
|
||||
}
|
||||
|
||||
func marshalResponse(resp openaiapi.ChatCompletionResponse) []*OpenAIPacketType {
|
||||
var rtn []*OpenAIPacketType
|
||||
headerPk := MakeOpenAIPacket()
|
||||
headerPk.Model = resp.Model
|
||||
headerPk.Created = resp.Created
|
||||
headerPk.Usage = convertUsage(resp)
|
||||
rtn = append(rtn, headerPk)
|
||||
for _, choice := range resp.Choices {
|
||||
choicePk := MakeOpenAIPacket()
|
||||
choicePk.Index = choice.Index
|
||||
choicePk.Text = choice.Message.Content
|
||||
choicePk.FinishReason = string(choice.FinishReason)
|
||||
rtn = append(rtn, choicePk)
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func CreateErrorPacket(errStr string) *OpenAIPacketType {
|
||||
errPk := MakeOpenAIPacket()
|
||||
errPk.FinishReason = "error"
|
||||
errPk.Error = errStr
|
||||
return errPk
|
||||
}
|
||||
|
||||
func CreateTextPacket(text string) *OpenAIPacketType {
|
||||
pk := MakeOpenAIPacket()
|
||||
pk.Text = text
|
||||
return pk
|
||||
}
|
@ -4,6 +4,7 @@
|
||||
package wconfig
|
||||
|
||||
import (
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/wavetermdev/thenextwave/pkg/wavebase"
|
||||
@ -141,6 +142,13 @@ func applyDefaultSettings(settings *SettingsConfigType) {
|
||||
IntervalMs: 3600000,
|
||||
}
|
||||
}
|
||||
var userName string
|
||||
currentUser, err := user.Current()
|
||||
if err != nil {
|
||||
userName = "user"
|
||||
} else {
|
||||
userName = currentUser.Username
|
||||
}
|
||||
defaultWidgets := []WidgetsConfigType{
|
||||
{
|
||||
Icon: "files",
|
||||
@ -170,6 +178,7 @@ func applyDefaultSettings(settings *SettingsConfigType) {
|
||||
Label: "waveai",
|
||||
BlockDef: wstore.BlockDef{
|
||||
View: "waveai",
|
||||
Meta: map[string]any{"name": userName, "baseurl": "", "apitoken": ""},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshutil"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
|
||||
"github.com/wavetermdev/thenextwave/pkg/waveobj"
|
||||
"github.com/wavetermdev/thenextwave/pkg/waveai"
|
||||
)
|
||||
|
||||
// command "controller:input", wshserver.BlockInputCommand
|
||||
@ -89,6 +90,11 @@ func BlockSetViewCommand(w *wshutil.WshRpc, data wshrpc.CommandBlockSetViewData,
|
||||
return err
|
||||
}
|
||||
|
||||
// command "stream:waveai", wshserver.RespStreamWaveAi
|
||||
func RespStreamWaveAi(w *wshutil.WshRpc, data waveai.OpenAiStreamRequest, opts *wshrpc.WshRpcCommandOpts) chan wshrpc.RespOrErrorUnion[waveai.OpenAIPacketType] {
|
||||
return sendRpcRequestResponseStreamHelper[waveai.OpenAIPacketType](w, "stream:waveai", data, opts)
|
||||
}
|
||||
|
||||
// command "streamtest", wshserver.RespStreamTest
|
||||
func RespStreamTest(w *wshutil.WshRpc, opts *wshrpc.WshRpcCommandOpts) chan wshrpc.RespOrErrorUnion[int] {
|
||||
return sendRpcRequestResponseStreamHelper[int](w, "streamtest", nil, opts)
|
||||
|
@ -28,6 +28,7 @@ const (
|
||||
Command_DeleteBlock = "deleteblock"
|
||||
Command_WriteFile = "file:write"
|
||||
Command_ReadFile = "file:read"
|
||||
Command_StreamWaveAi = "stream:waveai"
|
||||
)
|
||||
|
||||
type MetaDataType = map[string]any
|
||||
|
@ -18,6 +18,7 @@ import (
|
||||
"github.com/wavetermdev/thenextwave/pkg/blockcontroller"
|
||||
"github.com/wavetermdev/thenextwave/pkg/eventbus"
|
||||
"github.com/wavetermdev/thenextwave/pkg/filestore"
|
||||
"github.com/wavetermdev/thenextwave/pkg/waveai"
|
||||
"github.com/wavetermdev/thenextwave/pkg/waveobj"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshrpc"
|
||||
"github.com/wavetermdev/thenextwave/pkg/wshutil"
|
||||
@ -33,6 +34,15 @@ var RespStreamTest_MethodDecl = &WshServerMethodDecl{
|
||||
DefaultResponseDataType: reflect.TypeOf((int)(0)),
|
||||
}
|
||||
|
||||
var RespStreamWaveAi_MethodDecl = &WshServerMethodDecl{
|
||||
Command: wshrpc.Command_StreamWaveAi,
|
||||
CommandType: wshutil.RpcType_ResponseStream,
|
||||
MethodName: "RespStreamWaveAi",
|
||||
Method: reflect.ValueOf(WshServerImpl.RespStreamWaveAi),
|
||||
CommandDataType: reflect.TypeOf(waveai.OpenAiStreamRequest{}),
|
||||
DefaultResponseDataType: reflect.TypeOf(waveai.OpenAIPacketType{}),
|
||||
}
|
||||
|
||||
var WshServerCommandToDeclMap = map[string]*WshServerMethodDecl{
|
||||
wshrpc.Command_Message: GetWshServerMethod(wshrpc.Command_Message, wshutil.RpcType_Call, "MessageCommand", WshServerImpl.MessageCommand),
|
||||
wshrpc.Command_SetView: GetWshServerMethod(wshrpc.Command_SetView, wshutil.RpcType_Call, "BlockSetViewCommand", WshServerImpl.BlockSetViewCommand),
|
||||
@ -47,6 +57,7 @@ var WshServerCommandToDeclMap = map[string]*WshServerMethodDecl{
|
||||
wshrpc.Command_DeleteBlock: GetWshServerMethod(wshrpc.Command_DeleteBlock, wshutil.RpcType_Call, "DeleteBlockCommand", WshServerImpl.DeleteBlockCommand),
|
||||
wshrpc.Command_WriteFile: GetWshServerMethod(wshrpc.Command_WriteFile, wshutil.RpcType_Call, "WriteFile", WshServerImpl.WriteFile),
|
||||
wshrpc.Command_ReadFile: GetWshServerMethod(wshrpc.Command_ReadFile, wshutil.RpcType_Call, "ReadFile", WshServerImpl.ReadFile),
|
||||
wshrpc.Command_StreamWaveAi: RespStreamWaveAi_MethodDecl,
|
||||
"streamtest": RespStreamTest_MethodDecl,
|
||||
}
|
||||
|
||||
@ -69,6 +80,13 @@ func (ws *WshServer) RespStreamTest(ctx context.Context) chan wshrpc.RespOrError
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (ws *WshServer) RespStreamWaveAi(ctx context.Context, request waveai.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[waveai.OpenAIPacketType] {
|
||||
if request.Opts.BaseURL == "" && request.Opts.APIToken == "" {
|
||||
return waveai.RunCloudCompletionStream(ctx, request)
|
||||
}
|
||||
return waveai.RunLocalCompletionStream(ctx, request)
|
||||
}
|
||||
|
||||
func (ws *WshServer) GetMetaCommand(ctx context.Context, data wshrpc.CommandGetMetaData) (wshrpc.MetaDataType, error) {
|
||||
log.Printf("calling meta: %s\n", data.ORef)
|
||||
obj, err := wstore.DBGetORef(ctx, data.ORef)
|
||||
|
Loading…
Reference in New Issue
Block a user