mirror of
https://github.com/wavetermdev/waveterm.git
synced 2024-12-22 16:48:23 +01:00
fix: add error messages to ai chat (#999)
This will print error messages to the chat when there is an error getting an ai response. The actual content of the responses are not forwarded to the models in future requests. <img width="389" alt="Screenshot 2024-10-09 at 2 36 13 PM" src="https://github.com/user-attachments/assets/e6c6b1c1-fa19-4456-be3b-596feaeaafed">
This commit is contained in:
parent
ad3166a2c9
commit
9dd4188810
@ -191,7 +191,7 @@
|
|||||||
background-color: transparent;
|
background-color: transparent;
|
||||||
outline: none;
|
outline: none;
|
||||||
border: none;
|
border: none;
|
||||||
color: var(--app-text-color);
|
color: var(--main-text-color);
|
||||||
width: 100%;
|
width: 100%;
|
||||||
white-space: nowrap;
|
white-space: nowrap;
|
||||||
overflow: hidden;
|
overflow: hidden;
|
||||||
|
@ -51,7 +51,7 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
&.chat-msg-assistant {
|
&.chat-msg-assistant {
|
||||||
color: var(--app-text-color);
|
color: var(--main-text-color);
|
||||||
background-color: rgb(from var(--highlight-bg-color) r g b / 0.1);
|
background-color: rgb(from var(--highlight-bg-color) r g b / 0.1);
|
||||||
margin-right: auto;
|
margin-right: auto;
|
||||||
padding: 10px;
|
padding: 10px;
|
||||||
@ -77,9 +77,23 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
&.chat-msg-error {
|
&.chat-msg-error {
|
||||||
color: var(--cmdinput-text-error);
|
color: var(--main-text-color);
|
||||||
font-family: var(--markdown-font);
|
background-color: rgb(from var(--error-color) r g b / 0.25);
|
||||||
font-size: 14px;
|
margin-right: auto;
|
||||||
|
padding: 10px;
|
||||||
|
max-width: 85%;
|
||||||
|
|
||||||
|
.markdown {
|
||||||
|
width: 100%;
|
||||||
|
|
||||||
|
pre {
|
||||||
|
white-space: pre-wrap;
|
||||||
|
word-break: break-word;
|
||||||
|
max-width: 100%;
|
||||||
|
overflow-x: auto;
|
||||||
|
margin-left: 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
&.typing-indicator {
|
&.typing-indicator {
|
||||||
|
@ -20,9 +20,7 @@ interface ChatMessageType {
|
|||||||
id: string;
|
id: string;
|
||||||
user: string;
|
user: string;
|
||||||
text: string;
|
text: string;
|
||||||
isAssistant: boolean;
|
|
||||||
isUpdating?: boolean;
|
isUpdating?: boolean;
|
||||||
isError?: string;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const outline = "2px solid var(--accent-color)";
|
const outline = "2px solid var(--accent-color)";
|
||||||
@ -36,7 +34,6 @@ function promptToMsg(prompt: OpenAIPromptMessageType): ChatMessageType {
|
|||||||
id: crypto.randomUUID(),
|
id: crypto.randomUUID(),
|
||||||
user: prompt.role,
|
user: prompt.role,
|
||||||
text: prompt.content,
|
text: prompt.content,
|
||||||
isAssistant: prompt.role == "assistant",
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,7 +75,7 @@ export class WaveAiModel implements ViewModel {
|
|||||||
this.updateLastMessageAtom = atom(null, (get, set, text: string, isUpdating: boolean) => {
|
this.updateLastMessageAtom = atom(null, (get, set, text: string, isUpdating: boolean) => {
|
||||||
const messages = get(this.messagesAtom);
|
const messages = get(this.messagesAtom);
|
||||||
const lastMessage = messages[messages.length - 1];
|
const lastMessage = messages[messages.length - 1];
|
||||||
if (lastMessage.isAssistant && !lastMessage.isError) {
|
if (lastMessage.user == "assistant") {
|
||||||
const updatedMessage = { ...lastMessage, text: lastMessage.text + text, isUpdating };
|
const updatedMessage = { ...lastMessage, text: lastMessage.text + text, isUpdating };
|
||||||
set(this.messagesAtom, [...messages.slice(0, -1), updatedMessage]);
|
set(this.messagesAtom, [...messages.slice(0, -1), updatedMessage]);
|
||||||
}
|
}
|
||||||
@ -94,7 +91,6 @@ export class WaveAiModel implements ViewModel {
|
|||||||
id: crypto.randomUUID(),
|
id: crypto.randomUUID(),
|
||||||
user: "assistant",
|
user: "assistant",
|
||||||
text: "",
|
text: "",
|
||||||
isAssistant: true,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Add a typing indicator
|
// Add a typing indicator
|
||||||
@ -190,7 +186,6 @@ export class WaveAiModel implements ViewModel {
|
|||||||
id: crypto.randomUUID(),
|
id: crypto.randomUUID(),
|
||||||
user,
|
user,
|
||||||
text,
|
text,
|
||||||
isAssistant: false,
|
|
||||||
};
|
};
|
||||||
addMessage(newMessage);
|
addMessage(newMessage);
|
||||||
// send message to backend and get response
|
// send message to backend and get response
|
||||||
@ -214,7 +209,6 @@ export class WaveAiModel implements ViewModel {
|
|||||||
id: crypto.randomUUID(),
|
id: crypto.randomUUID(),
|
||||||
user: "assistant",
|
user: "assistant",
|
||||||
text: "",
|
text: "",
|
||||||
isAssistant: true,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Add a typing indicator
|
// Add a typing indicator
|
||||||
@ -225,8 +219,9 @@ export class WaveAiModel implements ViewModel {
|
|||||||
opts: opts,
|
opts: opts,
|
||||||
prompt: [...history, newPrompt],
|
prompt: [...history, newPrompt],
|
||||||
};
|
};
|
||||||
const aiGen = RpcApi.StreamWaveAiCommand(WindowRpcClient, beMsg, { timeout: opts.timeoutms });
|
|
||||||
let fullMsg = "";
|
let fullMsg = "";
|
||||||
|
try {
|
||||||
|
const aiGen = RpcApi.StreamWaveAiCommand(WindowRpcClient, beMsg, { timeout: opts.timeoutms });
|
||||||
for await (const msg of aiGen) {
|
for await (const msg of aiGen) {
|
||||||
fullMsg += msg.text ?? "";
|
fullMsg += msg.text ?? "";
|
||||||
globalStore.set(this.updateLastMessageAtom, msg.text ?? "", true);
|
globalStore.set(this.updateLastMessageAtom, msg.text ?? "", true);
|
||||||
@ -236,7 +231,6 @@ export class WaveAiModel implements ViewModel {
|
|||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
globalStore.set(this.updateLastMessageAtom, "", false);
|
globalStore.set(this.updateLastMessageAtom, "", false);
|
||||||
if (fullMsg != "") {
|
if (fullMsg != "") {
|
||||||
const responsePrompt: OpenAIPromptMessageType = {
|
const responsePrompt: OpenAIPromptMessageType = {
|
||||||
@ -245,6 +239,35 @@ export class WaveAiModel implements ViewModel {
|
|||||||
};
|
};
|
||||||
await BlockService.SaveWaveAiData(blockId, [...history, newPrompt, responsePrompt]);
|
await BlockService.SaveWaveAiData(blockId, [...history, newPrompt, responsePrompt]);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
const updatedHist = [...history, newPrompt];
|
||||||
|
if (fullMsg == "") {
|
||||||
|
globalStore.set(this.removeLastMessageAtom);
|
||||||
|
} else {
|
||||||
|
globalStore.set(this.updateLastMessageAtom, "", false);
|
||||||
|
const responsePrompt: OpenAIPromptMessageType = {
|
||||||
|
role: "assistant",
|
||||||
|
content: fullMsg,
|
||||||
|
};
|
||||||
|
updatedHist.push(responsePrompt);
|
||||||
|
}
|
||||||
|
const errMsg: string = (error as Error).message;
|
||||||
|
const errorMessage: ChatMessageType = {
|
||||||
|
id: crypto.randomUUID(),
|
||||||
|
user: "error",
|
||||||
|
text: errMsg,
|
||||||
|
};
|
||||||
|
globalStore.set(this.addMessageAtom, errorMessage);
|
||||||
|
globalStore.set(this.updateLastMessageAtom, "", false);
|
||||||
|
const errorPrompt: OpenAIPromptMessageType = {
|
||||||
|
role: "error",
|
||||||
|
content: errMsg,
|
||||||
|
};
|
||||||
|
updatedHist.push(errorPrompt);
|
||||||
|
console.log(updatedHist);
|
||||||
|
await BlockService.SaveWaveAiData(blockId, updatedHist);
|
||||||
|
}
|
||||||
setLocked(false);
|
setLocked(false);
|
||||||
this.cancel = false;
|
this.cancel = false;
|
||||||
};
|
};
|
||||||
@ -264,17 +287,26 @@ function makeWaveAiViewModel(blockId): WaveAiModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const ChatItem = ({ chatItem }: ChatItemProps) => {
|
const ChatItem = ({ chatItem }: ChatItemProps) => {
|
||||||
const { isAssistant, text, isError } = chatItem;
|
const { user, text } = chatItem;
|
||||||
const cssVar = "--panel-bg-color";
|
const cssVar = "--panel-bg-color";
|
||||||
const panelBgColor = getComputedStyle(document.documentElement).getPropertyValue(cssVar).trim();
|
const panelBgColor = getComputedStyle(document.documentElement).getPropertyValue(cssVar).trim();
|
||||||
|
|
||||||
const renderError = (err: string): React.JSX.Element => <div className="chat-msg-error">{err}</div>;
|
|
||||||
|
|
||||||
const renderContent = useMemo(() => {
|
const renderContent = useMemo(() => {
|
||||||
if (isAssistant) {
|
if (user == "error") {
|
||||||
if (isError) {
|
return (
|
||||||
return renderError(isError);
|
<>
|
||||||
|
<div className="chat-msg chat-msg-header">
|
||||||
|
<div className="icon-box">
|
||||||
|
<i className="fa-sharp fa-solid fa-circle-exclamation"></i>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="chat-msg chat-msg-error">
|
||||||
|
<Markdown text={text} scrollable={false} />
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
if (user == "assistant") {
|
||||||
return text ? (
|
return text ? (
|
||||||
<>
|
<>
|
||||||
<div className="chat-msg chat-msg-header">
|
<div className="chat-msg chat-msg-header">
|
||||||
@ -302,7 +334,7 @@ const ChatItem = ({ chatItem }: ChatItemProps) => {
|
|||||||
</div>
|
</div>
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
}, [text, isAssistant, isError]);
|
}, [text, user]);
|
||||||
|
|
||||||
return <div className={"chat-msg-container"}>{renderContent}</div>;
|
return <div className={"chat-msg-container"}>{renderContent}</div>;
|
||||||
};
|
};
|
||||||
|
@ -14,7 +14,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
|
||||||
openaiapi "github.com/sashabaranov/go-openai"
|
openaiapi "github.com/sashabaranov/go-openai"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
||||||
"github.com/wavetermdev/waveterm/pkg/wcloud"
|
"github.com/wavetermdev/waveterm/pkg/wcloud"
|
||||||
@ -150,9 +149,16 @@ func RunCloudCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRe
|
|||||||
rtn <- makeAIError(fmt.Errorf("unable to close openai channel: %v", err))
|
rtn <- makeAIError(fmt.Errorf("unable to close openai channel: %v", err))
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
var sendablePromptMsgs []wshrpc.OpenAIPromptMessageType
|
||||||
|
for _, promptMsg := range request.Prompt {
|
||||||
|
if promptMsg.Role == "error" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
sendablePromptMsgs = append(sendablePromptMsgs, promptMsg)
|
||||||
|
}
|
||||||
reqPk := MakeOpenAICloudReqPacket()
|
reqPk := MakeOpenAICloudReqPacket()
|
||||||
reqPk.ClientId = request.ClientId
|
reqPk.ClientId = request.ClientId
|
||||||
reqPk.Prompt = request.Prompt
|
reqPk.Prompt = sendablePromptMsgs
|
||||||
reqPk.MaxTokens = request.Opts.MaxTokens
|
reqPk.MaxTokens = request.Opts.MaxTokens
|
||||||
reqPk.MaxChoices = request.Opts.MaxChoices
|
reqPk.MaxChoices = request.Opts.MaxChoices
|
||||||
configMessageBuf, err := json.Marshal(reqPk)
|
configMessageBuf, err := json.Marshal(reqPk)
|
||||||
@ -200,23 +206,23 @@ func defaultAzureMapperFn(model string) string {
|
|||||||
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
|
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
func setApiType(opts *wshrpc.OpenAIOptsType, clientConfig *openai.ClientConfig) error {
|
func setApiType(opts *wshrpc.OpenAIOptsType, clientConfig *openaiapi.ClientConfig) error {
|
||||||
ourApiType := strings.ToLower(opts.APIType)
|
ourApiType := strings.ToLower(opts.APIType)
|
||||||
if ourApiType == "" || ourApiType == strings.ToLower(string(openai.APITypeOpenAI)) {
|
if ourApiType == "" || ourApiType == strings.ToLower(string(openaiapi.APITypeOpenAI)) {
|
||||||
clientConfig.APIType = openai.APITypeOpenAI
|
clientConfig.APIType = openaiapi.APITypeOpenAI
|
||||||
return nil
|
return nil
|
||||||
} else if ourApiType == strings.ToLower(string(openai.APITypeAzure)) {
|
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeAzure)) {
|
||||||
clientConfig.APIType = openai.APITypeAzure
|
clientConfig.APIType = openaiapi.APITypeAzure
|
||||||
clientConfig.APIVersion = DefaultAzureAPIVersion
|
clientConfig.APIVersion = DefaultAzureAPIVersion
|
||||||
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
||||||
return nil
|
return nil
|
||||||
} else if ourApiType == strings.ToLower(string(openai.APITypeAzureAD)) {
|
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeAzureAD)) {
|
||||||
clientConfig.APIType = openai.APITypeAzureAD
|
clientConfig.APIType = openaiapi.APITypeAzureAD
|
||||||
clientConfig.APIVersion = DefaultAzureAPIVersion
|
clientConfig.APIVersion = DefaultAzureAPIVersion
|
||||||
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
||||||
return nil
|
return nil
|
||||||
} else if ourApiType == strings.ToLower(string(openai.APITypeCloudflareAzure)) {
|
} else if ourApiType == strings.ToLower(string(openaiapi.APITypeCloudflareAzure)) {
|
||||||
clientConfig.APIType = openai.APITypeCloudflareAzure
|
clientConfig.APIType = openaiapi.APITypeCloudflareAzure
|
||||||
clientConfig.APIVersion = DefaultAzureAPIVersion
|
clientConfig.APIVersion = DefaultAzureAPIVersion
|
||||||
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
||||||
return nil
|
return nil
|
||||||
|
Loading…
Reference in New Issue
Block a user