mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-01-21 21:32:13 +01:00
fix: add error messages to ai chat (#999)
This will print error messages to the chat when there is an error getting an ai response. The actual content of the responses are not forwarded to the models in future requests. <img width="389" alt="Screenshot 2024-10-09 at 2 36 13 PM" src="https://github.com/user-attachments/assets/e6c6b1c1-fa19-4456-be3b-596feaeaafed">
This commit is contained in:
parent
ad3166a2c9
commit
9dd4188810
@ -191,7 +191,7 @@
|
||||
background-color: transparent;
|
||||
outline: none;
|
||||
border: none;
|
||||
color: var(--app-text-color);
|
||||
color: var(--main-text-color);
|
||||
width: 100%;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
|
@ -51,7 +51,7 @@
|
||||
}
|
||||
|
||||
&.chat-msg-assistant {
|
||||
color: var(--app-text-color);
|
||||
color: var(--main-text-color);
|
||||
background-color: rgb(from var(--highlight-bg-color) r g b / 0.1);
|
||||
margin-right: auto;
|
||||
padding: 10px;
|
||||
@ -77,9 +77,23 @@
|
||||
}
|
||||
|
||||
&.chat-msg-error {
|
||||
color: var(--cmdinput-text-error);
|
||||
font-family: var(--markdown-font);
|
||||
font-size: 14px;
|
||||
color: var(--main-text-color);
|
||||
background-color: rgb(from var(--error-color) r g b / 0.25);
|
||||
margin-right: auto;
|
||||
padding: 10px;
|
||||
max-width: 85%;
|
||||
|
||||
.markdown {
|
||||
width: 100%;
|
||||
|
||||
pre {
|
||||
white-space: pre-wrap;
|
||||
word-break: break-word;
|
||||
max-width: 100%;
|
||||
overflow-x: auto;
|
||||
margin-left: 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
&.typing-indicator {
|
||||
|
@ -20,9 +20,7 @@ interface ChatMessageType {
|
||||
id: string;
|
||||
user: string;
|
||||
text: string;
|
||||
isAssistant: boolean;
|
||||
isUpdating?: boolean;
|
||||
isError?: string;
|
||||
}
|
||||
|
||||
const outline = "2px solid var(--accent-color)";
|
||||
@ -36,7 +34,6 @@ function promptToMsg(prompt: OpenAIPromptMessageType): ChatMessageType {
|
||||
id: crypto.randomUUID(),
|
||||
user: prompt.role,
|
||||
text: prompt.content,
|
||||
isAssistant: prompt.role == "assistant",
|
||||
};
|
||||
}
|
||||
|
||||
@ -78,7 +75,7 @@ export class WaveAiModel implements ViewModel {
|
||||
this.updateLastMessageAtom = atom(null, (get, set, text: string, isUpdating: boolean) => {
|
||||
const messages = get(this.messagesAtom);
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
if (lastMessage.isAssistant && !lastMessage.isError) {
|
||||
if (lastMessage.user == "assistant") {
|
||||
const updatedMessage = { ...lastMessage, text: lastMessage.text + text, isUpdating };
|
||||
set(this.messagesAtom, [...messages.slice(0, -1), updatedMessage]);
|
||||
}
|
||||
@ -94,7 +91,6 @@ export class WaveAiModel implements ViewModel {
|
||||
id: crypto.randomUUID(),
|
||||
user: "assistant",
|
||||
text: "",
|
||||
isAssistant: true,
|
||||
};
|
||||
|
||||
// Add a typing indicator
|
||||
@ -190,7 +186,6 @@ export class WaveAiModel implements ViewModel {
|
||||
id: crypto.randomUUID(),
|
||||
user,
|
||||
text,
|
||||
isAssistant: false,
|
||||
};
|
||||
addMessage(newMessage);
|
||||
// send message to backend and get response
|
||||
@ -214,7 +209,6 @@ export class WaveAiModel implements ViewModel {
|
||||
id: crypto.randomUUID(),
|
||||
user: "assistant",
|
||||
text: "",
|
||||
isAssistant: true,
|
||||
};
|
||||
|
||||
// Add a typing indicator
|
||||
@ -225,25 +219,54 @@ export class WaveAiModel implements ViewModel {
|
||||
opts: opts,
|
||||
prompt: [...history, newPrompt],
|
||||
};
|
||||
const aiGen = RpcApi.StreamWaveAiCommand(WindowRpcClient, beMsg, { timeout: opts.timeoutms });
|
||||
let fullMsg = "";
|
||||
for await (const msg of aiGen) {
|
||||
fullMsg += msg.text ?? "";
|
||||
globalStore.set(this.updateLastMessageAtom, msg.text ?? "", true);
|
||||
if (this.cancel) {
|
||||
if (fullMsg == "") {
|
||||
globalStore.set(this.removeLastMessageAtom);
|
||||
try {
|
||||
const aiGen = RpcApi.StreamWaveAiCommand(WindowRpcClient, beMsg, { timeout: opts.timeoutms });
|
||||
for await (const msg of aiGen) {
|
||||
fullMsg += msg.text ?? "";
|
||||
globalStore.set(this.updateLastMessageAtom, msg.text ?? "", true);
|
||||
if (this.cancel) {
|
||||
if (fullMsg == "") {
|
||||
globalStore.set(this.removeLastMessageAtom);
|
||||
}
|
||||
break;
|
||||
}
|
||||
globalStore.set(this.updateLastMessageAtom, "", false);
|
||||
if (fullMsg != "") {
|
||||
const responsePrompt: OpenAIPromptMessageType = {
|
||||
role: "assistant",
|
||||
content: fullMsg,
|
||||
};
|
||||
await BlockService.SaveWaveAiData(blockId, [...history, newPrompt, responsePrompt]);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
globalStore.set(this.updateLastMessageAtom, "", false);
|
||||
if (fullMsg != "") {
|
||||
const responsePrompt: OpenAIPromptMessageType = {
|
||||
role: "assistant",
|
||||
content: fullMsg,
|
||||
} catch (error) {
|
||||
const updatedHist = [...history, newPrompt];
|
||||
if (fullMsg == "") {
|
||||
globalStore.set(this.removeLastMessageAtom);
|
||||
} else {
|
||||
globalStore.set(this.updateLastMessageAtom, "", false);
|
||||
const responsePrompt: OpenAIPromptMessageType = {
|
||||
role: "assistant",
|
||||
content: fullMsg,
|
||||
};
|
||||
updatedHist.push(responsePrompt);
|
||||
}
|
||||
const errMsg: string = (error as Error).message;
|
||||
const errorMessage: ChatMessageType = {
|
||||
id: crypto.randomUUID(),
|
||||
user: "error",
|
||||
text: errMsg,
|
||||
};
|
||||
await BlockService.SaveWaveAiData(blockId, [...history, newPrompt, responsePrompt]);
|
||||
globalStore.set(this.addMessageAtom, errorMessage);
|
||||
globalStore.set(this.updateLastMessageAtom, "", false);
|
||||
const errorPrompt: OpenAIPromptMessageType = {
|
||||
role: "error",
|
||||
content: errMsg,
|
||||
};
|
||||
updatedHist.push(errorPrompt);
|
||||
console.log(updatedHist);
|
||||
await BlockService.SaveWaveAiData(blockId, updatedHist);
|
||||
}
|
||||
setLocked(false);
|
||||
this.cancel = false;
|
||||
@ -264,17 +287,26 @@ function makeWaveAiViewModel(blockId): WaveAiModel {
|
||||
}
|
||||
|
||||
const ChatItem = ({ chatItem }: ChatItemProps) => {
|
||||
const { isAssistant, text, isError } = chatItem;
|
||||
const { user, text } = chatItem;
|
||||
const cssVar = "--panel-bg-color";
|
||||
const panelBgColor = getComputedStyle(document.documentElement).getPropertyValue(cssVar).trim();
|
||||
|
||||
const renderError = (err: string): React.JSX.Element => <div className="chat-msg-error">{err}</div>;
|
||||
|
||||
const renderContent = useMemo(() => {
|
||||
if (isAssistant) {
|
||||
if (isError) {
|
||||
return renderError(isError);
|
||||
}
|
||||
if (user == "error") {
|
||||
return (
|
||||
<>
|
||||
<div className="chat-msg chat-msg-header">
|
||||
<div className="icon-box">
|
||||
<i className="fa-sharp fa-solid fa-circle-exclamation"></i>
|
||||
</div>
|
||||
</div>
|
||||
<div className="chat-msg chat-msg-error">
|
||||
<Markdown text={text} scrollable={false} />
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
if (user == "assistant") {
|
||||
return text ? (
|
||||
<>
|
||||
<div className="chat-msg chat-msg-header">
|
||||
@ -302,7 +334,7 @@ const ChatItem = ({ chatItem }: ChatItemProps) => {
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}, [text, isAssistant, isError]);
|
||||
}, [text, user]);
|
||||
|
||||
return <div className={"chat-msg-container"}>{renderContent}</div>;
|
||||
};
|
||||
|
@ -14,7 +14,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
openaiapi "github.com/sashabaranov/go-openai"
|
||||
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
||||
"github.com/wavetermdev/waveterm/pkg/wcloud"
|
||||
@ -150,9 +149,16 @@ func RunCloudCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRe
|
||||
rtn <- makeAIError(fmt.Errorf("unable to close openai channel: %v", err))
|
||||
}
|
||||
}()
|
||||
var sendablePromptMsgs []wshrpc.OpenAIPromptMessageType
|
||||
for _, promptMsg := range request.Prompt {
|
||||
if promptMsg.Role == "error" {
|
||||
continue
|
||||
}
|
||||
sendablePromptMsgs = append(sendablePromptMsgs, promptMsg)
|
||||
}
|
||||
reqPk := MakeOpenAICloudReqPacket()
|
||||
reqPk.ClientId = request.ClientId
|
||||
reqPk.Prompt = request.Prompt
|
||||
reqPk.Prompt = sendablePromptMsgs
|
||||
reqPk.MaxTokens = request.Opts.MaxTokens
|
||||
reqPk.MaxChoices = request.Opts.MaxChoices
|
||||
configMessageBuf, err := json.Marshal(reqPk)
|
||||
@ -200,23 +206,23 @@ func defaultAzureMapperFn(model string) string {
|
||||
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
|
||||
}
|
||||
|
||||
func setApiType(opts *wshrpc.OpenAIOptsType, clientConfig *openai.ClientConfig) error {
|
||||
func setApiType(opts *wshrpc.OpenAIOptsType, clientConfig *openaiapi.ClientConfig) error {
|
||||
ourApiType := strings.ToLower(opts.APIType)
|
||||
if ourApiType == "" || ourApiType == strings.ToLower(string(openai.APITypeOpenAI)) {
|
||||
clientConfig.APIType = openai.APITypeOpenAI
|
||||
if ourApiType == "" || ourApiType == strings.ToLower(string(openaiapi.APITypeOpenAI)) {
|
||||
clientConfig.APIType = openaiapi.APITypeOpenAI
|
||||
return nil
|
||||
} else if ourApiType == strings.ToLower(string(openai.APITypeAzure)) {
|
||||
clientConfig.APIType = openai.APITypeAzure
|
||||
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeAzure)) {
|
||||
clientConfig.APIType = openaiapi.APITypeAzure
|
||||
clientConfig.APIVersion = DefaultAzureAPIVersion
|
||||
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
||||
return nil
|
||||
} else if ourApiType == strings.ToLower(string(openai.APITypeAzureAD)) {
|
||||
clientConfig.APIType = openai.APITypeAzureAD
|
||||
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeAzureAD)) {
|
||||
clientConfig.APIType = openaiapi.APITypeAzureAD
|
||||
clientConfig.APIVersion = DefaultAzureAPIVersion
|
||||
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
||||
return nil
|
||||
} else if ourApiType == strings.ToLower(string(openai.APITypeCloudflareAzure)) {
|
||||
clientConfig.APIType = openai.APITypeCloudflareAzure
|
||||
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeCloudflareAzure)) {
|
||||
clientConfig.APIType = openaiapi.APITypeCloudflareAzure
|
||||
clientConfig.APIVersion = DefaultAzureAPIVersion
|
||||
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
||||
return nil
|
||||
|
Loading…
Reference in New Issue
Block a user