diff --git a/src/app/common/common.less b/src/app/common/common.less index 8552e15f3..ddfd827c2 100644 --- a/src/app/common/common.less +++ b/src/app/common/common.less @@ -352,7 +352,7 @@ background-color: @markdown-highlight; color: @term-white; font-family: @terminal-font; - display: inline-block; + border-radius: 4px; } code.inline { @@ -403,6 +403,13 @@ background-color: @markdown-highlight; margin: 4px 10px 4px 10px; padding: 6px 6px 6px 10px; + border-radius: 4px; + } + + pre.selected { + border-style: solid; + outline-width: 2px; + border-color: @term-green; } .title.is-1 { diff --git a/src/app/common/common.tsx b/src/app/common/common.tsx index 2a2b2c3fe..90183571f 100644 --- a/src/app/common/common.tsx +++ b/src/app/common/common.tsx @@ -11,6 +11,7 @@ import cn from "classnames"; import { If } from "tsx-control-statements/components"; import type { RemoteType } from "../../types/types"; import ReactDOM from "react-dom"; +import { GlobalModel } from "../../model/model"; import { ReactComponent as CheckIcon } from "../assets/icons/line/check.svg"; import { ReactComponent as CopyIcon } from "../assets/icons/history/copy.svg"; @@ -828,9 +829,67 @@ function CodeRenderer(props: any): any { } @mobxReact.observer -class Markdown extends React.Component<{ text: string; style?: any; extraClassName?: string }, {}> { +class CodeBlockMarkdown extends React.Component< + { children: React.ReactNode; blockText: string; codeSelectSelectedIndex?: number }, + {} +> { + blockIndex: number; + blockRef: React.RefObject; + + constructor(props) { + super(props); + this.blockRef = React.createRef(); + this.blockIndex = GlobalModel.inputModel.addCodeBlockToCodeSelect(this.blockRef); + } + + render() { + let codeText = this.props.blockText; + let clickHandler: (e: React.MouseEvent, blockIndex: number) => void; + let inputModel = GlobalModel.inputModel; + clickHandler = (e: React.MouseEvent, blockIndex: number) => { + inputModel.setCodeSelectSelectedCodeBlock(blockIndex); + }; + let selected = this.blockIndex == this.props.codeSelectSelectedIndex; + return ( +
 clickHandler(event, this.blockIndex)}
+            >
+                {this.props.children}
+            
+ ); + } +} + +@mobxReact.observer +class Markdown extends React.Component< + { text: string; style?: any; extraClassName?: string; codeSelect?: boolean }, + {} +> { + CodeBlockRenderer(props: any, codeSelect: boolean, codeSelectIndex: number): any { + let codeText = codeSelect ? props.node.children[0].children[0].value : props.children; + if (codeText) { + codeText = codeText.replace(/\n$/, ""); // remove trailing newline + } + if (codeSelect) { + return ( + + {props.children} + + ); + } else { + let clickHandler = (e: React.MouseEvent) => { + navigator.clipboard.writeText(codeText); + }; + return
 clickHandler(event)}>{props.children}
; + } + } + render() { let text = this.props.text; + let codeSelect = this.props.codeSelect; + let curCodeSelectIndex = GlobalModel.inputModel.getCodeSelectSelectedIndex(); let markdownComponents = { a: LinkRenderer, h1: (props) => HeaderRenderer(props, 1), @@ -839,7 +898,8 @@ class Markdown extends React.Component<{ text: string; style?: any; extraClassNa h4: (props) => HeaderRenderer(props, 4), h5: (props) => HeaderRenderer(props, 5), h6: (props) => HeaderRenderer(props, 6), - code: CodeRenderer, + code: (props) => CodeRenderer(props), + pre: (props) => this.CodeBlockRenderer(props, codeSelect, curCodeSelectIndex), }; return (
diff --git a/src/app/workspace/cmdinput/aichat.tsx b/src/app/workspace/cmdinput/aichat.tsx new file mode 100644 index 000000000..008fe8eba --- /dev/null +++ b/src/app/workspace/cmdinput/aichat.tsx @@ -0,0 +1,219 @@ +// Copyright 2023, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +import * as React from "react"; +import * as mobxReact from "mobx-react"; +import * as mobx from "mobx"; +import { GlobalModel } from "../../../model/model"; +import { isBlank } from "../../../util/util"; +import { boundMethod } from "autobind-decorator"; +import cn from "classnames"; +import { Prompt } from "../../common/prompt/prompt"; +import { TextAreaInput } from "./textareainput"; +import { If, For } from "tsx-control-statements/components"; +import type { OpenAICmdInfoChatMessageType } from "../../../types/types"; +import { Markdown } from "../../common/common"; + +@mobxReact.observer +class AIChat extends React.Component<{}, {}> { + chatListKeyCount: number = 0; + textAreaNumLines: mobx.IObservableValue = mobx.observable.box(1, { name: "textAreaNumLines" }); + chatWindowScrollRef: React.RefObject; + textAreaRef: React.RefObject; + + constructor(props: any) { + super(props); + this.chatWindowScrollRef = React.createRef(); + this.textAreaRef = React.createRef(); + } + + componentDidMount() { + let model = GlobalModel; + let inputModel = model.inputModel; + if (this.chatWindowScrollRef != null && this.chatWindowScrollRef.current != null) { + this.chatWindowScrollRef.current.scrollTop = this.chatWindowScrollRef.current.scrollHeight; + } + if (this.textAreaRef.current != null) { + this.textAreaRef.current.focus(); + inputModel.setCmdInfoChatRefs(this.textAreaRef, this.chatWindowScrollRef); + } + this.requestChatUpdate(); + } + + componentDidUpdate() { + if (this.chatWindowScrollRef != null && this.chatWindowScrollRef.current != null) { + this.chatWindowScrollRef.current.scrollTop = this.chatWindowScrollRef.current.scrollHeight; + } + } + + requestChatUpdate() { + this.submitChatMessage(""); + } + + submitChatMessage(messageStr: string) { + let model = GlobalModel; + let inputModel = model.inputModel; + let curLine = inputModel.getCurLine(); + let prtn = GlobalModel.submitChatInfoCommand(messageStr, curLine, false); + prtn.then((rtn) => { + if (!rtn.success) { + console.log("submit chat command error: " + rtn.error); + } + }).catch((error) => {}); + } + + getLinePos(elem: any): { numLines: number; linePos: number } { + let numLines = elem.value.split("\n").length; + let linePos = elem.value.substr(0, elem.selectionStart).split("\n").length; + return { numLines, linePos }; + } + + @mobx.action + @boundMethod + onKeyDown(e: any) { + mobx.action(() => { + let model = GlobalModel; + let inputModel = model.inputModel; + let ctrlMod = e.getModifierState("Control") || e.getModifierState("Meta") || e.getModifierState("Shift"); + let resetCodeSelect = !ctrlMod; + + if (e.code == "Enter") { + e.preventDefault(); + if (!ctrlMod) { + if (inputModel.getCodeSelectSelectedIndex() == -1) { + let messageStr = e.target.value; + this.submitChatMessage(messageStr); + e.target.value = ""; + } else { + inputModel.grabCodeSelectSelection(); + } + } else { + e.target.setRangeText("\n", e.target.selectionStart, e.target.selectionEnd, "end"); + } + } + if (e.code == "Escape") { + e.preventDefault(); + e.stopPropagation(); + inputModel.closeAIAssistantChat(); + } + if (e.code == "KeyL" && e.getModifierState("Control")) { + e.preventDefault(); + e.stopPropagation(); + inputModel.clearAIAssistantChat(); + } + if (e.code == "ArrowUp") { + if (this.getLinePos(e.target).linePos > 1) { + // normal up arrow + return; + } + e.preventDefault(); + inputModel.codeSelectSelectNextOldestCodeBlock(); + resetCodeSelect = false; + } + if (e.code == "ArrowDown") { + if (inputModel.getCodeSelectSelectedIndex() == inputModel.codeSelectBottom) { + return; + } + e.preventDefault(); + inputModel.codeSelectSelectNextNewestCodeBlock(); + resetCodeSelect = false; + } + + if (resetCodeSelect) { + inputModel.codeSelectDeselectAll(); + } + + // set height of textarea based on number of newlines + this.textAreaNumLines.set(e.target.value.split(/\n/).length); + })(); + } + + renderError(err: string): any { + return
{err}
; + } + + renderChatMessage(chatItem: OpenAICmdInfoChatMessageType): any { + let curKey = "chatmsg-" + this.chatListKeyCount; + this.chatListKeyCount++; + let senderClassName = chatItem.isassistantresponse ? "chat-msg-assistant" : "chat-msg-user"; + let msgClassName = "chat-msg " + senderClassName; + let innerHTML: React.JSX.Element = ( + + + +

You

+
+

{chatItem.userquery}

+
+ ); + if (chatItem.isassistantresponse) { + if (chatItem.assistantresponse.error != null && chatItem.assistantresponse.error != "") { + innerHTML = this.renderError(chatItem.assistantresponse.error); + } else { + innerHTML = ( + + + +

ChatGPT

+
+ +
+ ); + } + } + + return ( +
+ {innerHTML} +
+ ); + } + + renderChatWindow(): any { + let model = GlobalModel; + let inputModel = model.inputModel; + let chatMessageItems = inputModel.AICmdInfoChatItems.slice(); + let chitem: OpenAICmdInfoChatMessageType = null; + return ( +
+ + {this.renderChatMessage(chitem)} + +
+ ); + } + + render() { + let model = GlobalModel; + let inputModel = model.inputModel; + + const termFontSize = 14; + const textAreaMaxLines = 4; + const textAreaLineHeight = termFontSize * 1.5; + const textAreaPadding = 2 * 0.5 * termFontSize; + let textAreaMaxHeight = textAreaLineHeight * textAreaMaxLines + textAreaPadding; + let textAreaInnerHeight = this.textAreaNumLines.get() * textAreaLineHeight + textAreaPadding; + + return ( +
+ {this.renderChatWindow()} + +
+ ); + } +} + +export { AIChat }; diff --git a/src/app/workspace/cmdinput/cmdinput.less b/src/app/workspace/cmdinput/cmdinput.less index 7324b86ea..69efc8521 100644 --- a/src/app/workspace/cmdinput/cmdinput.less +++ b/src/app/workspace/cmdinput/cmdinput.less @@ -42,6 +42,10 @@ max-height: max(300px, 70%); } + &.has-aichat { + max-height: max(300px, 70%); + } + .remote-status-warning { display: flex; flex-direction: row; @@ -196,6 +200,72 @@ } } } + + .cmd-aichat { + display: flex; + justify-content: flex-end; + flex-flow: column nowrap; + margin-bottom: 10px; + flex-shrink: 1; + overflow-y: auto; + + .chat-window { + overflow-y: auto; + margin-bottom: 5px; + flex-shrink: 1; + flex-direction: column-reverse; + } + + .chat-textarea { + color: @term-bright-white; + background-color: @textarea-background; + padding: 0.5em; + resize: none; + overflow: auto; + overflow-wrap: anywhere; + border-color: transparent; + border: none; + font-family: @terminal-font; + flex-shrink: 0; + flex-grow: 1; + border-radius: 4px; + + &:focus { + box-shadow: none; + border: none; + outline: none; + } + } + + .chat-msg { + margin-top:5px; + margin-bottom:5px; + } + + .chat-msg-assistant { + color: @term-white; + } + + .chat-msg-user { + + .msg-text { + font-family: @markdown-font; + font-size: 14px; + white-space: pre-wrap; + } + } + + .chat-msg-error { + color: @term-bright-red; + font-family: @markdown-font; + font-size: 14px; + } + + + .grow-spacer { + flex: 1 0 10px; + } + } .cmd-history { color: @term-white; diff --git a/src/app/workspace/cmdinput/cmdinput.tsx b/src/app/workspace/cmdinput/cmdinput.tsx index fd157b776..a13e961d0 100644 --- a/src/app/workspace/cmdinput/cmdinput.tsx +++ b/src/app/workspace/cmdinput/cmdinput.tsx @@ -19,6 +19,7 @@ import { Prompt } from "../../common/prompt/prompt"; import { ReactComponent as ExecIcon } from "../../assets/icons/exec.svg"; import { ReactComponent as RotateIcon } from "../../assets/icons/line/rotate.svg"; import "./cmdinput.less"; +import { AIChat } from "./aichat"; dayjs.extend(localizedFormat); @@ -116,6 +117,7 @@ class CmdInput extends React.Component<{}, {}> { } let infoShow = inputModel.infoShow.get(); let historyShow = !infoShow && inputModel.historyShow.get(); + let aiChatShow = inputModel.aIChatShow.get(); let infoMsg = inputModel.infoMsg.get(); let hasInfo = infoMsg != null; let focusVal = inputModel.physicalInputFocused.get(); @@ -127,11 +129,23 @@ class CmdInput extends React.Component<{}, {}> { numRunningLines = mobx.computed(() => win.getRunningCmdLines().length).get(); } return ( -
+
+ +
+ +
diff --git a/src/app/workspace/cmdinput/textareainput.tsx b/src/app/workspace/cmdinput/textareainput.tsx index 04e0c6b93..8cca3b89d 100644 --- a/src/app/workspace/cmdinput/textareainput.tsx +++ b/src/app/workspace/cmdinput/textareainput.tsx @@ -230,6 +230,7 @@ class TextAreaInput extends React.Component<{ screen: Screen; onHeightChange: () if (inputModel.inputMode.get() != null) { inputModel.resetInputMode(); } + inputModel.closeAIAssistantChat(); return; } if (e.code == "KeyE" && e.getModifierState("Meta")) { @@ -313,6 +314,10 @@ class TextAreaInput extends React.Component<{ screen: Screen; onHeightChange: () scrollDiv(div, e.code == "PageUp" ? -amt : amt); } } + if (e.code == "Space" && e.getModifierState("Control")) { + e.preventDefault(); + inputModel.openAIAssistantChat(); + } // console.log(e.code, e.keyCode, e.key, event.which, ctrlMod, e); })(); } diff --git a/src/model/model.ts b/src/model/model.ts index 6b867636f..f7cc0d56b 100644 --- a/src/model/model.ts +++ b/src/model/model.ts @@ -65,6 +65,7 @@ import type { CommandRtnType, WebCmd, WebRemote, + OpenAICmdInfoChatMessageType, } from "../types/types"; import * as T from "../types/types"; import { WSControl } from "./ws"; @@ -1233,7 +1234,18 @@ function getDefaultHistoryQueryOpts(): HistoryQueryOpts { class InputModel { historyShow: OV = mobx.observable.box(false); infoShow: OV = mobx.observable.box(false); + aIChatShow: OV = mobx.observable.box(false); cmdInputHeight: OV = mobx.observable.box(0); + aiChatTextAreaRef: React.RefObject; + aiChatWindowRef: React.RefObject; + codeSelectBlockRefArray: Array>; + codeSelectSelectedIndex: OV = mobx.observable.box(-1); + + AICmdInfoChatItems: mobx.IObservableArray = mobx.observable.array([], { + name: "aicmdinfo-chat", + }); + readonly codeSelectTop: number = -2; + readonly codeSelectBottom: number = -1; historyType: mobx.IObservableValue = mobx.observable.box("screen"); historyLoading: mobx.IObservableValue = mobx.observable.box(false); @@ -1271,6 +1283,10 @@ class InputModel { this.filteredHistoryItems = mobx.computed(() => { return this._getFilteredHistoryItems(); }); + mobx.action(() => { + this.codeSelectSelectedIndex.set(-1); + this.codeSelectBlockRefArray = []; + })(); } setInputMode(inputMode: null | "comment" | "global"): void { @@ -1395,6 +1411,11 @@ class InputModel { })(); } + setOpenAICmdInfoChat(chat: OpenAICmdInfoChatMessageType[]): void { + this.AICmdInfoChatItems.replace(chat); + this.codeSelectBlockRefArray = []; + } + setHistoryShow(show: boolean): void { if (this.historyShow.get() == show) { return; @@ -1683,6 +1704,152 @@ class InputModel { } } + setCmdInfoChatRefs( + textAreaRef: React.RefObject, + chatWindowRef: React.RefObject + ) { + this.aiChatTextAreaRef = textAreaRef; + this.aiChatWindowRef = chatWindowRef; + } + + setAIChatFocus() { + if (this.aiChatTextAreaRef != null && this.aiChatTextAreaRef.current != null) { + this.aiChatTextAreaRef.current.focus(); + } + } + + grabCodeSelectSelection() { + if ( + this.codeSelectSelectedIndex.get() >= 0 && + this.codeSelectSelectedIndex.get() < this.codeSelectBlockRefArray.length + ) { + let curBlockRef = this.codeSelectBlockRefArray[this.codeSelectSelectedIndex.get()]; + let codeText = curBlockRef.current.innerText; + codeText = codeText.replace(/\n$/, ""); // remove trailing newline + let newLineValue = this.getCurLine() + " " + codeText; + this.setCurLine(newLineValue); + this.giveFocus(); + } + } + + addCodeBlockToCodeSelect(blockRef: React.RefObject): number { + let rtn = -1; + rtn = this.codeSelectBlockRefArray.length; + this.codeSelectBlockRefArray.push(blockRef); + return rtn; + } + + setCodeSelectSelectedCodeBlock(blockIndex: number) { + mobx.action(() => { + if (blockIndex >= 0 && blockIndex < this.codeSelectBlockRefArray.length) { + this.codeSelectSelectedIndex.set(blockIndex); + let currentRef = this.codeSelectBlockRefArray[blockIndex].current; + if (currentRef != null) { + if (this.aiChatWindowRef != null && this.aiChatWindowRef.current != null) { + let chatWindowTop = this.aiChatWindowRef.current.scrollTop; + let chatWindowBottom = chatWindowTop + this.aiChatWindowRef.current.clientHeight - 100; + let elemTop = currentRef.offsetTop; + let elemBottom = elemTop - currentRef.offsetHeight; + let elementIsInView = elemBottom < chatWindowBottom && elemTop > chatWindowTop; + if (!elementIsInView) { + this.aiChatWindowRef.current.scrollTop = + elemBottom - this.aiChatWindowRef.current.clientHeight / 3; + } + } + } + this.codeSelectBlockRefArray = []; + this.setAIChatFocus(); + } + })(); + } + + codeSelectSelectNextNewestCodeBlock() { + // oldest code block = index 0 in array + // this decrements codeSelectSelected index + mobx.action(() => { + if (this.codeSelectSelectedIndex.get() == this.codeSelectTop) { + this.codeSelectSelectedIndex.set(this.codeSelectBottom); + } else if (this.codeSelectSelectedIndex.get() == this.codeSelectBottom) { + return; + } + let incBlockIndex = this.codeSelectSelectedIndex.get() + 1; + if (this.codeSelectSelectedIndex.get() == this.codeSelectBlockRefArray.length - 1) { + this.codeSelectDeselectAll(); + if (this.aiChatWindowRef != null && this.aiChatWindowRef.current != null) { + this.aiChatWindowRef.current.scrollTop = this.aiChatWindowRef.current.scrollHeight; + } + } + if (incBlockIndex >= 0 && incBlockIndex < this.codeSelectBlockRefArray.length) { + this.setCodeSelectSelectedCodeBlock(incBlockIndex); + } + })(); + } + + codeSelectSelectNextOldestCodeBlock() { + mobx.action(() => { + if (this.codeSelectSelectedIndex.get() == this.codeSelectBottom) { + if (this.codeSelectBlockRefArray.length > 0) { + this.codeSelectSelectedIndex.set(this.codeSelectBlockRefArray.length); + } else { + return; + } + } else if (this.codeSelectSelectedIndex.get() == this.codeSelectTop) { + return; + } + let decBlockIndex = this.codeSelectSelectedIndex.get() - 1; + if (decBlockIndex < 0) { + this.codeSelectDeselectAll(this.codeSelectTop); + if (this.aiChatWindowRef != null && this.aiChatWindowRef.current != null) { + this.aiChatWindowRef.current.scrollTop = 0; + } + } + if (decBlockIndex >= 0 && decBlockIndex < this.codeSelectBlockRefArray.length) { + this.setCodeSelectSelectedCodeBlock(decBlockIndex); + } + })(); + } + + getCodeSelectSelectedIndex() { + return this.codeSelectSelectedIndex.get(); + } + + getCodeSelectRefArrayLength() { + return this.codeSelectBlockRefArray.length; + } + + codeBlockIsSelected(blockIndex: number): boolean { + return blockIndex == this.codeSelectSelectedIndex.get(); + } + + codeSelectDeselectAll(direction: number = this.codeSelectBottom) { + mobx.action(() => { + this.codeSelectSelectedIndex.set(direction); + this.codeSelectBlockRefArray = []; + })(); + } + + openAIAssistantChat(): void { + this.aIChatShow.set(true); + this.setAIChatFocus(); + } + + closeAIAssistantChat(): void { + this.aIChatShow.set(false); + this.giveFocus(); + } + + clearAIAssistantChat(): void { + let prtn = GlobalModel.submitChatInfoCommand("", "", true); + prtn.then((rtn) => { + if (rtn.success) { + } else { + console.log("submit chat command error: " + rtn.error); + } + }).catch((error) => { + console.log("submit chat command error: ", error); + }); + } + hasScrollingInfoMsg(): boolean { if (!this.infoShow.get()) { return false; @@ -1778,6 +1945,7 @@ class InputModel { resetInput(): void { mobx.action(() => { this.setHistoryShow(false); + this.closeAIAssistantChat(); this.infoShow.set(false); this.inputMode.set(null); this.resetHistory(); @@ -3834,6 +4002,9 @@ class Model { this.sessionListLoaded.set(true); this.remotesLoaded.set(true); } + if ("openaicmdinfochat" in update) { + this.inputModel.setOpenAICmdInfoChat(update.openaicmdinfochat); + } // console.log("run-update>", Date.now(), interactive, update); } @@ -4068,6 +4239,28 @@ class Model { return this.submitCommandPacket(pk, interactive); } + submitChatInfoCommand(chatMsg: string, curLineStr: string, clear: boolean): Promise { + let commandStr = "/chat " + chatMsg; + let interactive = false; + let pk: FeCmdPacketType = { + type: "fecmd", + metacmd: "eval", + args: [commandStr], + kwargs: {}, + uicontext: this.getUIContext(), + interactive: interactive, + rawstr: chatMsg, + }; + pk.kwargs["nohist"] = "1"; + if (clear) { + pk.kwargs["cmdinfoclear"] = "1"; + } else { + pk.kwargs["cmdinfo"] = "1"; + } + pk.kwargs["curline"] = curLineStr; + return this.submitCommandPacket(pk, interactive); + } + submitRawCommand(cmdStr: string, addToHistory: boolean, interactive: boolean): Promise { let pk: FeCmdPacketType = { type: "fecmd", diff --git a/src/types/types.ts b/src/types/types.ts index 29b6fb070..c69bbd709 100644 --- a/src/types/types.ts +++ b/src/types/types.ts @@ -265,6 +265,20 @@ type ScreenLinesType = { cmds: CmdDataType[]; }; +type OpenAIPacketOutputType = { + model: string; + created: number; + finish_reason: string; + message: string; + error?: string; +}; + +type OpenAICmdInfoChatMessageType = { + isassistantresponse?: boolean; + assistantresponse?: OpenAIPacketOutputType; + userquery?: string; +}; + type ModelUpdateType = { interactive: boolean; sessions?: SessionDataType[]; @@ -285,6 +299,7 @@ type ModelUpdateType = { clientdata?: ClientDataType; historyviewdata?: HistoryViewDataType; remoteview?: RemoteViewType; + openaicmdinfochat?: OpenAICmdInfoChatMessageType[]; alertmessage?: AlertMessageType; }; @@ -763,4 +778,5 @@ export type { ModalStoreEntry, StrWithPos, CmdInputTextPacketType, + OpenAICmdInfoChatMessageType, }; diff --git a/waveshell/pkg/packet/packet.go b/waveshell/pkg/packet/packet.go index f75030558..b485b6919 100644 --- a/waveshell/pkg/packet/packet.go +++ b/waveshell/pkg/packet/packet.go @@ -70,6 +70,8 @@ const PacketEOFStr = "EOF" var TypeStrToFactory map[string]reflect.Type +const OpenAICmdInfoChatGreetingMessage = "Hello, may I help you with this command? \n(Press ESC to close and Ctrl+L to clear chat buffer)" + func init() { TypeStrToFactory = make(map[string]reflect.Type) TypeStrToFactory[RunPacketStr] = reflect.TypeOf(RunPacketType{}) @@ -729,6 +731,14 @@ type OpenAIUsageType struct { TotalTokens int `json:"total_tokens,omitempty"` } +type OpenAICmdInfoPacketOutputType struct { + Model string `json:"model,omitempty"` + Created int64 `json:"created,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + Message string `json:"message,omitempty"` + Error string `json:"error,omitempty"` +} + type OpenAIPacketType struct { Type string `json:"type"` Model string `json:"model,omitempty"` @@ -843,6 +853,14 @@ func MakeWriteFileDonePacket(reqId string) *WriteFileDonePacketType { } } +type OpenAICmdInfoChatMessage struct { + MessageID int `json:"messageid"` + IsAssistantResponse bool `json:"isassistantresponse,omitempty"` + AssistantResponse *OpenAICmdInfoPacketOutputType `json:"assistantresponse,omitempty"` + UserQuery string `json:"userquery,omitempty"` + UserEngineeredQuery string `json:"userengineeredquery,omitempty"` +} + type OpenAIPromptMessageType struct { Role string `json:"role"` Content string `json:"content"` diff --git a/wavesrv/pkg/cmdrunner/cmdrunner.go b/wavesrv/pkg/cmdrunner/cmdrunner.go index 810771a58..36ff58d63 100644 --- a/wavesrv/pkg/cmdrunner/cmdrunner.go +++ b/wavesrv/pkg/cmdrunner/cmdrunner.go @@ -2001,6 +2001,114 @@ func doOpenAICompletion(cmd *sstore.CmdType, opts *sstore.OpenAIOptsType, prompt return } +func writePacketToUpdateBus(ctx context.Context, cmd *sstore.CmdType, pk *packet.OpenAICmdInfoChatMessage) { + update, err := sstore.UpdateWithAddNewOpenAICmdInfoPacket(ctx, cmd.ScreenId, pk) + if err != nil { + log.Printf("Open AI Update packet err: %v\n", err) + } + sstore.MainBus.SendScreenUpdate(cmd.ScreenId, update) +} + +func updateAsstResponseAndWriteToUpdateBus(ctx context.Context, cmd *sstore.CmdType, pk *packet.OpenAICmdInfoChatMessage, messageID int) { + update, err := sstore.UpdateWithUpdateOpenAICmdInfoPacket(ctx, cmd.ScreenId, messageID, pk) + if err != nil { + log.Printf("Open AI Update packet err: %v\n", err) + } + sstore.MainBus.SendScreenUpdate(cmd.ScreenId, update) +} + +func getCmdInfoEngineeredPrompt(userQuery string, curLineStr string) string { + rtn := "You are an expert on the command line terminal. Your task is to help me write a command." + if curLineStr != "" { + rtn += "My current command is: " + curLineStr + } + rtn += ". My question is: " + userQuery + "." + return rtn +} + +func doOpenAICmdInfoCompletion(cmd *sstore.CmdType, clientId string, opts *sstore.OpenAIOptsType, prompt []packet.OpenAIPromptMessageType, curLineStr string) { + var hadError bool + log.Println("had error: ", hadError) + ctx, cancelFn := context.WithTimeout(context.Background(), OpenAIStreamTimeout) + defer cancelFn() + defer func() { + r := recover() + if r != nil { + panicMsg := fmt.Sprintf("panic: %v", r) + log.Printf("panic in doOpenAICompletion: %s\n", panicMsg) + hadError = true + } + }() + var ch chan *packet.OpenAIPacketType + var err error + if opts.APIToken == "" { + var conn *websocket.Conn + ch, conn, err = openai.RunCloudCompletionStream(ctx, clientId, opts, prompt) + if conn != nil { + defer conn.Close() + } + } else { + ch, err = openai.RunCompletionStream(ctx, opts, prompt) + } + asstOutputPk := &packet.OpenAICmdInfoPacketOutputType{ + Model: "", + Created: 0, + FinishReason: "", + Message: "", + } + asstOutputMessageID := sstore.ScreenMemGetCmdInfoMessageCount(cmd.ScreenId) + asstMessagePk := &packet.OpenAICmdInfoChatMessage{IsAssistantResponse: true, AssistantResponse: asstOutputPk, MessageID: asstOutputMessageID} + if err != nil { + asstOutputPk.Error = fmt.Sprintf("Error calling OpenAI API: %v", err) + writePacketToUpdateBus(ctx, cmd, asstMessagePk) + return + } + writePacketToUpdateBus(ctx, cmd, asstMessagePk) + doneWaitingForPackets := false + for !doneWaitingForPackets { + select { + case <-time.After(OpenAIPacketTimeout): + // timeout reading from channel + hadError = true + doneWaitingForPackets = true + asstOutputPk.Error = "timeout waiting for server response" + updateAsstResponseAndWriteToUpdateBus(ctx, cmd, asstMessagePk, asstOutputMessageID) + break + case pk, ok := <-ch: + if ok { + // got a packet + if pk.Error != "" { + hadError = true + asstOutputPk.Error = pk.Error + } + if pk.Model != "" && pk.Index == 0 { + asstOutputPk.Model = pk.Model + asstOutputPk.Created = pk.Created + asstOutputPk.FinishReason = pk.FinishReason + if pk.Text != "" { + asstOutputPk.Message += pk.Text + } + } + if pk.Index == 0 { + if pk.FinishReason != "" { + asstOutputPk.FinishReason = pk.FinishReason + } + if pk.Text != "" { + asstOutputPk.Message += pk.Text + } + } + asstMessagePk.AssistantResponse = asstOutputPk + updateAsstResponseAndWriteToUpdateBus(ctx, cmd, asstMessagePk, asstOutputMessageID) + + } else { + // channel closed + doneWaitingForPackets = true + break + } + } + } +} + func doOpenAIStreamCompletion(cmd *sstore.CmdType, clientId string, opts *sstore.OpenAIOptsType, prompt []packet.OpenAIPromptMessageType) { var outputPos int64 var hadError bool @@ -2086,6 +2194,23 @@ func doOpenAIStreamCompletion(cmd *sstore.CmdType, clientId string, opts *sstore return } +func BuildOpenAIPromptArrayWithContext(messages []*packet.OpenAICmdInfoChatMessage) []packet.OpenAIPromptMessageType { + rtn := make([]packet.OpenAIPromptMessageType, 0) + for _, msg := range messages { + content := msg.UserEngineeredQuery + if msg.UserEngineeredQuery == "" { + content = msg.UserQuery + } + msgRole := sstore.OpenAIRoleUser + if msg.IsAssistantResponse { + msgRole = sstore.OpenAIRoleAssistant + content = msg.AssistantResponse.Message + } + rtn = append(rtn, packet.OpenAIPromptMessageType{Role: msgRole, Content: content}) + } + return rtn +} + func OpenAICommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstore.UpdatePacket, error) { ids, err := resolveUiIds(ctx, pk, R_Session|R_Screen) if err != nil { @@ -2111,9 +2236,6 @@ func OpenAICommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstor opts.MaxTokens = openai.DefaultMaxTokens } promptStr := firstArg(pk) - if promptStr == "" { - return nil, fmt.Errorf("openai error, prompt string is blank") - } ptermVal := defaultStr(pk.Kwargs["wterm"], DefaultPTERM) pkTermOpts, err := GetUITermOpts(pk.UIContext.WinSize, ptermVal) if err != nil { @@ -2124,11 +2246,40 @@ func OpenAICommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstor if err != nil { return nil, fmt.Errorf("openai error, cannot make dyn cmd") } + if resolveBool(pk.Kwargs["cmdinfo"], false) { + if promptStr == "" { + // this is requesting an update without wanting an openai query + update, err := sstore.UpdateWithCurrentOpenAICmdInfoChat(cmd.ScreenId) + if err != nil { + return nil, fmt.Errorf("error getting update for CmdInfoChat %v", err) + } + return update, nil + } + curLineStr := defaultStr(pk.Kwargs["curline"], "") + userQueryPk := &packet.OpenAICmdInfoChatMessage{UserQuery: promptStr, MessageID: sstore.ScreenMemGetCmdInfoMessageCount(cmd.ScreenId)} + engineeredQuery := getCmdInfoEngineeredPrompt(promptStr, curLineStr) + userQueryPk.UserEngineeredQuery = engineeredQuery + writePacketToUpdateBus(ctx, cmd, userQueryPk) + prompt := BuildOpenAIPromptArrayWithContext(sstore.ScreenMemGetCmdInfoChat(cmd.ScreenId).Messages) + go doOpenAICmdInfoCompletion(cmd, clientData.ClientId, opts, prompt, curLineStr) + update := &sstore.ModelUpdate{} + return update, nil + } + prompt := []packet.OpenAIPromptMessageType{{Role: sstore.OpenAIRoleUser, Content: promptStr}} + if resolveBool(pk.Kwargs["cmdinfoclear"], false) { + update, err := sstore.UpdateWithClearOpenAICmdInfo(cmd.ScreenId) + if err != nil { + return nil, fmt.Errorf("error clearing CmdInfoChat: %v", err) + } + return update, nil + } + if promptStr == "" { + return nil, fmt.Errorf("openai error, prompt string is blank") + } line, err := sstore.AddOpenAILine(ctx, ids.ScreenId, DefaultUserId, cmd) if err != nil { return nil, fmt.Errorf("cannot add new line: %v", err) } - prompt := []packet.OpenAIPromptMessageType{{Role: sstore.OpenAIRoleUser, Content: promptStr}} if resolveBool(pk.Kwargs["stream"], true) { go doOpenAIStreamCompletion(cmd, clientData.ClientId, opts, prompt) } else { diff --git a/wavesrv/pkg/sstore/dbops.go b/wavesrv/pkg/sstore/dbops.go index d2653cca5..e6a8b9278 100644 --- a/wavesrv/pkg/sstore/dbops.go +++ b/wavesrv/pkg/sstore/dbops.go @@ -748,6 +748,7 @@ func InsertScreen(ctx context.Context, sessionId string, origScreenName string, return nil, txErr } update.Sessions = []*SessionType{bareSession} + update.OpenAICmdInfoChat = ScreenMemGetCmdInfoChat(newScreenId).Messages } return update, nil } @@ -854,6 +855,29 @@ func GetCmdByScreenId(ctx context.Context, screenId string, lineId string) (*Cmd }) } +func UpdateWithClearOpenAICmdInfo(screenId string) (*ModelUpdate, error) { + ScreenMemClearCmdInfoChat(screenId) + return UpdateWithCurrentOpenAICmdInfoChat(screenId) +} + +func UpdateWithAddNewOpenAICmdInfoPacket(ctx context.Context, screenId string, pk *packet.OpenAICmdInfoChatMessage) (*ModelUpdate, error) { + ScreenMemAddCmdInfoChatMessage(screenId, pk) + return UpdateWithCurrentOpenAICmdInfoChat(screenId) +} + +func UpdateWithCurrentOpenAICmdInfoChat(screenId string) (*ModelUpdate, error) { + cmdInfoUpdate := ScreenMemGetCmdInfoChat(screenId).Messages + return &ModelUpdate{OpenAICmdInfoChat: cmdInfoUpdate}, nil +} + +func UpdateWithUpdateOpenAICmdInfoPacket(ctx context.Context, screenId string, messageID int, pk *packet.OpenAICmdInfoChatMessage) (*ModelUpdate, error) { + err := ScreenMemUpdateCmdInfoChatMessage(screenId, messageID, pk) + if err != nil { + return nil, err + } + return UpdateWithCurrentOpenAICmdInfoChat(screenId) +} + func UpdateCmdDoneInfo(ctx context.Context, ck base.CommandKey, donePk *packet.CmdDonePacketType, status string) (*ModelUpdate, error) { if donePk == nil { return nil, fmt.Errorf("invalid cmddone packet") @@ -1039,6 +1063,7 @@ func SwitchScreenById(ctx context.Context, sessionId string, screenId string) (* memState := GetScreenMemState(screenId) if memState != nil { update.CmdLine = &memState.CmdInputText + update.OpenAICmdInfoChat = ScreenMemGetCmdInfoChat(screenId).Messages } return update, nil } diff --git a/wavesrv/pkg/sstore/memops.go b/wavesrv/pkg/sstore/memops.go index 4571d8089..9b72a3bbf 100644 --- a/wavesrv/pkg/sstore/memops.go +++ b/wavesrv/pkg/sstore/memops.go @@ -5,9 +5,11 @@ package sstore import ( + "fmt" "log" "sync" + "github.com/wavetermdev/waveterm/waveshell/pkg/packet" "github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn" ) @@ -43,11 +45,109 @@ func isIndicatorGreater(i1 string, i2 string) bool { return screenIndicatorLevels[i1] > screenIndicatorLevels[i2] } +type OpenAICmdInfoChatStore struct { + MessageCount int `json:"messagecount"` + Messages []*packet.OpenAICmdInfoChatMessage `json:"messages"` +} + type ScreenMemState struct { - NumRunningCommands int `json:"numrunningcommands,omitempty"` - IndicatorType string `json:"indicatortype,omitempty"` - CmdInputText utilfn.StrWithPos `json:"cmdinputtext,omitempty"` - CmdInputSeqNum int `json:"cmdinputseqnum,omitempty"` + NumRunningCommands int `json:"numrunningcommands,omitempty"` + IndicatorType string `json:"indicatortype,omitempty"` + CmdInputText utilfn.StrWithPos `json:"cmdinputtext,omitempty"` + CmdInputSeqNum int `json:"cmdinputseqnum,omitempty"` + AICmdInfoChat *OpenAICmdInfoChatStore `json:"aicmdinfochat,omitempty"` +} + +func ScreenMemDeepCopyCmdInfoChatStore(store *OpenAICmdInfoChatStore) *OpenAICmdInfoChatStore { + rtnMessages := []*packet.OpenAICmdInfoChatMessage{} + for index := 0; index < len(store.Messages); index++ { + messageToCopy := *store.Messages[index] + if messageToCopy.AssistantResponse != nil { + assistantResponseCopy := *messageToCopy.AssistantResponse + messageToCopy.AssistantResponse = &assistantResponseCopy + } + rtnMessages = append(rtnMessages, &messageToCopy) + } + rtn := &OpenAICmdInfoChatStore{MessageCount: store.MessageCount, Messages: rtnMessages} + return rtn +} + +func ScreenMemInitCmdInfoChat(screenId string) { + greetingMessagePk := &packet.OpenAICmdInfoChatMessage{ + MessageID: 0, + IsAssistantResponse: true, + AssistantResponse: &packet.OpenAICmdInfoPacketOutputType{ + Message: packet.OpenAICmdInfoChatGreetingMessage, + }, + } + ScreenMemStore[screenId].AICmdInfoChat = &OpenAICmdInfoChatStore{MessageCount: 1, Messages: []*packet.OpenAICmdInfoChatMessage{greetingMessagePk}} +} + +func ScreenMemClearCmdInfoChat(screenId string) { + MemLock.Lock() + defer MemLock.Unlock() + if ScreenMemStore[screenId] == nil { + ScreenMemStore[screenId] = &ScreenMemState{} + } + ScreenMemInitCmdInfoChat(screenId) +} + +func ScreenMemAddCmdInfoChatMessage(screenId string, msg *packet.OpenAICmdInfoChatMessage) { + MemLock.Lock() + defer MemLock.Unlock() + if ScreenMemStore[screenId] == nil { + ScreenMemStore[screenId] = &ScreenMemState{} + } + if ScreenMemStore[screenId].AICmdInfoChat == nil { + log.Printf("AICmdInfoChat is null, creating") + ScreenMemInitCmdInfoChat(screenId) + } + + CmdInfoChat := ScreenMemStore[screenId].AICmdInfoChat + CmdInfoChat.Messages = append(CmdInfoChat.Messages, msg) + CmdInfoChat.MessageCount++ +} + +func ScreenMemGetCmdInfoMessageCount(screenId string) int { + MemLock.Lock() + defer MemLock.Unlock() + if ScreenMemStore[screenId] == nil { + ScreenMemStore[screenId] = &ScreenMemState{} + } + if ScreenMemStore[screenId].AICmdInfoChat == nil { + ScreenMemInitCmdInfoChat(screenId) + } + return ScreenMemStore[screenId].AICmdInfoChat.MessageCount +} + +func ScreenMemGetCmdInfoChat(screenId string) *OpenAICmdInfoChatStore { + MemLock.Lock() + defer MemLock.Unlock() + if ScreenMemStore[screenId] == nil { + ScreenMemStore[screenId] = &ScreenMemState{} + } + if ScreenMemStore[screenId].AICmdInfoChat == nil { + ScreenMemInitCmdInfoChat(screenId) + } + return ScreenMemDeepCopyCmdInfoChatStore(ScreenMemStore[screenId].AICmdInfoChat) +} + +func ScreenMemUpdateCmdInfoChatMessage(screenId string, messageID int, msg *packet.OpenAICmdInfoChatMessage) error { + MemLock.Lock() + defer MemLock.Unlock() + if ScreenMemStore[screenId] == nil { + ScreenMemStore[screenId] = &ScreenMemState{} + } + if ScreenMemStore[screenId].AICmdInfoChat == nil { + ScreenMemInitCmdInfoChat(screenId) + } + CmdInfoChat := ScreenMemStore[screenId].AICmdInfoChat + if messageID >= 0 && messageID < len(CmdInfoChat.Messages) { + CmdInfoChat.Messages[messageID] = msg + } else { + return fmt.Errorf("ScreenMemUpdateCmdInfoChatMessage: error: Message Id out of range: %d", messageID) + } + return nil } func ScreenMemSetCmdInputText(screenId string, sp utilfn.StrWithPos, seqNum int) { diff --git a/wavesrv/pkg/sstore/updatebus.go b/wavesrv/pkg/sstore/updatebus.go index a3715d934..8247c4d2f 100644 --- a/wavesrv/pkg/sstore/updatebus.go +++ b/wavesrv/pkg/sstore/updatebus.go @@ -8,6 +8,7 @@ import ( "log" "sync" + "github.com/wavetermdev/waveterm/waveshell/pkg/packet" "github.com/wavetermdev/waveterm/wavesrv/pkg/utilfn" ) @@ -38,29 +39,30 @@ func (*PtyDataUpdate) UpdateType() string { func (pdu *PtyDataUpdate) Clean() {} type ModelUpdate struct { - Sessions []*SessionType `json:"sessions,omitempty"` - ActiveSessionId string `json:"activesessionid,omitempty"` - Screens []*ScreenType `json:"screens,omitempty"` - ScreenLines *ScreenLinesType `json:"screenlines,omitempty"` - Line *LineType `json:"line,omitempty"` - Lines []*LineType `json:"lines,omitempty"` - Cmd *CmdType `json:"cmd,omitempty"` - CmdLine *utilfn.StrWithPos `json:"cmdline,omitempty"` - Info *InfoMsgType `json:"info,omitempty"` - ClearInfo bool `json:"clearinfo,omitempty"` - Remotes []RemoteRuntimeState `json:"remotes,omitempty"` - History *HistoryInfoType `json:"history,omitempty"` - Interactive bool `json:"interactive"` - Connect bool `json:"connect,omitempty"` - MainView string `json:"mainview,omitempty"` - Bookmarks []*BookmarkType `json:"bookmarks,omitempty"` - SelectedBookmark string `json:"selectedbookmark,omitempty"` - HistoryViewData *HistoryViewData `json:"historyviewdata,omitempty"` - ClientData *ClientData `json:"clientdata,omitempty"` - RemoteView *RemoteViewType `json:"remoteview,omitempty"` - ScreenTombstones []*ScreenTombstoneType `json:"screentombstones,omitempty"` - SessionTombstones []*SessionTombstoneType `json:"sessiontombstones,omitempty"` - AlertMessage *AlertMessageType `json:"alertmessage,omitempty"` + Sessions []*SessionType `json:"sessions,omitempty"` + ActiveSessionId string `json:"activesessionid,omitempty"` + Screens []*ScreenType `json:"screens,omitempty"` + ScreenLines *ScreenLinesType `json:"screenlines,omitempty"` + Line *LineType `json:"line,omitempty"` + Lines []*LineType `json:"lines,omitempty"` + Cmd *CmdType `json:"cmd,omitempty"` + CmdLine *utilfn.StrWithPos `json:"cmdline,omitempty"` + Info *InfoMsgType `json:"info,omitempty"` + ClearInfo bool `json:"clearinfo,omitempty"` + Remotes []RemoteRuntimeState `json:"remotes,omitempty"` + History *HistoryInfoType `json:"history,omitempty"` + Interactive bool `json:"interactive"` + Connect bool `json:"connect,omitempty"` + MainView string `json:"mainview,omitempty"` + Bookmarks []*BookmarkType `json:"bookmarks,omitempty"` + SelectedBookmark string `json:"selectedbookmark,omitempty"` + HistoryViewData *HistoryViewData `json:"historyviewdata,omitempty"` + ClientData *ClientData `json:"clientdata,omitempty"` + RemoteView *RemoteViewType `json:"remoteview,omitempty"` + ScreenTombstones []*ScreenTombstoneType `json:"screentombstones,omitempty"` + SessionTombstones []*SessionTombstoneType `json:"sessiontombstones,omitempty"` + OpenAICmdInfoChat []*packet.OpenAICmdInfoChatMessage `json:"openaicmdinfochat,omitempty"` + AlertMessage *AlertMessageType `json:"alertmessage,omitempty"` } func (*ModelUpdate) UpdateType() string {