// Copyright 2024, Command Line Inc. // SPDX-License-Identifier: Apache-2.0 import { Button } from "@/app/element/button"; import { Markdown } from "@/app/element/markdown"; import { TypingIndicator } from "@/app/element/typingindicator"; import { RpcResponseHelper, WshClient } from "@/app/store/wshclient"; import { RpcApi } from "@/app/store/wshclientapi"; import { makeFeBlockRouteId } from "@/app/store/wshrouter"; import { DefaultRouter, TabRpcClient } from "@/app/store/wshrpcutil"; import { atoms, createBlock, fetchWaveFile, getApi, globalStore, WOS } from "@/store/global"; import { BlockService, ObjectService } from "@/store/services"; import { adaptFromReactOrNativeKeyEvent, checkKeyPressed } from "@/util/keyutil"; import { fireAndForget, isBlank, makeIconClass } from "@/util/util"; import { atom, Atom, PrimitiveAtom, useAtomValue, WritableAtom } from "jotai"; import type { OverlayScrollbars } from "overlayscrollbars"; import { OverlayScrollbarsComponent, OverlayScrollbarsComponentRef } from "overlayscrollbars-react"; import { forwardRef, memo, useCallback, useEffect, useImperativeHandle, useMemo, useRef, useState } from "react"; import "./waveai.scss"; interface ChatMessageType { id: string; user: string; text: string; isUpdating?: boolean; } const outline = "2px solid var(--accent-color)"; const slidingWindowSize = 30; interface ChatItemProps { chatItem: ChatMessageType; } function promptToMsg(prompt: OpenAIPromptMessageType): ChatMessageType { return { id: crypto.randomUUID(), user: prompt.role, text: prompt.content, }; } class AiWshClient extends WshClient { blockId: string; model: WaveAiModel; constructor(blockId: string, model: WaveAiModel) { super(makeFeBlockRouteId(blockId)); this.blockId = blockId; this.model = model; } handle_aisendmessage(rh: RpcResponseHelper, data: AiMessageData) { if (isBlank(data.message)) { return; } this.model.sendMessage(data.message); } } export class WaveAiModel implements ViewModel { viewType: string; blockId: string; blockAtom: Atom; presetKey: Atom; presetMap: Atom<{ [k: string]: MetaType }>; aiOpts: Atom; viewIcon?: Atom; viewName?: Atom; viewText?: Atom; preIconButton?: Atom; endIconButtons?: Atom; messagesAtom: PrimitiveAtom>; addMessageAtom: WritableAtom; updateLastMessageAtom: WritableAtom; removeLastMessageAtom: WritableAtom; simulateAssistantResponseAtom: WritableAtom>; textAreaRef: React.RefObject; locked: PrimitiveAtom; cancel: boolean; aiWshClient: AiWshClient; constructor(blockId: string) { this.aiWshClient = new AiWshClient(blockId, this); DefaultRouter.registerRoute(makeFeBlockRouteId(blockId), this.aiWshClient); this.locked = atom(false); this.cancel = false; this.viewType = "waveai"; this.blockId = blockId; this.blockAtom = WOS.getWaveObjectAtom(`block:${blockId}`); this.viewIcon = atom("sparkles"); this.viewName = atom("Wave AI"); this.messagesAtom = atom([]); this.presetKey = atom((get) => { const metaPresetKey = get(this.blockAtom).meta["ai:preset"]; const globalPresetKey = get(atoms.settingsAtom)["ai:preset"]; return metaPresetKey ?? globalPresetKey; }); this.presetMap = atom((get) => { const fullConfig = get(atoms.fullConfigAtom); const presets = fullConfig.presets; const settings = fullConfig.settings; return Object.fromEntries( Object.entries(presets) .filter(([k]) => k.startsWith("ai@")) .map(([k, v]) => { const aiPresetKeys = Object.keys(v).filter((k) => k.startsWith("ai:")); const newV = { ...v }; newV["display:name"] = aiPresetKeys.length == 1 && aiPresetKeys.includes("ai:*") ? `${newV["display:name"] ?? "Default"} (${settings["ai:model"]})` : newV["display:name"]; return [k, newV]; }) ); }); this.addMessageAtom = atom(null, (get, set, message: ChatMessageType) => { const messages = get(this.messagesAtom); set(this.messagesAtom, [...messages, message]); }); this.updateLastMessageAtom = atom(null, (get, set, text: string, isUpdating: boolean) => { const messages = get(this.messagesAtom); const lastMessage = messages[messages.length - 1]; if (lastMessage.user == "assistant") { const updatedMessage = { ...lastMessage, text: lastMessage.text + text, isUpdating }; set(this.messagesAtom, [...messages.slice(0, -1), updatedMessage]); } }); this.removeLastMessageAtom = atom(null, (get, set) => { const messages = get(this.messagesAtom); messages.pop(); set(this.messagesAtom, [...messages]); }); this.simulateAssistantResponseAtom = atom(null, async (_, set, userMessage: ChatMessageType) => { // unused at the moment. can replace the temp() function in the future const typingMessage: ChatMessageType = { id: crypto.randomUUID(), user: "assistant", text: "", }; // Add a typing indicator set(this.addMessageAtom, typingMessage); const parts = userMessage.text.split(" "); let currentPart = 0; while (currentPart < parts.length) { const part = parts[currentPart] + " "; set(this.updateLastMessageAtom, part, true); currentPart++; } set(this.updateLastMessageAtom, "", false); }); this.aiOpts = atom((get) => { const meta = get(this.blockAtom).meta; let settings = get(atoms.settingsAtom); settings = { ...settings, ...meta, }; const opts: OpenAIOptsType = { model: settings["ai:model"] ?? null, apitype: settings["ai:apitype"] ?? null, orgid: settings["ai:orgid"] ?? null, apitoken: settings["ai:apitoken"] ?? null, apiversion: settings["ai:apiversion"] ?? null, maxtokens: settings["ai:maxtokens"] ?? null, timeoutms: settings["ai:timeoutms"] ?? 60000, baseurl: settings["ai:baseurl"] ?? null, }; return opts; }); this.viewText = atom((get) => { const viewTextChildren: HeaderElem[] = []; const aiOpts = get(this.aiOpts); const presets = get(this.presetMap); const presetKey = get(this.presetKey); const presetName = presets[presetKey]?.["display:name"] ?? ""; const isCloud = isBlank(aiOpts.apitoken) && isBlank(aiOpts.baseurl); // Handle known API providers switch (aiOpts?.apitype) { case "anthropic": viewTextChildren.push({ elemtype: "iconbutton", icon: "globe", title: `Using Remote Anthropic API (${aiOpts.model})`, noAction: true, }); break; case "perplexity": viewTextChildren.push({ elemtype: "iconbutton", icon: "globe", title: `Using Remote Perplexity API (${aiOpts.model})`, noAction: true, }); break; default: if (isCloud) { viewTextChildren.push({ elemtype: "iconbutton", icon: "cloud", title: "Using Wave's AI Proxy (gpt-4o-mini)", noAction: true, }); } else { const baseUrl = aiOpts.baseurl ?? "OpenAI Default Endpoint"; const modelName = aiOpts.model; if (baseUrl.startsWith("http://localhost") || baseUrl.startsWith("http://127.0.0.1")) { viewTextChildren.push({ elemtype: "iconbutton", icon: "location-dot", title: `Using Local Model @ ${baseUrl} (${modelName})`, noAction: true, }); } else { viewTextChildren.push({ elemtype: "iconbutton", icon: "globe", title: `Using Remote Model @ ${baseUrl} (${modelName})`, noAction: true, }); } } } const dropdownItems = Object.entries(presets) .sort((a, b) => ((a[1]["display:order"] ?? 0) > (b[1]["display:order"] ?? 0) ? 1 : -1)) .map( (preset) => ({ label: preset[1]["display:name"], onClick: () => fireAndForget(() => ObjectService.UpdateObjectMeta(WOS.makeORef("block", this.blockId), { ...preset[1], "ai:preset": preset[0], }) ), }) as MenuItem ); dropdownItems.push({ label: "Add AI preset...", onClick: () => { fireAndForget(async () => { const path = `${getApi().getConfigDir()}/presets/ai.json`; const blockDef: BlockDef = { meta: { view: "preview", file: path, }, }; await createBlock(blockDef, false, true); }); }, }); viewTextChildren.push({ elemtype: "menubutton", text: presetName, title: "Select AI Configuration", items: dropdownItems, }); return viewTextChildren; }); this.endIconButtons = atom((_) => { let clearButton: IconButtonDecl = { elemtype: "iconbutton", icon: "delete-left", title: "Clear Chat History", click: this.clearMessages.bind(this), }; return [clearButton]; }); } dispose() { DefaultRouter.unregisterRoute(makeFeBlockRouteId(this.blockId)); } async populateMessages(): Promise { const history = await this.fetchAiData(); globalStore.set(this.messagesAtom, history.map(promptToMsg)); } async fetchAiData(): Promise> { const { data } = await fetchWaveFile(this.blockId, "aidata"); if (!data) { return []; } const history: Array = JSON.parse(new TextDecoder().decode(data)); return history.slice(Math.max(history.length - slidingWindowSize, 0)); } giveFocus(): boolean { if (this?.textAreaRef?.current) { this.textAreaRef.current?.focus(); return true; } return false; } getAiName(): string { const blockMeta = globalStore.get(this.blockAtom)?.meta ?? {}; const settings = globalStore.get(atoms.settingsAtom) ?? {}; const name = blockMeta["ai:name"] ?? settings["ai:name"] ?? null; return name; } setLocked(locked: boolean) { globalStore.set(this.locked, locked); } sendMessage(text: string, user: string = "user") { const clientId = globalStore.get(atoms.clientId); this.setLocked(true); const newMessage: ChatMessageType = { id: crypto.randomUUID(), user, text, }; globalStore.set(this.addMessageAtom, newMessage); // send message to backend and get response const opts = globalStore.get(this.aiOpts); const newPrompt: OpenAIPromptMessageType = { role: "user", content: text, }; const handleAiStreamingResponse = async () => { const typingMessage: ChatMessageType = { id: crypto.randomUUID(), user: "assistant", text: "", }; // Add a typing indicator globalStore.set(this.addMessageAtom, typingMessage); const history = await this.fetchAiData(); const beMsg: OpenAiStreamRequest = { clientid: clientId, opts: opts, prompt: [...history, newPrompt], }; let fullMsg = ""; try { const aiGen = RpcApi.StreamWaveAiCommand(TabRpcClient, beMsg, { timeout: opts.timeoutms }); for await (const msg of aiGen) { fullMsg += msg.text ?? ""; globalStore.set(this.updateLastMessageAtom, msg.text ?? "", true); if (this.cancel) { break; } } if (fullMsg == "") { // remove a message if empty globalStore.set(this.removeLastMessageAtom); // only save the author's prompt await BlockService.SaveWaveAiData(this.blockId, [...history, newPrompt]); } else { const responsePrompt: OpenAIPromptMessageType = { role: "assistant", content: fullMsg, }; //mark message as complete globalStore.set(this.updateLastMessageAtom, "", false); // save a complete message prompt and response await BlockService.SaveWaveAiData(this.blockId, [...history, newPrompt, responsePrompt]); } } catch (error) { const updatedHist = [...history, newPrompt]; if (fullMsg == "") { globalStore.set(this.removeLastMessageAtom); } else { globalStore.set(this.updateLastMessageAtom, "", false); const responsePrompt: OpenAIPromptMessageType = { role: "assistant", content: fullMsg, }; updatedHist.push(responsePrompt); } const errMsg: string = (error as Error).message; const errorMessage: ChatMessageType = { id: crypto.randomUUID(), user: "error", text: errMsg, }; globalStore.set(this.addMessageAtom, errorMessage); globalStore.set(this.updateLastMessageAtom, "", false); const errorPrompt: OpenAIPromptMessageType = { role: "error", content: errMsg, }; updatedHist.push(errorPrompt); await BlockService.SaveWaveAiData(this.blockId, updatedHist); } this.setLocked(false); this.cancel = false; }; fireAndForget(handleAiStreamingResponse); } useWaveAi() { const messages = useAtomValue(this.messagesAtom); return { messages, sendMessage: this.sendMessage.bind(this), }; } async clearMessages() { await BlockService.SaveWaveAiData(this.blockId, []); globalStore.set(this.messagesAtom, []); } keyDownHandler(waveEvent: WaveKeyboardEvent): boolean { if (checkKeyPressed(waveEvent, "Cmd:l")) { fireAndForget(this.clearMessages.bind(this)); return true; } return false; } } function makeWaveAiViewModel(blockId: string): WaveAiModel { const waveAiModel = new WaveAiModel(blockId); return waveAiModel; } const ChatItem = ({ chatItem }: ChatItemProps) => { const { user, text } = chatItem; const cssVar = "--panel-bg-color"; const panelBgColor = getComputedStyle(document.documentElement).getPropertyValue(cssVar).trim(); const renderContent = useMemo(() => { if (user == "error") { return ( <>
); } if (user == "assistant") { return text ? ( <>
) : ( <>
); } return ( <>
); }, [text, user]); return
{renderContent}
; }; interface ChatWindowProps { chatWindowRef: React.RefObject; messages: ChatMessageType[]; msgWidths: Object; } const ChatWindow = memo( forwardRef(({ chatWindowRef, messages, msgWidths }, ref) => { const [isUserScrolling, setIsUserScrolling] = useState(false); const osRef = useRef(null); const prevMessagesLenRef = useRef(messages.length); useImperativeHandle(ref, () => osRef.current as OverlayScrollbarsComponentRef); useEffect(() => { if (osRef.current && osRef.current.osInstance()) { const { viewport } = osRef.current.osInstance().elements(); const curMessagesLen = messages.length; if (prevMessagesLenRef.current !== curMessagesLen || !isUserScrolling) { setIsUserScrolling(false); viewport.scrollTo({ behavior: "auto", top: chatWindowRef.current?.scrollHeight || 0, }); } prevMessagesLenRef.current = curMessagesLen; } }, [messages, isUserScrolling]); useEffect(() => { if (osRef.current && osRef.current.osInstance()) { const { viewport } = osRef.current.osInstance().elements(); const handleUserScroll = () => { setIsUserScrolling(true); }; viewport.addEventListener("wheel", handleUserScroll, { passive: true }); viewport.addEventListener("touchmove", handleUserScroll, { passive: true }); return () => { viewport.removeEventListener("wheel", handleUserScroll); viewport.removeEventListener("touchmove", handleUserScroll); if (osRef.current && osRef.current.osInstance()) { osRef.current.osInstance().destroy(); } }; } }, []); const handleScrollbarInitialized = (instance: OverlayScrollbars) => { const { viewport } = instance.elements(); viewport.removeAttribute("tabindex"); viewport.scrollTo({ behavior: "auto", top: chatWindowRef.current?.scrollHeight || 0, }); }; const handleScrollbarUpdated = (instance: OverlayScrollbars) => { const { viewport } = instance.elements(); viewport.removeAttribute("tabindex"); }; return (
{messages.map((chitem, idx) => ( ))}
); }) ); interface ChatInputProps { value: string; termFontSize: number; onChange: (e: React.ChangeEvent) => void; onKeyDown: (e: React.KeyboardEvent) => void; onMouseDown: (e: React.MouseEvent) => void; model: WaveAiModel; } const ChatInput = forwardRef( ({ value, onChange, onKeyDown, onMouseDown, termFontSize, model }, ref) => { const textAreaRef = useRef(null); useImperativeHandle(ref, () => textAreaRef.current as HTMLTextAreaElement); useEffect(() => { model.textAreaRef = textAreaRef; }, []); const adjustTextAreaHeight = useCallback( (value: string) => { if (textAreaRef.current == null) { return; } // Adjust the height of the textarea to fit the text const textAreaMaxLines = 5; const textAreaLineHeight = termFontSize * 1.5; const textAreaMinHeight = textAreaLineHeight; const textAreaMaxHeight = textAreaLineHeight * textAreaMaxLines; if (value === "") { textAreaRef.current.style.height = `${textAreaLineHeight}px`; return; } textAreaRef.current.style.height = `${textAreaLineHeight}px`; const scrollHeight = textAreaRef.current.scrollHeight; const newHeight = Math.min(Math.max(scrollHeight, textAreaMinHeight), textAreaMaxHeight); textAreaRef.current.style.height = newHeight + "px"; }, [termFontSize] ); useEffect(() => { adjustTextAreaHeight(value); }, [value]); return ( ); } ); const WaveAi = ({ model }: { model: WaveAiModel; blockId: string }) => { const { messages, sendMessage } = model.useWaveAi(); const waveaiRef = useRef(null); const chatWindowRef = useRef(null); const osRef = useRef(null); const inputRef = useRef(null); const [value, setValue] = useState(""); const [selectedBlockIdx, setSelectedBlockIdx] = useState(null); const termFontSize: number = 14; const msgWidths = {}; const locked = useAtomValue(model.locked); // a weird workaround to initialize ansynchronously useEffect(() => { fireAndForget(model.populateMessages.bind(model)); }, []); const handleTextAreaChange = (e: React.ChangeEvent) => { setValue(e.target.value); }; const updatePreTagOutline = (clickedPre?: HTMLElement | null) => { const pres = chatWindowRef.current?.querySelectorAll("pre"); if (!pres) return; pres.forEach((preElement, idx) => { if (preElement === clickedPre) { setSelectedBlockIdx(idx); } else { preElement.style.outline = "none"; } }); if (clickedPre) { clickedPre.style.outline = outline; } }; useEffect(() => { if (selectedBlockIdx !== null) { const pres = chatWindowRef.current?.querySelectorAll("pre"); if (pres && pres[selectedBlockIdx]) { pres[selectedBlockIdx].style.outline = outline; } } }, [selectedBlockIdx]); const handleTextAreaMouseDown = () => { updatePreTagOutline(); setSelectedBlockIdx(null); }; const handleEnterKeyPressed = useCallback(() => { // using globalStore to avoid potential timing problems // useAtom means the component must rerender once before // the unlock is detected. this automatically checks on the // callback firing instead const locked = globalStore.get(model.locked); if (locked || value === "") return; sendMessage(value); setValue(""); setSelectedBlockIdx(null); }, [messages, value]); const updateScrollTop = () => { const pres = chatWindowRef.current?.querySelectorAll("pre"); if (!pres || selectedBlockIdx === null) return; const block = pres[selectedBlockIdx]; if (!block || !osRef.current?.osInstance()) return; const { viewport, scrollOffsetElement } = osRef.current?.osInstance().elements(); const chatWindowTop = scrollOffsetElement.scrollTop; const chatWindowHeight = chatWindowRef.current.clientHeight; const chatWindowBottom = chatWindowTop + chatWindowHeight; const elemTop = block.offsetTop; const elemBottom = elemTop + block.offsetHeight; const elementIsInView = elemBottom <= chatWindowBottom && elemTop >= chatWindowTop; if (!elementIsInView) { let scrollPosition; if (elemBottom > chatWindowBottom) { scrollPosition = elemTop - chatWindowHeight + block.offsetHeight + 15; } else if (elemTop < chatWindowTop) { scrollPosition = elemTop - 15; } viewport.scrollTo({ behavior: "auto", top: scrollPosition, }); } }; const shouldSelectCodeBlock = (key: "ArrowUp" | "ArrowDown") => { const textarea = inputRef.current; const cursorPosition = textarea?.selectionStart || 0; const textBeforeCursor = textarea?.value.slice(0, cursorPosition) || ""; return ( (textBeforeCursor.indexOf("\n") === -1 && cursorPosition === 0 && key === "ArrowUp") || selectedBlockIdx !== null ); }; const handleArrowUpPressed = (e: React.KeyboardEvent) => { if (shouldSelectCodeBlock("ArrowUp")) { e.preventDefault(); const pres = chatWindowRef.current?.querySelectorAll("pre"); let blockIndex = selectedBlockIdx; if (!pres) return; if (blockIndex === null) { setSelectedBlockIdx(pres.length - 1); } else if (blockIndex > 0) { blockIndex--; setSelectedBlockIdx(blockIndex); } updateScrollTop(); } }; const handleArrowDownPressed = (e: React.KeyboardEvent) => { if (shouldSelectCodeBlock("ArrowDown")) { e.preventDefault(); const pres = chatWindowRef.current?.querySelectorAll("pre"); let blockIndex = selectedBlockIdx; if (!pres) return; if (blockIndex === null) return; if (blockIndex < pres.length - 1 && blockIndex >= 0) { setSelectedBlockIdx(++blockIndex); updateScrollTop(); } else { inputRef.current.focus(); setSelectedBlockIdx(null); } updateScrollTop(); } }; const handleTextAreaKeyDown = (e: React.KeyboardEvent) => { const waveEvent = adaptFromReactOrNativeKeyEvent(e); if (checkKeyPressed(waveEvent, "Enter")) { e.preventDefault(); handleEnterKeyPressed(); } else if (checkKeyPressed(waveEvent, "ArrowUp")) { handleArrowUpPressed(e); } else if (checkKeyPressed(waveEvent, "ArrowDown")) { handleArrowDownPressed(e); } }; let buttonClass = "waveai-submit-button"; let buttonIcon = makeIconClass("arrow-up", false); let buttonTitle = "run"; if (locked) { buttonClass = "waveai-submit-button stop"; buttonIcon = makeIconClass("stop", false); buttonTitle = "stop"; } const handleButtonPress = useCallback(() => { if (locked) { model.cancel = true; } else { handleEnterKeyPressed(); } }, [locked, handleEnterKeyPressed]); return (
); }; export { makeWaveAiViewModel, WaveAi };