fix: add error messages to ai chat (#999)

This will print error messages to the chat when there is an error
getting an ai response. The actual content of the responses are not
forwarded to the models in future requests.

<img width="389" alt="Screenshot 2024-10-09 at 2 36 13 PM"
src="https://github.com/user-attachments/assets/e6c6b1c1-fa19-4456-be3b-596feaeaafed">
This commit is contained in:
Sylvie Crowe 2024-10-09 14:50:56 -07:00 committed by GitHub
parent ad3166a2c9
commit 9dd4188810
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 98 additions and 46 deletions

View File

@ -191,7 +191,7 @@
background-color: transparent;
outline: none;
border: none;
color: var(--app-text-color);
color: var(--main-text-color);
width: 100%;
white-space: nowrap;
overflow: hidden;

View File

@ -51,7 +51,7 @@
}
&.chat-msg-assistant {
color: var(--app-text-color);
color: var(--main-text-color);
background-color: rgb(from var(--highlight-bg-color) r g b / 0.1);
margin-right: auto;
padding: 10px;
@ -77,9 +77,23 @@
}
&.chat-msg-error {
color: var(--cmdinput-text-error);
font-family: var(--markdown-font);
font-size: 14px;
color: var(--main-text-color);
background-color: rgb(from var(--error-color) r g b / 0.25);
margin-right: auto;
padding: 10px;
max-width: 85%;
.markdown {
width: 100%;
pre {
white-space: pre-wrap;
word-break: break-word;
max-width: 100%;
overflow-x: auto;
margin-left: 0;
}
}
}
&.typing-indicator {

View File

@ -20,9 +20,7 @@ interface ChatMessageType {
id: string;
user: string;
text: string;
isAssistant: boolean;
isUpdating?: boolean;
isError?: string;
}
const outline = "2px solid var(--accent-color)";
@ -36,7 +34,6 @@ function promptToMsg(prompt: OpenAIPromptMessageType): ChatMessageType {
id: crypto.randomUUID(),
user: prompt.role,
text: prompt.content,
isAssistant: prompt.role == "assistant",
};
}
@ -78,7 +75,7 @@ export class WaveAiModel implements ViewModel {
this.updateLastMessageAtom = atom(null, (get, set, text: string, isUpdating: boolean) => {
const messages = get(this.messagesAtom);
const lastMessage = messages[messages.length - 1];
if (lastMessage.isAssistant && !lastMessage.isError) {
if (lastMessage.user == "assistant") {
const updatedMessage = { ...lastMessage, text: lastMessage.text + text, isUpdating };
set(this.messagesAtom, [...messages.slice(0, -1), updatedMessage]);
}
@ -94,7 +91,6 @@ export class WaveAiModel implements ViewModel {
id: crypto.randomUUID(),
user: "assistant",
text: "",
isAssistant: true,
};
// Add a typing indicator
@ -190,7 +186,6 @@ export class WaveAiModel implements ViewModel {
id: crypto.randomUUID(),
user,
text,
isAssistant: false,
};
addMessage(newMessage);
// send message to backend and get response
@ -214,7 +209,6 @@ export class WaveAiModel implements ViewModel {
id: crypto.randomUUID(),
user: "assistant",
text: "",
isAssistant: true,
};
// Add a typing indicator
@ -225,25 +219,54 @@ export class WaveAiModel implements ViewModel {
opts: opts,
prompt: [...history, newPrompt],
};
const aiGen = RpcApi.StreamWaveAiCommand(WindowRpcClient, beMsg, { timeout: opts.timeoutms });
let fullMsg = "";
for await (const msg of aiGen) {
fullMsg += msg.text ?? "";
globalStore.set(this.updateLastMessageAtom, msg.text ?? "", true);
if (this.cancel) {
if (fullMsg == "") {
globalStore.set(this.removeLastMessageAtom);
try {
const aiGen = RpcApi.StreamWaveAiCommand(WindowRpcClient, beMsg, { timeout: opts.timeoutms });
for await (const msg of aiGen) {
fullMsg += msg.text ?? "";
globalStore.set(this.updateLastMessageAtom, msg.text ?? "", true);
if (this.cancel) {
if (fullMsg == "") {
globalStore.set(this.removeLastMessageAtom);
}
break;
}
globalStore.set(this.updateLastMessageAtom, "", false);
if (fullMsg != "") {
const responsePrompt: OpenAIPromptMessageType = {
role: "assistant",
content: fullMsg,
};
await BlockService.SaveWaveAiData(blockId, [...history, newPrompt, responsePrompt]);
}
break;
}
}
globalStore.set(this.updateLastMessageAtom, "", false);
if (fullMsg != "") {
const responsePrompt: OpenAIPromptMessageType = {
role: "assistant",
content: fullMsg,
} catch (error) {
const updatedHist = [...history, newPrompt];
if (fullMsg == "") {
globalStore.set(this.removeLastMessageAtom);
} else {
globalStore.set(this.updateLastMessageAtom, "", false);
const responsePrompt: OpenAIPromptMessageType = {
role: "assistant",
content: fullMsg,
};
updatedHist.push(responsePrompt);
}
const errMsg: string = (error as Error).message;
const errorMessage: ChatMessageType = {
id: crypto.randomUUID(),
user: "error",
text: errMsg,
};
await BlockService.SaveWaveAiData(blockId, [...history, newPrompt, responsePrompt]);
globalStore.set(this.addMessageAtom, errorMessage);
globalStore.set(this.updateLastMessageAtom, "", false);
const errorPrompt: OpenAIPromptMessageType = {
role: "error",
content: errMsg,
};
updatedHist.push(errorPrompt);
console.log(updatedHist);
await BlockService.SaveWaveAiData(blockId, updatedHist);
}
setLocked(false);
this.cancel = false;
@ -264,17 +287,26 @@ function makeWaveAiViewModel(blockId): WaveAiModel {
}
const ChatItem = ({ chatItem }: ChatItemProps) => {
const { isAssistant, text, isError } = chatItem;
const { user, text } = chatItem;
const cssVar = "--panel-bg-color";
const panelBgColor = getComputedStyle(document.documentElement).getPropertyValue(cssVar).trim();
const renderError = (err: string): React.JSX.Element => <div className="chat-msg-error">{err}</div>;
const renderContent = useMemo(() => {
if (isAssistant) {
if (isError) {
return renderError(isError);
}
if (user == "error") {
return (
<>
<div className="chat-msg chat-msg-header">
<div className="icon-box">
<i className="fa-sharp fa-solid fa-circle-exclamation"></i>
</div>
</div>
<div className="chat-msg chat-msg-error">
<Markdown text={text} scrollable={false} />
</div>
</>
);
}
if (user == "assistant") {
return text ? (
<>
<div className="chat-msg chat-msg-header">
@ -302,7 +334,7 @@ const ChatItem = ({ chatItem }: ChatItemProps) => {
</div>
</>
);
}, [text, isAssistant, isError]);
}, [text, user]);
return <div className={"chat-msg-container"}>{renderContent}</div>;
};

View File

@ -14,7 +14,6 @@ import (
"strings"
"time"
"github.com/sashabaranov/go-openai"
openaiapi "github.com/sashabaranov/go-openai"
"github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/wcloud"
@ -150,9 +149,16 @@ func RunCloudCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRe
rtn <- makeAIError(fmt.Errorf("unable to close openai channel: %v", err))
}
}()
var sendablePromptMsgs []wshrpc.OpenAIPromptMessageType
for _, promptMsg := range request.Prompt {
if promptMsg.Role == "error" {
continue
}
sendablePromptMsgs = append(sendablePromptMsgs, promptMsg)
}
reqPk := MakeOpenAICloudReqPacket()
reqPk.ClientId = request.ClientId
reqPk.Prompt = request.Prompt
reqPk.Prompt = sendablePromptMsgs
reqPk.MaxTokens = request.Opts.MaxTokens
reqPk.MaxChoices = request.Opts.MaxChoices
configMessageBuf, err := json.Marshal(reqPk)
@ -200,23 +206,23 @@ func defaultAzureMapperFn(model string) string {
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
}
func setApiType(opts *wshrpc.OpenAIOptsType, clientConfig *openai.ClientConfig) error {
func setApiType(opts *wshrpc.OpenAIOptsType, clientConfig *openaiapi.ClientConfig) error {
ourApiType := strings.ToLower(opts.APIType)
if ourApiType == "" || ourApiType == strings.ToLower(string(openai.APITypeOpenAI)) {
clientConfig.APIType = openai.APITypeOpenAI
if ourApiType == "" || ourApiType == strings.ToLower(string(openaiapi.APITypeOpenAI)) {
clientConfig.APIType = openaiapi.APITypeOpenAI
return nil
} else if ourApiType == strings.ToLower(string(openai.APITypeAzure)) {
clientConfig.APIType = openai.APITypeAzure
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeAzure)) {
clientConfig.APIType = openaiapi.APITypeAzure
clientConfig.APIVersion = DefaultAzureAPIVersion
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
return nil
} else if ourApiType == strings.ToLower(string(openai.APITypeAzureAD)) {
clientConfig.APIType = openai.APITypeAzureAD
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeAzureAD)) {
clientConfig.APIType = openaiapi.APITypeAzureAD
clientConfig.APIVersion = DefaultAzureAPIVersion
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
return nil
} else if ourApiType == strings.ToLower(string(openai.APITypeCloudflareAzure)) {
clientConfig.APIType = openai.APITypeCloudflareAzure
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeCloudflareAzure)) {
clientConfig.APIType = openaiapi.APITypeCloudflareAzure
clientConfig.APIVersion = DefaultAzureAPIVersion
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
return nil