// Copyright 2024, Command Line Inc. // SPDX-License-Identifier: Apache-2.0 import { Markdown } from "@/app/element/markdown"; import { TypingIndicator } from "@/app/element/typingindicator"; import { atoms, fetchWaveFile, getUserName, globalStore, WOS } from "@/store/global"; import { BlockService } from "@/store/services"; import { WshServer } from "@/store/wshserver"; import { adaptFromReactOrNativeKeyEvent, checkKeyPressed } from "@/util/keyutil"; import { isBlank } from "@/util/util"; import { atom, Atom, PrimitiveAtom, useAtomValue, useSetAtom, WritableAtom } from "jotai"; import type { OverlayScrollbars } from "overlayscrollbars"; import { OverlayScrollbarsComponent, OverlayScrollbarsComponentRef } from "overlayscrollbars-react"; import { forwardRef, memo, useCallback, useEffect, useImperativeHandle, useRef, useState } from "react"; import tinycolor from "tinycolor2"; import "./waveai.less"; interface ChatMessageType { id: string; user: string; text: string; isAssistant: boolean; isUpdating?: boolean; isError?: string; } const outline = "2px solid var(--accent-color)"; interface ChatItemProps { chatItem: ChatMessageType; itemCount: number; } function promptToMsg(prompt: OpenAIPromptMessageType): ChatMessageType { return { id: crypto.randomUUID(), user: prompt.role, text: prompt.content, isAssistant: prompt.role == "assistant", }; } export class WaveAiModel implements ViewModel { viewType: string; blockId: string; blockAtom: Atom; viewIcon?: Atom; viewName?: Atom; viewText?: Atom; preIconButton?: Atom; endIconButtons?: Atom; messagesAtom: PrimitiveAtom>; addMessageAtom: WritableAtom; updateLastMessageAtom: WritableAtom; simulateAssistantResponseAtom: WritableAtom>; textAreaRef: React.RefObject; constructor(blockId: string) { this.viewType = "waveai"; this.blockId = blockId; this.blockAtom = WOS.getWaveObjectAtom(`block:${blockId}`); this.viewIcon = atom((get) => { return "sparkles"; // should not be hardcoded }); this.viewName = atom("Wave Ai"); this.messagesAtom = atom([]); this.addMessageAtom = atom(null, (get, set, message: ChatMessageType) => { const messages = get(this.messagesAtom); set(this.messagesAtom, [...messages, message]); }); this.updateLastMessageAtom = atom(null, (get, set, text: string, isUpdating: boolean) => { const messages = get(this.messagesAtom); const lastMessage = messages[messages.length - 1]; if (lastMessage.isAssistant && !lastMessage.isError) { const updatedMessage = { ...lastMessage, text: lastMessage.text + text, isUpdating }; set(this.messagesAtom, [...messages.slice(0, -1), updatedMessage]); } }); this.simulateAssistantResponseAtom = atom(null, async (get, set, userMessage: ChatMessageType) => { const typingMessage: ChatMessageType = { id: crypto.randomUUID(), user: "assistant", text: "", isAssistant: true, }; // Add a typing indicator set(this.addMessageAtom, typingMessage); setTimeout(() => { const parts = userMessage.text.split(" "); let currentPart = 0; const intervalId = setInterval(() => { if (currentPart < parts.length) { const part = parts[currentPart] + " "; set(this.updateLastMessageAtom, part, true); currentPart++; } else { clearInterval(intervalId); set(this.updateLastMessageAtom, "", false); } }, 100); }, 1500); }); this.viewText = atom((get) => { const settings = get(atoms.settingsAtom); const isCloud = isBlank(settings?.["ai:apitoken"]) && isBlank(settings?.["ai:baseurl"]); let modelText = "gpt-4o-mini"; if (!isCloud && !isBlank(settings?.["ai:model"])) { modelText = settings["ai:model"]; } const viewTextChildren: HeaderElem[] = [ { elemtype: "text", text: modelText, }, ]; return viewTextChildren; }); } async populateMessages(): Promise { const history = await this.fetchAiData(); globalStore.set(this.messagesAtom, history.map(promptToMsg)); } async fetchAiData(): Promise> { const { data, fileInfo } = await fetchWaveFile(this.blockId, "aidata"); if (!data) { return []; } const history: Array = JSON.parse(new TextDecoder().decode(data)); return history; } giveFocus(): boolean { if (this?.textAreaRef?.current) { this.textAreaRef.current?.focus(); return true; } return false; } useWaveAi() { const messages = useAtomValue(this.messagesAtom); const addMessage = useSetAtom(this.addMessageAtom); const simulateResponse = useSetAtom(this.simulateAssistantResponseAtom); const clientId = useAtomValue(atoms.clientId); const blockId = this.blockId; const sendMessage = (text: string, user: string = "user") => { const newMessage: ChatMessageType = { id: crypto.randomUUID(), user, text, isAssistant: false, }; addMessage(newMessage); // send message to backend and get response const settings = globalStore.get(atoms.settingsAtom); const opts: OpenAIOptsType = { model: settings["ai:model"], apitoken: settings["ai:apitoken"], maxtokens: settings["ai:maxtokens"], timeout: settings["ai:timeoutms"] / 1000, baseurl: settings["ai:baseurl"], }; const newPrompt: OpenAIPromptMessageType = { role: "user", content: text, }; if (newPrompt.name == "*username") { newPrompt.name = getUserName(); } const temp = async () => { const history = await this.fetchAiData(); const beMsg: OpenAiStreamRequest = { clientid: clientId, opts: opts, prompt: [...history, newPrompt], }; const aiGen = WshServer.StreamWaveAiCommand(beMsg, { timeout: 60000 }); let fullMsg = ""; for await (const msg of aiGen) { fullMsg += msg.text ?? ""; } const response: ChatMessageType = { id: newMessage.id, user: newMessage.user, text: fullMsg, isAssistant: true, }; const responsePrompt: OpenAIPromptMessageType = { role: "assistant", content: fullMsg, }; const writeToHistory = BlockService.SaveWaveAiData(blockId, [...history, newPrompt, responsePrompt]); const typeResponse = simulateResponse(response); Promise.all([writeToHistory, typeResponse]); }; temp(); }; return { messages, sendMessage, }; } } function makeWaveAiViewModel(blockId): WaveAiModel { const waveAiModel = new WaveAiModel(blockId); return waveAiModel; } const ChatItem = ({ chatItem, itemCount }: ChatItemProps) => { const { isAssistant, text, isError } = chatItem; const senderClassName = isAssistant ? "chat-msg-assistant" : "chat-msg-user"; const msgClassName = `chat-msg ${senderClassName}`; const cssVar = "--panel-bg-color"; const panelBgColor = getComputedStyle(document.documentElement).getPropertyValue(cssVar).trim(); const color = tinycolor(panelBgColor); const newColor = color.isValid() ? tinycolor(panelBgColor).darken(6).toString() : "none"; const backgroundColor = itemCount % 2 === 0 ? "none" : newColor; const renderError = (err: string): React.JSX.Element =>
{err}
; const renderContent = (): React.JSX.Element => { if (isAssistant) { if (isError) { return renderError(isError); } return text ? ( <>
) : ( <>
); } return ( <>
); }; return (
{renderContent()}
); }; interface ChatWindowProps { chatWindowRef: React.RefObject; messages: ChatMessageType[]; } const ChatWindow = memo( forwardRef(({ chatWindowRef, messages }, ref) => { const [isUserScrolling, setIsUserScrolling] = useState(false); const osRef = useRef(null); const prevMessagesLenRef = useRef(messages.length); useImperativeHandle(ref, () => osRef.current as OverlayScrollbarsComponentRef); useEffect(() => { if (osRef.current && osRef.current.osInstance()) { const { viewport } = osRef.current.osInstance().elements(); const curMessagesLen = messages.length; if (prevMessagesLenRef.current !== curMessagesLen || !isUserScrolling) { setIsUserScrolling(false); viewport.scrollTo({ behavior: "auto", top: chatWindowRef.current?.scrollHeight || 0, }); } prevMessagesLenRef.current = curMessagesLen; } }, [messages, isUserScrolling]); useEffect(() => { if (osRef.current && osRef.current.osInstance()) { const { viewport } = osRef.current.osInstance().elements(); const handleUserScroll = () => { setIsUserScrolling(true); }; viewport.addEventListener("wheel", handleUserScroll, { passive: true }); viewport.addEventListener("touchmove", handleUserScroll, { passive: true }); return () => { viewport.removeEventListener("wheel", handleUserScroll); viewport.removeEventListener("touchmove", handleUserScroll); if (osRef.current && osRef.current.osInstance()) { osRef.current.osInstance().destroy(); } }; } }, []); const handleScrollbarInitialized = (instance: OverlayScrollbars) => { const { viewport } = instance.elements(); viewport.scrollTo({ behavior: "auto", top: chatWindowRef.current?.scrollHeight || 0, }); }; return (
{messages.map((chitem, idx) => ( ))}
); }) ); interface ChatInputProps { value: string; termFontSize: number; onChange: (e: React.ChangeEvent) => void; onKeyDown: (e: React.KeyboardEvent) => void; onMouseDown: (e: React.MouseEvent) => void; model: WaveAiModel; } const ChatInput = forwardRef( ({ value, onChange, onKeyDown, onMouseDown, termFontSize, model }, ref) => { const textAreaRef = useRef(null); useImperativeHandle(ref, () => textAreaRef.current as HTMLTextAreaElement); useEffect(() => { model.textAreaRef = textAreaRef; }, []); const adjustTextAreaHeight = () => { if (textAreaRef.current == null) { return; } // Adjust the height of the textarea to fit the text const textAreaMaxLines = 100; const textAreaLineHeight = termFontSize * 1.5; const textAreaMinHeight = textAreaLineHeight; const textAreaMaxHeight = textAreaLineHeight * textAreaMaxLines; textAreaRef.current.style.height = "1px"; const scrollHeight = textAreaRef.current.scrollHeight; const newHeight = Math.min(Math.max(scrollHeight, textAreaMinHeight), textAreaMaxHeight); textAreaRef.current.style.height = newHeight + "px"; }; useEffect(() => { adjustTextAreaHeight(); }, [value]); return ( ); } ); const WaveAi = ({ model }: { model: WaveAiModel; blockId: string }) => { const { messages, sendMessage } = model.useWaveAi(); const waveaiRef = useRef(null); const chatWindowRef = useRef(null); const osRef = useRef(null); const inputRef = useRef(null); const submitTimeoutRef = useRef(null); const [value, setValue] = useState(""); const [selectedBlockIdx, setSelectedBlockIdx] = useState(null); const [isSubmitting, setIsSubmitting] = useState(false); const termFontSize: number = 14; // a weird workaround to initialize ansynchronously useEffect(() => { model.populateMessages(); }, []); useEffect(() => { return () => { if (submitTimeoutRef.current) { clearTimeout(submitTimeoutRef.current); } }; }, []); const submit = useCallback( (messageStr: string) => { if (!isSubmitting) { setIsSubmitting(true); sendMessage(messageStr); clearTimeout(submitTimeoutRef.current); submitTimeoutRef.current = setTimeout(() => { setIsSubmitting(false); }, 500); } }, [isSubmitting, sendMessage, setValue] ); const handleTextAreaChange = (e: React.ChangeEvent) => { setValue(e.target.value); }; const updatePreTagOutline = (clickedPre?: HTMLElement | null) => { const pres = chatWindowRef.current?.querySelectorAll("pre"); if (!pres) return; pres.forEach((preElement, idx) => { if (preElement === clickedPre) { setSelectedBlockIdx(idx); } else { preElement.style.outline = "none"; } }); if (clickedPre) { clickedPre.style.outline = outline; } }; useEffect(() => { if (selectedBlockIdx !== null) { const pres = chatWindowRef.current?.querySelectorAll("pre"); if (pres && pres[selectedBlockIdx]) { pres[selectedBlockIdx].style.outline = outline; } } }, [selectedBlockIdx]); const handleTextAreaMouseDown = () => { updatePreTagOutline(); setSelectedBlockIdx(null); }; const handleEnterKeyPressed = useCallback(() => { const isCurrentlyUpdating = messages.some((message) => message.isUpdating); if (isCurrentlyUpdating || value === "") return; submit(value); setValue(""); setSelectedBlockIdx(null); }, [messages, value]); const handleContainerClick = (event: React.MouseEvent) => { inputRef.current?.focus(); const target = event.target as HTMLElement; if ( target.closest(".copy-button") || target.closest(".fa-square-terminal") || target.closest(".waveai-input") ) { return; } const pre = target.closest("pre"); updatePreTagOutline(pre); }; const updateScrollTop = () => { const pres = chatWindowRef.current?.querySelectorAll("pre"); if (!pres || selectedBlockIdx === null) return; const block = pres[selectedBlockIdx]; if (!block || !osRef.current?.osInstance()) return; const { viewport, scrollOffsetElement } = osRef.current?.osInstance().elements(); const chatWindowTop = scrollOffsetElement.scrollTop; const chatWindowHeight = chatWindowRef.current.clientHeight; const chatWindowBottom = chatWindowTop + chatWindowHeight; const elemTop = block.offsetTop; const elemBottom = elemTop + block.offsetHeight; const elementIsInView = elemBottom <= chatWindowBottom && elemTop >= chatWindowTop; if (!elementIsInView) { let scrollPosition; if (elemBottom > chatWindowBottom) { scrollPosition = elemTop - chatWindowHeight + block.offsetHeight + 15; } else if (elemTop < chatWindowTop) { scrollPosition = elemTop - 15; } viewport.scrollTo({ behavior: "auto", top: scrollPosition, }); } }; const shouldSelectCodeBlock = (key: "ArrowUp" | "ArrowDown") => { const textarea = inputRef.current; const cursorPosition = textarea?.selectionStart || 0; const textBeforeCursor = textarea?.value.slice(0, cursorPosition) || ""; return ( (textBeforeCursor.indexOf("\n") === -1 && cursorPosition === 0 && key === "ArrowUp") || selectedBlockIdx !== null ); }; const handleArrowUpPressed = (e: React.KeyboardEvent) => { if (shouldSelectCodeBlock("ArrowUp")) { e.preventDefault(); const pres = chatWindowRef.current?.querySelectorAll("pre"); let blockIndex = selectedBlockIdx; if (!pres) return; if (blockIndex === null) { setSelectedBlockIdx(pres.length - 1); } else if (blockIndex > 0) { blockIndex--; setSelectedBlockIdx(blockIndex); } updateScrollTop(); } }; const handleArrowDownPressed = (e: React.KeyboardEvent) => { if (shouldSelectCodeBlock("ArrowDown")) { e.preventDefault(); const pres = chatWindowRef.current?.querySelectorAll("pre"); let blockIndex = selectedBlockIdx; if (!pres) return; if (blockIndex === null) return; if (blockIndex < pres.length - 1 && blockIndex >= 0) { setSelectedBlockIdx(++blockIndex); updateScrollTop(); } else { inputRef.current.focus(); setSelectedBlockIdx(null); } updateScrollTop(); } }; const handleTextAreaKeyDown = (e: React.KeyboardEvent) => { const waveEvent = adaptFromReactOrNativeKeyEvent(e); if (checkKeyPressed(waveEvent, "Enter")) { e.preventDefault(); handleEnterKeyPressed(); } else if (checkKeyPressed(waveEvent, "ArrowUp")) { handleArrowUpPressed(e); } else if (checkKeyPressed(waveEvent, "ArrowDown")) { handleArrowDownPressed(e); } }; return (
); }; export { makeWaveAiViewModel, WaveAi };