diff --git a/frontend/app/block/block.less b/frontend/app/block/block.less index 7151084aa..f3e172912 100644 --- a/frontend/app/block/block.less +++ b/frontend/app/block/block.less @@ -191,7 +191,7 @@ background-color: transparent; outline: none; border: none; - color: var(--app-text-color); + color: var(--main-text-color); width: 100%; white-space: nowrap; overflow: hidden; diff --git a/frontend/app/view/waveai/waveai.less b/frontend/app/view/waveai/waveai.less index 93d5911e7..ed2ed177c 100644 --- a/frontend/app/view/waveai/waveai.less +++ b/frontend/app/view/waveai/waveai.less @@ -51,7 +51,7 @@ } &.chat-msg-assistant { - color: var(--app-text-color); + color: var(--main-text-color); background-color: rgb(from var(--highlight-bg-color) r g b / 0.1); margin-right: auto; padding: 10px; @@ -77,9 +77,23 @@ } &.chat-msg-error { - color: var(--cmdinput-text-error); - font-family: var(--markdown-font); - font-size: 14px; + color: var(--main-text-color); + background-color: rgb(from var(--error-color) r g b / 0.25); + margin-right: auto; + padding: 10px; + max-width: 85%; + + .markdown { + width: 100%; + + pre { + white-space: pre-wrap; + word-break: break-word; + max-width: 100%; + overflow-x: auto; + margin-left: 0; + } + } } &.typing-indicator { diff --git a/frontend/app/view/waveai/waveai.tsx b/frontend/app/view/waveai/waveai.tsx index c29458bc2..5519b745d 100644 --- a/frontend/app/view/waveai/waveai.tsx +++ b/frontend/app/view/waveai/waveai.tsx @@ -20,9 +20,7 @@ interface ChatMessageType { id: string; user: string; text: string; - isAssistant: boolean; isUpdating?: boolean; - isError?: string; } const outline = "2px solid var(--accent-color)"; @@ -36,7 +34,6 @@ function promptToMsg(prompt: OpenAIPromptMessageType): ChatMessageType { id: crypto.randomUUID(), user: prompt.role, text: prompt.content, - isAssistant: prompt.role == "assistant", }; } @@ -78,7 +75,7 @@ export class WaveAiModel implements ViewModel { this.updateLastMessageAtom = atom(null, (get, set, text: string, isUpdating: boolean) => { const messages = get(this.messagesAtom); const lastMessage = messages[messages.length - 1]; - if (lastMessage.isAssistant && !lastMessage.isError) { + if (lastMessage.user == "assistant") { const updatedMessage = { ...lastMessage, text: lastMessage.text + text, isUpdating }; set(this.messagesAtom, [...messages.slice(0, -1), updatedMessage]); } @@ -94,7 +91,6 @@ export class WaveAiModel implements ViewModel { id: crypto.randomUUID(), user: "assistant", text: "", - isAssistant: true, }; // Add a typing indicator @@ -190,7 +186,6 @@ export class WaveAiModel implements ViewModel { id: crypto.randomUUID(), user, text, - isAssistant: false, }; addMessage(newMessage); // send message to backend and get response @@ -214,7 +209,6 @@ export class WaveAiModel implements ViewModel { id: crypto.randomUUID(), user: "assistant", text: "", - isAssistant: true, }; // Add a typing indicator @@ -225,25 +219,54 @@ export class WaveAiModel implements ViewModel { opts: opts, prompt: [...history, newPrompt], }; - const aiGen = RpcApi.StreamWaveAiCommand(WindowRpcClient, beMsg, { timeout: opts.timeoutms }); let fullMsg = ""; - for await (const msg of aiGen) { - fullMsg += msg.text ?? ""; - globalStore.set(this.updateLastMessageAtom, msg.text ?? "", true); - if (this.cancel) { - if (fullMsg == "") { - globalStore.set(this.removeLastMessageAtom); + try { + const aiGen = RpcApi.StreamWaveAiCommand(WindowRpcClient, beMsg, { timeout: opts.timeoutms }); + for await (const msg of aiGen) { + fullMsg += msg.text ?? ""; + globalStore.set(this.updateLastMessageAtom, msg.text ?? "", true); + if (this.cancel) { + if (fullMsg == "") { + globalStore.set(this.removeLastMessageAtom); + } + break; + } + globalStore.set(this.updateLastMessageAtom, "", false); + if (fullMsg != "") { + const responsePrompt: OpenAIPromptMessageType = { + role: "assistant", + content: fullMsg, + }; + await BlockService.SaveWaveAiData(blockId, [...history, newPrompt, responsePrompt]); } - break; } - } - globalStore.set(this.updateLastMessageAtom, "", false); - if (fullMsg != "") { - const responsePrompt: OpenAIPromptMessageType = { - role: "assistant", - content: fullMsg, + } catch (error) { + const updatedHist = [...history, newPrompt]; + if (fullMsg == "") { + globalStore.set(this.removeLastMessageAtom); + } else { + globalStore.set(this.updateLastMessageAtom, "", false); + const responsePrompt: OpenAIPromptMessageType = { + role: "assistant", + content: fullMsg, + }; + updatedHist.push(responsePrompt); + } + const errMsg: string = (error as Error).message; + const errorMessage: ChatMessageType = { + id: crypto.randomUUID(), + user: "error", + text: errMsg, }; - await BlockService.SaveWaveAiData(blockId, [...history, newPrompt, responsePrompt]); + globalStore.set(this.addMessageAtom, errorMessage); + globalStore.set(this.updateLastMessageAtom, "", false); + const errorPrompt: OpenAIPromptMessageType = { + role: "error", + content: errMsg, + }; + updatedHist.push(errorPrompt); + console.log(updatedHist); + await BlockService.SaveWaveAiData(blockId, updatedHist); } setLocked(false); this.cancel = false; @@ -264,17 +287,26 @@ function makeWaveAiViewModel(blockId): WaveAiModel { } const ChatItem = ({ chatItem }: ChatItemProps) => { - const { isAssistant, text, isError } = chatItem; + const { user, text } = chatItem; const cssVar = "--panel-bg-color"; const panelBgColor = getComputedStyle(document.documentElement).getPropertyValue(cssVar).trim(); - const renderError = (err: string): React.JSX.Element =>
{err}
; - const renderContent = useMemo(() => { - if (isAssistant) { - if (isError) { - return renderError(isError); - } + if (user == "error") { + return ( + <> +
+
+ +
+
+
+ +
+ + ); + } + if (user == "assistant") { return text ? ( <>
@@ -302,7 +334,7 @@ const ChatItem = ({ chatItem }: ChatItemProps) => {
); - }, [text, isAssistant, isError]); + }, [text, user]); return
{renderContent}
; }; diff --git a/pkg/waveai/waveai.go b/pkg/waveai/waveai.go index e9be4de4f..f22a48101 100644 --- a/pkg/waveai/waveai.go +++ b/pkg/waveai/waveai.go @@ -14,7 +14,6 @@ import ( "strings" "time" - "github.com/sashabaranov/go-openai" openaiapi "github.com/sashabaranov/go-openai" "github.com/wavetermdev/waveterm/pkg/wavebase" "github.com/wavetermdev/waveterm/pkg/wcloud" @@ -150,9 +149,16 @@ func RunCloudCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRe rtn <- makeAIError(fmt.Errorf("unable to close openai channel: %v", err)) } }() + var sendablePromptMsgs []wshrpc.OpenAIPromptMessageType + for _, promptMsg := range request.Prompt { + if promptMsg.Role == "error" { + continue + } + sendablePromptMsgs = append(sendablePromptMsgs, promptMsg) + } reqPk := MakeOpenAICloudReqPacket() reqPk.ClientId = request.ClientId - reqPk.Prompt = request.Prompt + reqPk.Prompt = sendablePromptMsgs reqPk.MaxTokens = request.Opts.MaxTokens reqPk.MaxChoices = request.Opts.MaxChoices configMessageBuf, err := json.Marshal(reqPk) @@ -200,23 +206,23 @@ func defaultAzureMapperFn(model string) string { return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "") } -func setApiType(opts *wshrpc.OpenAIOptsType, clientConfig *openai.ClientConfig) error { +func setApiType(opts *wshrpc.OpenAIOptsType, clientConfig *openaiapi.ClientConfig) error { ourApiType := strings.ToLower(opts.APIType) - if ourApiType == "" || ourApiType == strings.ToLower(string(openai.APITypeOpenAI)) { - clientConfig.APIType = openai.APITypeOpenAI + if ourApiType == "" || ourApiType == strings.ToLower(string(openaiapi.APITypeOpenAI)) { + clientConfig.APIType = openaiapi.APITypeOpenAI return nil - } else if ourApiType == strings.ToLower(string(openai.APITypeAzure)) { - clientConfig.APIType = openai.APITypeAzure + } else if ourApiType == strings.ToLower(string(openaiapi.APITypeAzure)) { + clientConfig.APIType = openaiapi.APITypeAzure clientConfig.APIVersion = DefaultAzureAPIVersion clientConfig.AzureModelMapperFunc = defaultAzureMapperFn return nil - } else if ourApiType == strings.ToLower(string(openai.APITypeAzureAD)) { - clientConfig.APIType = openai.APITypeAzureAD + } else if ourApiType == strings.ToLower(string(openaiapi.APITypeAzureAD)) { + clientConfig.APIType = openaiapi.APITypeAzureAD clientConfig.APIVersion = DefaultAzureAPIVersion clientConfig.AzureModelMapperFunc = defaultAzureMapperFn return nil - } else if ourApiType == strings.ToLower(string(openai.APITypeCloudflareAzure)) { - clientConfig.APIType = openai.APITypeCloudflareAzure + } else if ourApiType == strings.ToLower(string(openaiapi.APITypeCloudflareAzure)) { + clientConfig.APIType = openaiapi.APITypeCloudflareAzure clientConfig.APIVersion = DefaultAzureAPIVersion clientConfig.AzureModelMapperFunc = defaultAzureMapperFn return nil