Open AI Timeout setting and aliases for AI settings (#590)

* added ai timeout setting

* addressed review comments

* fixed baseurl gating for telemetry

* updated copy

* addressed review comments

* removed prefix for client:show and added units to timeout

* changed timeout to use ms precision
This commit is contained in:
Cole Lashley 2024-04-25 16:14:37 -07:00 committed by GitHub
parent fcf8e4ed44
commit 5e3243564b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 90 additions and 30 deletions

View File

@ -157,6 +157,12 @@ class ClientSettingsView extends React.Component<{ model: RemotesModel }, { hove
commandRtnHandler(prtn, this.errorMessage); commandRtnHandler(prtn, this.errorMessage);
} }
@boundMethod
inlineUpdateOpenAITimeout(newTimeout: string): void {
const prtn = GlobalCommandRunner.setClientOpenAISettings({ timeout: newTimeout });
commandRtnHandler(prtn, this.errorMessage);
}
@boundMethod @boundMethod
setErrorMessage(msg: string): void { setErrorMessage(msg: string): void {
mobx.action(() => { mobx.action(() => {
@ -203,6 +209,9 @@ class ClientSettingsView extends React.Component<{ model: RemotesModel }, { hove
const maxTokensStr = String( const maxTokensStr = String(
openAIOpts.maxtokens == null || openAIOpts.maxtokens == 0 ? 1000 : openAIOpts.maxtokens openAIOpts.maxtokens == null || openAIOpts.maxtokens == 0 ? 1000 : openAIOpts.maxtokens
); );
const aiTimeoutStr = String(
openAIOpts.timeout == null || openAIOpts.timeout == 0 ? 10 : openAIOpts.timeout / 1000
);
const curFontSize = GlobalModel.getTermFontSize(); const curFontSize = GlobalModel.getTermFontSize();
const curFontFamily = GlobalModel.getTermFontFamily(); const curFontFamily = GlobalModel.getTermFontFamily();
const curTheme = GlobalModel.getThemeSource(); const curTheme = GlobalModel.getThemeSource();
@ -342,6 +351,19 @@ class ClientSettingsView extends React.Component<{ model: RemotesModel }, { hove
/> />
</div> </div>
</div> </div>
<div className="settings-field">
<div className="settings-label">AI Timeout (seconds)</div>
<div className="settings-input">
<InlineSettingsTextEdit
placeholder=""
text={aiTimeoutStr}
value={aiTimeoutStr}
onChange={this.inlineUpdateOpenAITimeout}
maxLength={10}
showIcon={true}
/>
</div>
</div>
<div className="settings-field"> <div className="settings-field">
<div className="settings-label">Global Hotkey</div> <div className="settings-label">Global Hotkey</div>
<div className="settings-input"> <div className="settings-input">

View File

@ -424,6 +424,7 @@ class CommandRunner {
apitoken?: string; apitoken?: string;
maxtokens?: string; maxtokens?: string;
baseurl?: string; baseurl?: string;
timeout?: string;
}): Promise<CommandRtnType> { }): Promise<CommandRtnType> {
let kwargs = { let kwargs = {
nohist: "1", nohist: "1",
@ -440,6 +441,9 @@ class CommandRunner {
if (opts.baseurl != null) { if (opts.baseurl != null) {
kwargs["openaibaseurl"] = opts.baseurl; kwargs["openaibaseurl"] = opts.baseurl;
} }
if (opts.timeout != null) {
kwargs["openaitimeout"] = opts.timeout;
}
return GlobalModel.submitCommand("client", "set", null, kwargs, false); return GlobalModel.submitCommand("client", "set", null, kwargs, false);
} }

View File

@ -659,6 +659,7 @@ declare global {
maxtokens?: number; maxtokens?: number;
maxchoices?: number; maxchoices?: number;
baseurl?: string; baseurl?: string;
timeout?: number;
}; };
type PlaybookType = { type PlaybookType = {

View File

@ -84,9 +84,9 @@ const TermFontSizeMax = 24
const TsFormatStr = "2006-01-02 15:04:05" const TsFormatStr = "2006-01-02 15:04:05"
const OpenAIPacketTimeout = 10 * time.Second const OpenAIPacketTimeout = 10 * 1000 * time.Millisecond
const OpenAIStreamTimeout = 5 * time.Minute const OpenAIStreamTimeout = 5 * time.Minute
const OpenAICloudCompletionTelemetryOffErrorMsg = "To ensure responsible usage and prevent misuse, Wave AI requires telemetry to be enabled when using its free AI features.\n\nIf you prefer not to enable telemetry, you can still access Wave AI's features by providing your own OpenAI API key in the Settings menu. Please note that when using your personal API key, requests will be sent directly to the OpenAI API without being proxied through Wave's servers.\n\nIf you wish to continue using Wave AI's free features, you can easily enable telemetry by running the '/telemetry:on' command in the terminal. This will allow you to access the free AI features while helping to protect the platform from abuse." const OpenAICloudCompletionTelemetryOffErrorMsg = "To ensure responsible usage and prevent misuse, Wave AI requires telemetry to be enabled when using its free AI features.\n\nIf you prefer not to enable telemetry, you can still access Wave AI's features by providing your own OpenAI API key or AI Base URL in the Settings menu. Please note that when using your personal API key, requests will be sent directly to the OpenAI API or the API that you specified with the AI Base URL, without being proxied through Wave's servers.\n\nIf you wish to continue using Wave AI's free features, you can easily enable telemetry by running the '/telemetry:on' command in the terminal. This will allow you to access the free AI features while helping to protect the platform from abuse."
const ( const (
KwArgRenderer = "renderer" KwArgRenderer = "renderer"
@ -2693,8 +2693,6 @@ func getCmdInfoEngineeredPrompt(userQuery string, curLineStr string, shellType s
} }
func doOpenAICmdInfoCompletion(cmd *sstore.CmdType, clientId string, opts *sstore.OpenAIOptsType, prompt []packet.OpenAIPromptMessageType, curLineStr string) { func doOpenAICmdInfoCompletion(cmd *sstore.CmdType, clientId string, opts *sstore.OpenAIOptsType, prompt []packet.OpenAIPromptMessageType, curLineStr string) {
var hadError bool
log.Println("had error: ", hadError)
ctx, cancelFn := context.WithTimeout(context.Background(), OpenAIStreamTimeout) ctx, cancelFn := context.WithTimeout(context.Background(), OpenAIStreamTimeout)
defer cancelFn() defer cancelFn()
defer func() { defer func() {
@ -2702,7 +2700,6 @@ func doOpenAICmdInfoCompletion(cmd *sstore.CmdType, clientId string, opts *sstor
if r != nil { if r != nil {
panicMsg := fmt.Sprintf("panic: %v", r) panicMsg := fmt.Sprintf("panic: %v", r)
log.Printf("panic in doOpenAICompletion: %s\n", panicMsg) log.Printf("panic in doOpenAICompletion: %s\n", panicMsg)
hadError = true
} }
}() }()
var ch chan *packet.OpenAIPacketType var ch chan *packet.OpenAIPacketType
@ -2730,12 +2727,15 @@ func doOpenAICmdInfoCompletion(cmd *sstore.CmdType, clientId string, opts *sstor
return return
} }
writePacketToUpdateBus(ctx, cmd, asstMessagePk) writePacketToUpdateBus(ctx, cmd, asstMessagePk)
packetTimeout := OpenAIPacketTimeout
if opts.Timeout >= 0 {
packetTimeout = time.Duration(opts.Timeout) * time.Millisecond
}
doneWaitingForPackets := false doneWaitingForPackets := false
for !doneWaitingForPackets { for !doneWaitingForPackets {
select { select {
case <-time.After(OpenAIPacketTimeout): case <-time.After(packetTimeout):
// timeout reading from channel // timeout reading from channel
hadError = true
doneWaitingForPackets = true doneWaitingForPackets = true
asstOutputPk.Error = "timeout waiting for server response" asstOutputPk.Error = "timeout waiting for server response"
updateAsstResponseAndWriteToUpdateBus(ctx, cmd, asstMessagePk, asstOutputMessageID) updateAsstResponseAndWriteToUpdateBus(ctx, cmd, asstMessagePk, asstOutputMessageID)
@ -2743,7 +2743,6 @@ func doOpenAICmdInfoCompletion(cmd *sstore.CmdType, clientId string, opts *sstor
if ok { if ok {
// got a packet // got a packet
if pk.Error != "" { if pk.Error != "" {
hadError = true
asstOutputPk.Error = pk.Error asstOutputPk.Error = pk.Error
} }
if pk.Model != "" && pk.Index == 0 { if pk.Model != "" && pk.Index == 0 {
@ -2823,10 +2822,14 @@ func doOpenAIStreamCompletion(cmd *sstore.CmdType, clientId string, opts *sstore
writeErrorToPty(cmd, fmt.Sprintf("error calling OpenAI API: %v", err), outputPos) writeErrorToPty(cmd, fmt.Sprintf("error calling OpenAI API: %v", err), outputPos)
return return
} }
packetTimeout := OpenAIPacketTimeout
if opts.Timeout >= 0 {
packetTimeout = time.Duration(opts.Timeout) * time.Millisecond
}
doneWaitingForPackets := false doneWaitingForPackets := false
for !doneWaitingForPackets { for !doneWaitingForPackets {
select { select {
case <-time.After(OpenAIPacketTimeout): case <-time.After(packetTimeout):
// timeout reading from channel // timeout reading from channel
hadError = true hadError = true
pk := openai.CreateErrorPacket(fmt.Sprintf("timeout waiting for server response")) pk := openai.CreateErrorPacket(fmt.Sprintf("timeout waiting for server response"))
@ -2895,7 +2898,7 @@ func OpenAICommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scbus
return nil, fmt.Errorf("error retrieving client open ai options") return nil, fmt.Errorf("error retrieving client open ai options")
} }
opts := clientData.OpenAIOpts opts := clientData.OpenAIOpts
if opts.APIToken == "" { if opts.APIToken == "" && opts.BaseURL == "" {
if clientData.ClientOpts.NoTelemetry { if clientData.ClientOpts.NoTelemetry {
return nil, fmt.Errorf(OpenAICloudCompletionTelemetryOffErrorMsg) return nil, fmt.Errorf(OpenAICloudCompletionTelemetryOffErrorMsg)
} }
@ -5798,6 +5801,15 @@ func validateFontFamily(fontFamily string) error {
return nil return nil
} }
func CheckOptionAlias(kwargs map[string]string, aliases ...string) (string, bool) {
for _, alias := range aliases {
if val, found := kwargs[alias]; found {
return val, found
}
}
return "", false
}
func ClientSetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scbus.UpdatePacket, error) { func ClientSetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (scbus.UpdatePacket, error) {
clientData, err := sstore.EnsureClientData(ctx) clientData, err := sstore.EnsureClientData(ctx)
if err != nil { if err != nil {
@ -5870,7 +5882,7 @@ func ClientSetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sc
} }
varsUpdated = append(varsUpdated, "termtheme") varsUpdated = append(varsUpdated, "termtheme")
} }
if apiToken, found := pk.Kwargs["openaiapitoken"]; found { if apiToken, found := CheckOptionAlias(pk.Kwargs, "openaiapitoken", "aiapitoken"); found {
err = validateOpenAIAPIToken(apiToken) err = validateOpenAIAPIToken(apiToken)
if err != nil { if err != nil {
return nil, err return nil, err
@ -5884,10 +5896,10 @@ func ClientSetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sc
aiOpts.APIToken = apiToken aiOpts.APIToken = apiToken
err = sstore.UpdateClientOpenAIOpts(ctx, *aiOpts) err = sstore.UpdateClientOpenAIOpts(ctx, *aiOpts)
if err != nil { if err != nil {
return nil, fmt.Errorf("error updating client openai api token: %v", err) return nil, fmt.Errorf("error updating client ai api token: %v", err)
} }
} }
if aiModel, found := pk.Kwargs["openaimodel"]; found { if aiModel, found := CheckOptionAlias(pk.Kwargs, "openaimodel", "aimodel"); found {
err = validateOpenAIModel(aiModel) err = validateOpenAIModel(aiModel)
if err != nil { if err != nil {
return nil, err return nil, err
@ -5901,16 +5913,16 @@ func ClientSetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sc
aiOpts.Model = aiModel aiOpts.Model = aiModel
err = sstore.UpdateClientOpenAIOpts(ctx, *aiOpts) err = sstore.UpdateClientOpenAIOpts(ctx, *aiOpts)
if err != nil { if err != nil {
return nil, fmt.Errorf("error updating client openai model: %v", err) return nil, fmt.Errorf("error updating client ai model: %v", err)
} }
} }
if maxTokensStr, found := pk.Kwargs["openaimaxtokens"]; found { if maxTokensStr, found := CheckOptionAlias(pk.Kwargs, "openaimaxtokens", "aimaxtokens"); found {
maxTokens, err := strconv.Atoi(maxTokensStr) maxTokens, err := strconv.Atoi(maxTokensStr)
if err != nil { if err != nil {
return nil, fmt.Errorf("error updating client openai maxtokens, invalid number: %v", err) return nil, fmt.Errorf("error updating client ai maxtokens, invalid number: %v", err)
} }
if maxTokens < 0 || maxTokens > 1000000 { if maxTokens < 0 || maxTokens > 1000000 {
return nil, fmt.Errorf("error updating client openai maxtokens, out of range: %d", maxTokens) return nil, fmt.Errorf("error updating client ai maxtokens, out of range: %d", maxTokens)
} }
varsUpdated = append(varsUpdated, "openaimaxtokens") varsUpdated = append(varsUpdated, "openaimaxtokens")
aiOpts := clientData.OpenAIOpts aiOpts := clientData.OpenAIOpts
@ -5921,16 +5933,16 @@ func ClientSetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sc
aiOpts.MaxTokens = maxTokens aiOpts.MaxTokens = maxTokens
err = sstore.UpdateClientOpenAIOpts(ctx, *aiOpts) err = sstore.UpdateClientOpenAIOpts(ctx, *aiOpts)
if err != nil { if err != nil {
return nil, fmt.Errorf("error updating client openai maxtokens: %v", err) return nil, fmt.Errorf("error updating client ai maxtokens: %v", err)
} }
} }
if maxChoicesStr, found := pk.Kwargs["openaimaxchoices"]; found { if maxChoicesStr, found := CheckOptionAlias(pk.Kwargs, "openaimaxchoices", "aimaxchoices"); found {
maxChoices, err := strconv.Atoi(maxChoicesStr) maxChoices, err := strconv.Atoi(maxChoicesStr)
if err != nil { if err != nil {
return nil, fmt.Errorf("error updating client openai maxchoices, invalid number: %v", err) return nil, fmt.Errorf("error updating client ai maxchoices, invalid number: %v", err)
} }
if maxChoices < 0 || maxChoices > 10 { if maxChoices < 0 || maxChoices > 10 {
return nil, fmt.Errorf("error updating client openai maxchoices, out of range: %d", maxChoices) return nil, fmt.Errorf("error updating client ai maxchoices, out of range: %d", maxChoices)
} }
varsUpdated = append(varsUpdated, "openaimaxchoices") varsUpdated = append(varsUpdated, "openaimaxchoices")
aiOpts := clientData.OpenAIOpts aiOpts := clientData.OpenAIOpts
@ -5941,10 +5953,10 @@ func ClientSetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sc
aiOpts.MaxChoices = maxChoices aiOpts.MaxChoices = maxChoices
err = sstore.UpdateClientOpenAIOpts(ctx, *aiOpts) err = sstore.UpdateClientOpenAIOpts(ctx, *aiOpts)
if err != nil { if err != nil {
return nil, fmt.Errorf("error updating client openai maxchoices: %v", err) return nil, fmt.Errorf("error updating client ai maxchoices: %v", err)
} }
} }
if aiBaseURL, found := pk.Kwargs["openaibaseurl"]; found { if aiBaseURL, found := CheckOptionAlias(pk.Kwargs, "openaibaseurl", "aibaseurl"); found {
aiOpts := clientData.OpenAIOpts aiOpts := clientData.OpenAIOpts
if aiOpts == nil { if aiOpts == nil {
aiOpts = &sstore.OpenAIOptsType{} aiOpts = &sstore.OpenAIOptsType{}
@ -5954,7 +5966,24 @@ func ClientSetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sc
varsUpdated = append(varsUpdated, "openaibaseurl") varsUpdated = append(varsUpdated, "openaibaseurl")
err = sstore.UpdateClientOpenAIOpts(ctx, *aiOpts) err = sstore.UpdateClientOpenAIOpts(ctx, *aiOpts)
if err != nil { if err != nil {
return nil, fmt.Errorf("error updating client openai base url: %v", err) return nil, fmt.Errorf("error updating client ai base url: %v", err)
}
}
if aiTimeoutStr, found := CheckOptionAlias(pk.Kwargs, "openaitimeout", "aitimeout"); found {
aiTimeout, err := strconv.ParseFloat(aiTimeoutStr, 64)
if err != nil {
return nil, fmt.Errorf("error updating client ai timeout, invalid number: %v", err)
}
aiOpts := clientData.OpenAIOpts
if aiOpts == nil {
aiOpts = &sstore.OpenAIOptsType{}
clientData.OpenAIOpts = aiOpts
}
aiOpts.Timeout = int(aiTimeout * 1000)
varsUpdated = append(varsUpdated, "openaitimeout")
err = sstore.UpdateClientOpenAIOpts(ctx, *aiOpts)
if err != nil {
return nil, fmt.Errorf("error updating client ai timeout: %v", err)
} }
} }
if webglStr, found := pk.Kwargs["webgl"]; found { if webglStr, found := pk.Kwargs["webgl"]; found {
@ -5968,7 +5997,7 @@ func ClientSetCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sc
varsUpdated = append(varsUpdated, "webgl") varsUpdated = append(varsUpdated, "webgl")
} }
if len(varsUpdated) == 0 { if len(varsUpdated) == 0 {
return nil, fmt.Errorf("/client:set requires a value to set: %s", formatStrs([]string{"termfontsize", "termfontfamily", "openaiapitoken", "openaimodel", "openaibaseurl", "openaimaxtokens", "openaimaxchoices", "webgl"}, "or", false)) return nil, fmt.Errorf("/client:set requires a value to set: %s", formatStrs([]string{"termfontsize", "termfontfamily", "openaiapitoken", "openaimodel", "openaibaseurl", "openaimaxtokens", "openaimaxchoices", "openaitimeout", "webgl"}, "or", false))
} }
clientData, err = sstore.EnsureClientData(ctx) clientData, err = sstore.EnsureClientData(ctx)
if err != nil { if err != nil {
@ -6008,11 +6037,12 @@ func ClientShowCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (s
buf.WriteString(fmt.Sprintf(" %-15s %d\n", "termfontsize", clientData.FeOpts.TermFontSize)) buf.WriteString(fmt.Sprintf(" %-15s %d\n", "termfontsize", clientData.FeOpts.TermFontSize))
buf.WriteString(fmt.Sprintf(" %-15s %s\n", "termfontfamily", clientData.FeOpts.TermFontFamily)) buf.WriteString(fmt.Sprintf(" %-15s %s\n", "termfontfamily", clientData.FeOpts.TermFontFamily))
buf.WriteString(fmt.Sprintf(" %-15s %s\n", "termfontfamily", clientData.FeOpts.Theme)) buf.WriteString(fmt.Sprintf(" %-15s %s\n", "termfontfamily", clientData.FeOpts.Theme))
buf.WriteString(fmt.Sprintf(" %-15s %s\n", "openaiapitoken", clientData.OpenAIOpts.APIToken)) buf.WriteString(fmt.Sprintf(" %-15s %s\n", "aiapitoken", clientData.OpenAIOpts.APIToken))
buf.WriteString(fmt.Sprintf(" %-15s %s\n", "openaimodel", clientData.OpenAIOpts.Model)) buf.WriteString(fmt.Sprintf(" %-15s %s\n", "aimodel", clientData.OpenAIOpts.Model))
buf.WriteString(fmt.Sprintf(" %-15s %d\n", "openaimaxtokens", clientData.OpenAIOpts.MaxTokens)) buf.WriteString(fmt.Sprintf(" %-15s %d\n", "aimaxtokens", clientData.OpenAIOpts.MaxTokens))
buf.WriteString(fmt.Sprintf(" %-15s %d\n", "openaimaxchoices", clientData.OpenAIOpts.MaxChoices)) buf.WriteString(fmt.Sprintf(" %-15s %d\n", "aimaxchoices", clientData.OpenAIOpts.MaxChoices))
buf.WriteString(fmt.Sprintf(" %-15s %s\n", "openaibaseurl", clientData.OpenAIOpts.BaseURL)) buf.WriteString(fmt.Sprintf(" %-15s %s\n", "aibaseurl", clientData.OpenAIOpts.BaseURL))
buf.WriteString(fmt.Sprintf(" %-15s %ss\n", "aitimeout", strconv.FormatFloat((float64(clientData.OpenAIOpts.Timeout)/1000.0), 'f', -1, 64)))
update := scbus.MakeUpdatePacket() update := scbus.MakeUpdatePacket()
update.AddUpdate(sstore.InfoMsgType{ update.AddUpdate(sstore.InfoMsgType{
InfoTitle: fmt.Sprintf("client info"), InfoTitle: fmt.Sprintf("client info"),

View File

@ -289,6 +289,8 @@ func (cdata *ClientData) Clean() *ClientData {
Model: cdata.OpenAIOpts.Model, Model: cdata.OpenAIOpts.Model,
MaxTokens: cdata.OpenAIOpts.MaxTokens, MaxTokens: cdata.OpenAIOpts.MaxTokens,
MaxChoices: cdata.OpenAIOpts.MaxChoices, MaxChoices: cdata.OpenAIOpts.MaxChoices,
Timeout: cdata.OpenAIOpts.Timeout,
BaseURL: cdata.OpenAIOpts.BaseURL,
// omit API Token // omit API Token
} }
if cdata.OpenAIOpts.APIToken != "" { if cdata.OpenAIOpts.APIToken != "" {
@ -736,6 +738,7 @@ type OpenAIOptsType struct {
BaseURL string `json:"baseurl,omitempty"` BaseURL string `json:"baseurl,omitempty"`
MaxTokens int `json:"maxtokens,omitempty"` MaxTokens int `json:"maxtokens,omitempty"`
MaxChoices int `json:"maxchoices,omitempty"` MaxChoices int `json:"maxchoices,omitempty"`
Timeout int `json:"timeout,omitempty"`
} }
const ( const (