// Copyright 2024, Command Line Inc. // SPDX-License-Identifier: Apache-2.0 import { Button } from "@/app/element/button"; import { Markdown } from "@/app/element/markdown"; import { TypingIndicator } from "@/app/element/typingindicator"; import { RpcApi } from "@/app/store/wshclientapi"; import { WindowRpcClient } from "@/app/store/wshrpcutil"; import { atoms, fetchWaveFile, globalStore, WOS } from "@/store/global"; import { BlockService } from "@/store/services"; import { adaptFromReactOrNativeKeyEvent, checkKeyPressed } from "@/util/keyutil"; import { isBlank, makeIconClass } from "@/util/util"; import { atom, Atom, PrimitiveAtom, useAtomValue, useSetAtom, WritableAtom } from "jotai"; import type { OverlayScrollbars } from "overlayscrollbars"; import { OverlayScrollbarsComponent, OverlayScrollbarsComponentRef } from "overlayscrollbars-react"; import { forwardRef, memo, useCallback, useEffect, useImperativeHandle, useMemo, useRef, useState } from "react"; import "./waveai.less"; interface ChatMessageType { id: string; user: string; text: string; isUpdating?: boolean; } const outline = "2px solid var(--accent-color)"; interface ChatItemProps { chatItem: ChatMessageType; } function promptToMsg(prompt: OpenAIPromptMessageType): ChatMessageType { return { id: crypto.randomUUID(), user: prompt.role, text: prompt.content, }; } export class WaveAiModel implements ViewModel { viewType: string; blockId: string; blockAtom: Atom; viewIcon?: Atom; viewName?: Atom; viewText?: Atom; preIconButton?: Atom; endIconButtons?: Atom; messagesAtom: PrimitiveAtom>; addMessageAtom: WritableAtom; updateLastMessageAtom: WritableAtom; removeLastMessageAtom: WritableAtom; simulateAssistantResponseAtom: WritableAtom>; textAreaRef: React.RefObject; locked: PrimitiveAtom; cancel: boolean; constructor(blockId: string) { this.locked = atom(false); this.cancel = false; this.viewType = "waveai"; this.blockId = blockId; this.blockAtom = WOS.getWaveObjectAtom(`block:${blockId}`); this.viewIcon = atom((get) => { return "sparkles"; // should not be hardcoded }); this.viewName = atom("Wave Ai"); this.messagesAtom = atom([]); this.addMessageAtom = atom(null, (get, set, message: ChatMessageType) => { const messages = get(this.messagesAtom); set(this.messagesAtom, [...messages, message]); }); this.updateLastMessageAtom = atom(null, (get, set, text: string, isUpdating: boolean) => { const messages = get(this.messagesAtom); const lastMessage = messages[messages.length - 1]; if (lastMessage.user == "assistant") { const updatedMessage = { ...lastMessage, text: lastMessage.text + text, isUpdating }; set(this.messagesAtom, [...messages.slice(0, -1), updatedMessage]); } }); this.removeLastMessageAtom = atom(null, (get, set) => { const messages = get(this.messagesAtom); messages.pop(); set(this.messagesAtom, [...messages]); }); this.simulateAssistantResponseAtom = atom(null, async (get, set, userMessage: ChatMessageType) => { // unused at the moment. can replace the temp() function in the future const typingMessage: ChatMessageType = { id: crypto.randomUUID(), user: "assistant", text: "", }; // Add a typing indicator set(this.addMessageAtom, typingMessage); const parts = userMessage.text.split(" "); let currentPart = 0; while (currentPart < parts.length) { const part = parts[currentPart] + " "; set(this.updateLastMessageAtom, part, true); currentPart++; } set(this.updateLastMessageAtom, "", false); }); this.viewText = atom((get) => { const viewTextChildren: HeaderElem[] = []; const settings = get(atoms.settingsAtom); const isCloud = isBlank(settings?.["ai:apitoken"]) && isBlank(settings?.["ai:baseurl"]); let modelText = "gpt-4o-mini"; if (!isCloud && !isBlank(settings?.["ai:model"])) { if (!isBlank(settings?.["ai:name"])) { modelText = settings["ai:name"]; } else { modelText = settings["ai:model"]; } } if (isCloud) { viewTextChildren.push({ elemtype: "iconbutton", icon: "cloud", title: "Using Wave's AI Proxy (gpt-4o-mini)", disabled: true, }); } else { const baseUrl = settings["ai:baseurl"] ?? "OpenAI Default Endpoint"; const modelName = settings["ai:model"]; if (baseUrl.startsWith("http://localhost") || baseUrl.startsWith("http://127.0.0.1")) { viewTextChildren.push({ elemtype: "iconbutton", icon: "location-dot", title: "Using Local Model @ " + baseUrl + " (" + modelName + ")", disabled: true, }); } else { viewTextChildren.push({ elemtype: "iconbutton", icon: "globe", title: "Using Remote Model @ " + baseUrl + " (" + modelName + ")", disabled: true, }); } } viewTextChildren.push({ elemtype: "text", text: modelText, }); return viewTextChildren; }); } async populateMessages(): Promise { const history = await this.fetchAiData(); globalStore.set(this.messagesAtom, history.map(promptToMsg)); } async fetchAiData(): Promise> { const { data, fileInfo } = await fetchWaveFile(this.blockId, "aidata"); if (!data) { return []; } const history: Array = JSON.parse(new TextDecoder().decode(data)); return history; } giveFocus(): boolean { if (this?.textAreaRef?.current) { this.textAreaRef.current?.focus(); return true; } return false; } useWaveAi() { const messages = useAtomValue(this.messagesAtom); const addMessage = useSetAtom(this.addMessageAtom); const simulateResponse = useSetAtom(this.simulateAssistantResponseAtom); const clientId = useAtomValue(atoms.clientId); const blockId = this.blockId; const setLocked = useSetAtom(this.locked); const sendMessage = (text: string, user: string = "user") => { setLocked(true); const newMessage: ChatMessageType = { id: crypto.randomUUID(), user, text, }; addMessage(newMessage); // send message to backend and get response const settings = globalStore.get(atoms.settingsAtom) ?? {}; const opts: OpenAIOptsType = { model: settings["ai:model"], apitype: settings["ai:apitype"], orgid: settings["ai:orgid"], apitoken: settings["ai:apitoken"], apiversion: settings["ai:apiversion"], maxtokens: settings["ai:maxtokens"], timeoutms: settings["ai:timeoutms"] ?? 60000, baseurl: settings["ai:baseurl"], }; const newPrompt: OpenAIPromptMessageType = { role: "user", content: text, }; const temp = async () => { const typingMessage: ChatMessageType = { id: crypto.randomUUID(), user: "assistant", text: "", }; // Add a typing indicator globalStore.set(this.addMessageAtom, typingMessage); const history = await this.fetchAiData(); const beMsg: OpenAiStreamRequest = { clientid: clientId, opts: opts, prompt: [...history, newPrompt], }; let fullMsg = ""; try { const aiGen = RpcApi.StreamWaveAiCommand(WindowRpcClient, beMsg, { timeout: opts.timeoutms }); for await (const msg of aiGen) { fullMsg += msg.text ?? ""; globalStore.set(this.updateLastMessageAtom, msg.text ?? "", true); if (this.cancel) { if (fullMsg == "") { globalStore.set(this.removeLastMessageAtom); } break; } globalStore.set(this.updateLastMessageAtom, "", false); if (fullMsg != "") { const responsePrompt: OpenAIPromptMessageType = { role: "assistant", content: fullMsg, }; await BlockService.SaveWaveAiData(blockId, [...history, newPrompt, responsePrompt]); } } } catch (error) { const updatedHist = [...history, newPrompt]; if (fullMsg == "") { globalStore.set(this.removeLastMessageAtom); } else { globalStore.set(this.updateLastMessageAtom, "", false); const responsePrompt: OpenAIPromptMessageType = { role: "assistant", content: fullMsg, }; updatedHist.push(responsePrompt); } const errMsg: string = (error as Error).message; const errorMessage: ChatMessageType = { id: crypto.randomUUID(), user: "error", text: errMsg, }; globalStore.set(this.addMessageAtom, errorMessage); globalStore.set(this.updateLastMessageAtom, "", false); const errorPrompt: OpenAIPromptMessageType = { role: "error", content: errMsg, }; updatedHist.push(errorPrompt); console.log(updatedHist); await BlockService.SaveWaveAiData(blockId, updatedHist); } setLocked(false); this.cancel = false; }; temp(); }; return { messages, sendMessage, }; } } function makeWaveAiViewModel(blockId): WaveAiModel { const waveAiModel = new WaveAiModel(blockId); return waveAiModel; } const ChatItem = ({ chatItem }: ChatItemProps) => { const { user, text } = chatItem; const cssVar = "--panel-bg-color"; const panelBgColor = getComputedStyle(document.documentElement).getPropertyValue(cssVar).trim(); const renderContent = useMemo(() => { if (user == "error") { return ( <>
); } if (user == "assistant") { return text ? ( <>
) : ( <>
); } return ( <>
); }, [text, user]); return
{renderContent}
; }; interface ChatWindowProps { chatWindowRef: React.RefObject; messages: ChatMessageType[]; msgWidths: Object; } const ChatWindow = memo( forwardRef(({ chatWindowRef, messages, msgWidths }, ref) => { const [isUserScrolling, setIsUserScrolling] = useState(false); const osRef = useRef(null); const prevMessagesLenRef = useRef(messages.length); useImperativeHandle(ref, () => osRef.current as OverlayScrollbarsComponentRef); useEffect(() => { if (osRef.current && osRef.current.osInstance()) { const { viewport } = osRef.current.osInstance().elements(); const curMessagesLen = messages.length; if (prevMessagesLenRef.current !== curMessagesLen || !isUserScrolling) { setIsUserScrolling(false); viewport.scrollTo({ behavior: "auto", top: chatWindowRef.current?.scrollHeight || 0, }); } prevMessagesLenRef.current = curMessagesLen; } }, [messages, isUserScrolling]); useEffect(() => { if (osRef.current && osRef.current.osInstance()) { const { viewport } = osRef.current.osInstance().elements(); const handleUserScroll = () => { setIsUserScrolling(true); }; viewport.addEventListener("wheel", handleUserScroll, { passive: true }); viewport.addEventListener("touchmove", handleUserScroll, { passive: true }); return () => { viewport.removeEventListener("wheel", handleUserScroll); viewport.removeEventListener("touchmove", handleUserScroll); if (osRef.current && osRef.current.osInstance()) { osRef.current.osInstance().destroy(); } }; } }, []); const handleScrollbarInitialized = (instance: OverlayScrollbars) => { const { viewport } = instance.elements(); viewport.removeAttribute("tabindex"); viewport.scrollTo({ behavior: "auto", top: chatWindowRef.current?.scrollHeight || 0, }); }; const handleScrollbarUpdated = (instance: OverlayScrollbars) => { const { viewport } = instance.elements(); viewport.removeAttribute("tabindex"); }; return (
{messages.map((chitem, idx) => ( ))}
); }) ); interface ChatInputProps { value: string; termFontSize: number; onChange: (e: React.ChangeEvent) => void; onKeyDown: (e: React.KeyboardEvent) => void; onMouseDown: (e: React.MouseEvent) => void; model: WaveAiModel; } const ChatInput = forwardRef( ({ value, onChange, onKeyDown, onMouseDown, termFontSize, model }, ref) => { const textAreaRef = useRef(null); useImperativeHandle(ref, () => textAreaRef.current as HTMLTextAreaElement); useEffect(() => { model.textAreaRef = textAreaRef; }, []); const adjustTextAreaHeight = () => { if (textAreaRef.current == null) { return; } // Adjust the height of the textarea to fit the text const textAreaMaxLines = 100; const textAreaLineHeight = termFontSize * 1.5; const textAreaMinHeight = textAreaLineHeight; const textAreaMaxHeight = textAreaLineHeight * textAreaMaxLines; textAreaRef.current.style.height = "1px"; const scrollHeight = textAreaRef.current.scrollHeight; const newHeight = Math.min(Math.max(scrollHeight, textAreaMinHeight), textAreaMaxHeight); textAreaRef.current.style.height = newHeight + "px"; }; useEffect(() => { adjustTextAreaHeight(); }, [value]); return ( ); } ); const WaveAi = ({ model }: { model: WaveAiModel; blockId: string }) => { const { messages, sendMessage } = model.useWaveAi(); const waveaiRef = useRef(null); const chatWindowRef = useRef(null); const osRef = useRef(null); const inputRef = useRef(null); const [value, setValue] = useState(""); const [selectedBlockIdx, setSelectedBlockIdx] = useState(null); const termFontSize: number = 14; const msgWidths = {}; const locked = useAtomValue(model.locked); // a weird workaround to initialize ansynchronously useEffect(() => { model.populateMessages(); }, []); const handleTextAreaChange = (e: React.ChangeEvent) => { setValue(e.target.value); }; const updatePreTagOutline = (clickedPre?: HTMLElement | null) => { const pres = chatWindowRef.current?.querySelectorAll("pre"); if (!pres) return; pres.forEach((preElement, idx) => { if (preElement === clickedPre) { setSelectedBlockIdx(idx); } else { preElement.style.outline = "none"; } }); if (clickedPre) { clickedPre.style.outline = outline; } }; useEffect(() => { if (selectedBlockIdx !== null) { const pres = chatWindowRef.current?.querySelectorAll("pre"); if (pres && pres[selectedBlockIdx]) { pres[selectedBlockIdx].style.outline = outline; } } }, [selectedBlockIdx]); const handleTextAreaMouseDown = () => { updatePreTagOutline(); setSelectedBlockIdx(null); }; const handleEnterKeyPressed = useCallback(() => { // using globalStore to avoid potential timing problems // useAtom means the component must rerender once before // the unlock is detected. this automatically checks on the // callback firing instead const locked = globalStore.get(model.locked); if (locked || value === "") return; sendMessage(value); setValue(""); setSelectedBlockIdx(null); }, [messages, value]); const updateScrollTop = () => { const pres = chatWindowRef.current?.querySelectorAll("pre"); if (!pres || selectedBlockIdx === null) return; const block = pres[selectedBlockIdx]; if (!block || !osRef.current?.osInstance()) return; const { viewport, scrollOffsetElement } = osRef.current?.osInstance().elements(); const chatWindowTop = scrollOffsetElement.scrollTop; const chatWindowHeight = chatWindowRef.current.clientHeight; const chatWindowBottom = chatWindowTop + chatWindowHeight; const elemTop = block.offsetTop; const elemBottom = elemTop + block.offsetHeight; const elementIsInView = elemBottom <= chatWindowBottom && elemTop >= chatWindowTop; if (!elementIsInView) { let scrollPosition; if (elemBottom > chatWindowBottom) { scrollPosition = elemTop - chatWindowHeight + block.offsetHeight + 15; } else if (elemTop < chatWindowTop) { scrollPosition = elemTop - 15; } viewport.scrollTo({ behavior: "auto", top: scrollPosition, }); } }; const shouldSelectCodeBlock = (key: "ArrowUp" | "ArrowDown") => { const textarea = inputRef.current; const cursorPosition = textarea?.selectionStart || 0; const textBeforeCursor = textarea?.value.slice(0, cursorPosition) || ""; return ( (textBeforeCursor.indexOf("\n") === -1 && cursorPosition === 0 && key === "ArrowUp") || selectedBlockIdx !== null ); }; const handleArrowUpPressed = (e: React.KeyboardEvent) => { if (shouldSelectCodeBlock("ArrowUp")) { e.preventDefault(); const pres = chatWindowRef.current?.querySelectorAll("pre"); let blockIndex = selectedBlockIdx; if (!pres) return; if (blockIndex === null) { setSelectedBlockIdx(pres.length - 1); } else if (blockIndex > 0) { blockIndex--; setSelectedBlockIdx(blockIndex); } updateScrollTop(); } }; const handleArrowDownPressed = (e: React.KeyboardEvent) => { if (shouldSelectCodeBlock("ArrowDown")) { e.preventDefault(); const pres = chatWindowRef.current?.querySelectorAll("pre"); let blockIndex = selectedBlockIdx; if (!pres) return; if (blockIndex === null) return; if (blockIndex < pres.length - 1 && blockIndex >= 0) { setSelectedBlockIdx(++blockIndex); updateScrollTop(); } else { inputRef.current.focus(); setSelectedBlockIdx(null); } updateScrollTop(); } }; const handleTextAreaKeyDown = (e: React.KeyboardEvent) => { const waveEvent = adaptFromReactOrNativeKeyEvent(e); if (checkKeyPressed(waveEvent, "Enter")) { e.preventDefault(); handleEnterKeyPressed(); } else if (checkKeyPressed(waveEvent, "ArrowUp")) { handleArrowUpPressed(e); } else if (checkKeyPressed(waveEvent, "ArrowDown")) { handleArrowDownPressed(e); } }; let buttonClass = "waveai-submit-button"; let buttonIcon = makeIconClass("arrow-up", false); let buttonTitle = "run"; if (locked) { buttonClass = "waveai-submit-button stop"; buttonIcon = makeIconClass("stop", false); buttonTitle = "stop"; } const handleButtonPress = useCallback(() => { if (locked) { model.cancel = true; } else { handleEnterKeyPressed(); } }, [locked, handleEnterKeyPressed]); return (
); }; export { makeWaveAiViewModel, WaveAi };