fix: add error messages to ai chat (#999)

This will print error messages to the chat when there is an error
getting an ai response. The actual content of the responses are not
forwarded to the models in future requests.

<img width="389" alt="Screenshot 2024-10-09 at 2 36 13 PM"
src="https://github.com/user-attachments/assets/e6c6b1c1-fa19-4456-be3b-596feaeaafed">
This commit is contained in:
Sylvie Crowe 2024-10-09 14:50:56 -07:00 committed by GitHub
parent ad3166a2c9
commit 9dd4188810
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 98 additions and 46 deletions

View File

@ -191,7 +191,7 @@
background-color: transparent; background-color: transparent;
outline: none; outline: none;
border: none; border: none;
color: var(--app-text-color); color: var(--main-text-color);
width: 100%; width: 100%;
white-space: nowrap; white-space: nowrap;
overflow: hidden; overflow: hidden;

View File

@ -51,7 +51,7 @@
} }
&.chat-msg-assistant { &.chat-msg-assistant {
color: var(--app-text-color); color: var(--main-text-color);
background-color: rgb(from var(--highlight-bg-color) r g b / 0.1); background-color: rgb(from var(--highlight-bg-color) r g b / 0.1);
margin-right: auto; margin-right: auto;
padding: 10px; padding: 10px;
@ -77,9 +77,23 @@
} }
&.chat-msg-error { &.chat-msg-error {
color: var(--cmdinput-text-error); color: var(--main-text-color);
font-family: var(--markdown-font); background-color: rgb(from var(--error-color) r g b / 0.25);
font-size: 14px; margin-right: auto;
padding: 10px;
max-width: 85%;
.markdown {
width: 100%;
pre {
white-space: pre-wrap;
word-break: break-word;
max-width: 100%;
overflow-x: auto;
margin-left: 0;
}
}
} }
&.typing-indicator { &.typing-indicator {

View File

@ -20,9 +20,7 @@ interface ChatMessageType {
id: string; id: string;
user: string; user: string;
text: string; text: string;
isAssistant: boolean;
isUpdating?: boolean; isUpdating?: boolean;
isError?: string;
} }
const outline = "2px solid var(--accent-color)"; const outline = "2px solid var(--accent-color)";
@ -36,7 +34,6 @@ function promptToMsg(prompt: OpenAIPromptMessageType): ChatMessageType {
id: crypto.randomUUID(), id: crypto.randomUUID(),
user: prompt.role, user: prompt.role,
text: prompt.content, text: prompt.content,
isAssistant: prompt.role == "assistant",
}; };
} }
@ -78,7 +75,7 @@ export class WaveAiModel implements ViewModel {
this.updateLastMessageAtom = atom(null, (get, set, text: string, isUpdating: boolean) => { this.updateLastMessageAtom = atom(null, (get, set, text: string, isUpdating: boolean) => {
const messages = get(this.messagesAtom); const messages = get(this.messagesAtom);
const lastMessage = messages[messages.length - 1]; const lastMessage = messages[messages.length - 1];
if (lastMessage.isAssistant && !lastMessage.isError) { if (lastMessage.user == "assistant") {
const updatedMessage = { ...lastMessage, text: lastMessage.text + text, isUpdating }; const updatedMessage = { ...lastMessage, text: lastMessage.text + text, isUpdating };
set(this.messagesAtom, [...messages.slice(0, -1), updatedMessage]); set(this.messagesAtom, [...messages.slice(0, -1), updatedMessage]);
} }
@ -94,7 +91,6 @@ export class WaveAiModel implements ViewModel {
id: crypto.randomUUID(), id: crypto.randomUUID(),
user: "assistant", user: "assistant",
text: "", text: "",
isAssistant: true,
}; };
// Add a typing indicator // Add a typing indicator
@ -190,7 +186,6 @@ export class WaveAiModel implements ViewModel {
id: crypto.randomUUID(), id: crypto.randomUUID(),
user, user,
text, text,
isAssistant: false,
}; };
addMessage(newMessage); addMessage(newMessage);
// send message to backend and get response // send message to backend and get response
@ -214,7 +209,6 @@ export class WaveAiModel implements ViewModel {
id: crypto.randomUUID(), id: crypto.randomUUID(),
user: "assistant", user: "assistant",
text: "", text: "",
isAssistant: true,
}; };
// Add a typing indicator // Add a typing indicator
@ -225,8 +219,9 @@ export class WaveAiModel implements ViewModel {
opts: opts, opts: opts,
prompt: [...history, newPrompt], prompt: [...history, newPrompt],
}; };
const aiGen = RpcApi.StreamWaveAiCommand(WindowRpcClient, beMsg, { timeout: opts.timeoutms });
let fullMsg = ""; let fullMsg = "";
try {
const aiGen = RpcApi.StreamWaveAiCommand(WindowRpcClient, beMsg, { timeout: opts.timeoutms });
for await (const msg of aiGen) { for await (const msg of aiGen) {
fullMsg += msg.text ?? ""; fullMsg += msg.text ?? "";
globalStore.set(this.updateLastMessageAtom, msg.text ?? "", true); globalStore.set(this.updateLastMessageAtom, msg.text ?? "", true);
@ -236,7 +231,6 @@ export class WaveAiModel implements ViewModel {
} }
break; break;
} }
}
globalStore.set(this.updateLastMessageAtom, "", false); globalStore.set(this.updateLastMessageAtom, "", false);
if (fullMsg != "") { if (fullMsg != "") {
const responsePrompt: OpenAIPromptMessageType = { const responsePrompt: OpenAIPromptMessageType = {
@ -245,6 +239,35 @@ export class WaveAiModel implements ViewModel {
}; };
await BlockService.SaveWaveAiData(blockId, [...history, newPrompt, responsePrompt]); await BlockService.SaveWaveAiData(blockId, [...history, newPrompt, responsePrompt]);
} }
}
} catch (error) {
const updatedHist = [...history, newPrompt];
if (fullMsg == "") {
globalStore.set(this.removeLastMessageAtom);
} else {
globalStore.set(this.updateLastMessageAtom, "", false);
const responsePrompt: OpenAIPromptMessageType = {
role: "assistant",
content: fullMsg,
};
updatedHist.push(responsePrompt);
}
const errMsg: string = (error as Error).message;
const errorMessage: ChatMessageType = {
id: crypto.randomUUID(),
user: "error",
text: errMsg,
};
globalStore.set(this.addMessageAtom, errorMessage);
globalStore.set(this.updateLastMessageAtom, "", false);
const errorPrompt: OpenAIPromptMessageType = {
role: "error",
content: errMsg,
};
updatedHist.push(errorPrompt);
console.log(updatedHist);
await BlockService.SaveWaveAiData(blockId, updatedHist);
}
setLocked(false); setLocked(false);
this.cancel = false; this.cancel = false;
}; };
@ -264,17 +287,26 @@ function makeWaveAiViewModel(blockId): WaveAiModel {
} }
const ChatItem = ({ chatItem }: ChatItemProps) => { const ChatItem = ({ chatItem }: ChatItemProps) => {
const { isAssistant, text, isError } = chatItem; const { user, text } = chatItem;
const cssVar = "--panel-bg-color"; const cssVar = "--panel-bg-color";
const panelBgColor = getComputedStyle(document.documentElement).getPropertyValue(cssVar).trim(); const panelBgColor = getComputedStyle(document.documentElement).getPropertyValue(cssVar).trim();
const renderError = (err: string): React.JSX.Element => <div className="chat-msg-error">{err}</div>;
const renderContent = useMemo(() => { const renderContent = useMemo(() => {
if (isAssistant) { if (user == "error") {
if (isError) { return (
return renderError(isError); <>
<div className="chat-msg chat-msg-header">
<div className="icon-box">
<i className="fa-sharp fa-solid fa-circle-exclamation"></i>
</div>
</div>
<div className="chat-msg chat-msg-error">
<Markdown text={text} scrollable={false} />
</div>
</>
);
} }
if (user == "assistant") {
return text ? ( return text ? (
<> <>
<div className="chat-msg chat-msg-header"> <div className="chat-msg chat-msg-header">
@ -302,7 +334,7 @@ const ChatItem = ({ chatItem }: ChatItemProps) => {
</div> </div>
</> </>
); );
}, [text, isAssistant, isError]); }, [text, user]);
return <div className={"chat-msg-container"}>{renderContent}</div>; return <div className={"chat-msg-container"}>{renderContent}</div>;
}; };

View File

@ -14,7 +14,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/sashabaranov/go-openai"
openaiapi "github.com/sashabaranov/go-openai" openaiapi "github.com/sashabaranov/go-openai"
"github.com/wavetermdev/waveterm/pkg/wavebase" "github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/wcloud" "github.com/wavetermdev/waveterm/pkg/wcloud"
@ -150,9 +149,16 @@ func RunCloudCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRe
rtn <- makeAIError(fmt.Errorf("unable to close openai channel: %v", err)) rtn <- makeAIError(fmt.Errorf("unable to close openai channel: %v", err))
} }
}() }()
var sendablePromptMsgs []wshrpc.OpenAIPromptMessageType
for _, promptMsg := range request.Prompt {
if promptMsg.Role == "error" {
continue
}
sendablePromptMsgs = append(sendablePromptMsgs, promptMsg)
}
reqPk := MakeOpenAICloudReqPacket() reqPk := MakeOpenAICloudReqPacket()
reqPk.ClientId = request.ClientId reqPk.ClientId = request.ClientId
reqPk.Prompt = request.Prompt reqPk.Prompt = sendablePromptMsgs
reqPk.MaxTokens = request.Opts.MaxTokens reqPk.MaxTokens = request.Opts.MaxTokens
reqPk.MaxChoices = request.Opts.MaxChoices reqPk.MaxChoices = request.Opts.MaxChoices
configMessageBuf, err := json.Marshal(reqPk) configMessageBuf, err := json.Marshal(reqPk)
@ -200,23 +206,23 @@ func defaultAzureMapperFn(model string) string {
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "") return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
} }
func setApiType(opts *wshrpc.OpenAIOptsType, clientConfig *openai.ClientConfig) error { func setApiType(opts *wshrpc.OpenAIOptsType, clientConfig *openaiapi.ClientConfig) error {
ourApiType := strings.ToLower(opts.APIType) ourApiType := strings.ToLower(opts.APIType)
if ourApiType == "" || ourApiType == strings.ToLower(string(openai.APITypeOpenAI)) { if ourApiType == "" || ourApiType == strings.ToLower(string(openaiapi.APITypeOpenAI)) {
clientConfig.APIType = openai.APITypeOpenAI clientConfig.APIType = openaiapi.APITypeOpenAI
return nil return nil
} else if ourApiType == strings.ToLower(string(openai.APITypeAzure)) { } else if ourApiType == strings.ToLower(string(openaiapi.APITypeAzure)) {
clientConfig.APIType = openai.APITypeAzure clientConfig.APIType = openaiapi.APITypeAzure
clientConfig.APIVersion = DefaultAzureAPIVersion clientConfig.APIVersion = DefaultAzureAPIVersion
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
return nil return nil
} else if ourApiType == strings.ToLower(string(openai.APITypeAzureAD)) { } else if ourApiType == strings.ToLower(string(openaiapi.APITypeAzureAD)) {
clientConfig.APIType = openai.APITypeAzureAD clientConfig.APIType = openaiapi.APITypeAzureAD
clientConfig.APIVersion = DefaultAzureAPIVersion clientConfig.APIVersion = DefaultAzureAPIVersion
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
return nil return nil
} else if ourApiType == strings.ToLower(string(openai.APITypeCloudflareAzure)) { } else if ourApiType == strings.ToLower(string(openaiapi.APITypeCloudflareAzure)) {
clientConfig.APIType = openai.APITypeCloudflareAzure clientConfig.APIType = openaiapi.APITypeCloudflareAzure
clientConfig.APIVersion = DefaultAzureAPIVersion clientConfig.APIVersion = DefaultAzureAPIVersion
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
return nil return nil