azure ai support (#997)

This commit is contained in:
Mike Sawka 2024-10-09 13:36:02 -07:00 committed by GitHub
parent b81ab63ddc
commit f33028af1d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 100 additions and 68 deletions

View File

@ -194,12 +194,15 @@ export class WaveAiModel implements ViewModel {
};
addMessage(newMessage);
// send message to backend and get response
const settings = globalStore.get(atoms.settingsAtom);
const settings = globalStore.get(atoms.settingsAtom) ?? {};
const opts: OpenAIOptsType = {
model: settings["ai:model"],
apitype: settings["ai:apitype"],
orgid: settings["ai:orgid"],
apitoken: settings["ai:apitoken"],
apiversion: settings["ai:apiversion"],
maxtokens: settings["ai:maxtokens"],
timeout: settings["ai:timeoutms"] / 1000,
timeoutms: settings["ai:timeoutms"] ?? 60000,
baseurl: settings["ai:baseurl"],
};
const newPrompt: OpenAIPromptMessageType = {
@ -222,7 +225,7 @@ export class WaveAiModel implements ViewModel {
opts: opts,
prompt: [...history, newPrompt],
};
const aiGen = RpcApi.StreamWaveAiCommand(WindowRpcClient, beMsg, { timeout: 60000 });
const aiGen = RpcApi.StreamWaveAiCommand(WindowRpcClient, beMsg, { timeout: opts.timeoutms });
let fullMsg = "";
for await (const msg of aiGen) {
fullMsg += msg.text ?? "";

View File

@ -329,11 +329,14 @@ declare global {
// wshrpc.OpenAIOptsType
type OpenAIOptsType = {
model: string;
apitype?: string;
apitoken: string;
orgid?: string;
apiversion?: string;
baseurl?: string;
maxtokens?: number;
maxchoices?: number;
timeout?: number;
timeoutms?: number;
};
// wshrpc.OpenAIPacketType
@ -413,10 +416,13 @@ declare global {
// wconfig.SettingsType
type SettingsType = {
"ai:*"?: boolean;
"ai:apitype"?: string;
"ai:baseurl"?: string;
"ai:apitoken"?: string;
"ai:name"?: string;
"ai:model"?: string;
"ai:orgid"?: string;
"ai:apiversion"?: string;
"ai:maxtokens"?: number;
"ai:timeoutms"?: number;
"term:*"?: boolean;

View File

@ -10,8 +10,11 @@ import (
"io"
"log"
"os"
"regexp"
"strings"
"time"
"github.com/sashabaranov/go-openai"
openaiapi "github.com/sashabaranov/go-openai"
"github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/wcloud"
@ -23,6 +26,7 @@ import (
const OpenAIPacketStr = "openai"
const OpenAICloudReqStr = "openai-cloudreq"
const PacketEOFStr = "EOF"
const DefaultAzureAPIVersion = "2023-05-15"
type OpenAICmdInfoPacketOutputType struct {
Model string `json:"model,omitempty"`
@ -52,16 +56,6 @@ type OpenAICloudReqPacketType struct {
MaxChoices int `json:"maxchoices,omitempty"`
}
type OpenAIOptsType struct {
Model string `json:"model"`
APIToken string `json:"apitoken"`
BaseURL string `json:"baseurl,omitempty"`
MaxTokens int `json:"maxtokens,omitempty"`
MaxChoices int `json:"maxchoices,omitempty"`
Timeout int `json:"timeout,omitempty"`
BlockId string `json:"blockid"`
}
func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType {
return &OpenAICloudReqPacketType{
Type: OpenAICloudReqStr,
@ -69,26 +63,31 @@ func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType {
}
func GetWSEndpoint() string {
return PCloudWSEndpoint
if !wavebase.IsDevMode() {
return PCloudWSEndpoint
return WCloudWSEndpoint
} else {
endpoint := os.Getenv(PCloudWSEndpointVarName)
endpoint := os.Getenv(WCloudWSEndpointVarName)
if endpoint == "" {
panic("Invalid PCloud ws dev endpoint, PCLOUD_WS_ENDPOINT not set or invalid")
panic("Invalid WCloud websocket dev endpoint, WCLOUD_WS_ENDPOINT not set or invalid")
}
return endpoint
}
}
const DefaultMaxTokens = 1000
const DefaultMaxTokens = 2048
const DefaultModel = "gpt-4o-mini"
const DefaultStreamChanSize = 10
const PCloudWSEndpoint = "wss://wsapi.waveterm.dev/"
const PCloudWSEndpointVarName = "PCLOUD_WS_ENDPOINT"
const WCloudWSEndpoint = "wss://wsapi.waveterm.dev/"
const WCloudWSEndpointVarName = "WCLOUD_WS_ENDPOINT"
const CloudWebsocketConnectTimeout = 1 * time.Minute
func IsCloudAIRequest(opts *wshrpc.OpenAIOptsType) bool {
if opts == nil {
return true
}
return opts.BaseURL == "" && opts.APIToken == ""
}
func convertUsage(resp openaiapi.ChatCompletionResponse) *wshrpc.OpenAIUsageType {
if resp.Usage.TotalTokens == 0 {
return nil
@ -113,6 +112,15 @@ func makeAIError(err error) wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
return wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: err}
}
func RunAICommand(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
if IsCloudAIRequest(request.Opts) {
log.Print("sending ai chat message to default waveterm cloud endpoint\n")
return RunCloudCompletionStream(ctx, request)
}
log.Printf("sending ai chat message to user-configured endpoint %s\n", request.Opts.BaseURL)
return RunLocalCompletionStream(ctx, request)
}
func RunCloudCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
wsEndpoint := wcloud.GetWSEndpoint()
@ -187,6 +195,36 @@ func RunCloudCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRe
return rtn
}
// copied from go-openai/config.go
func defaultAzureMapperFn(model string) string {
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
}
func setApiType(opts *wshrpc.OpenAIOptsType, clientConfig *openai.ClientConfig) error {
ourApiType := strings.ToLower(opts.APIType)
if ourApiType == "" || ourApiType == strings.ToLower(string(openai.APITypeOpenAI)) {
clientConfig.APIType = openai.APITypeOpenAI
return nil
} else if ourApiType == strings.ToLower(string(openai.APITypeAzure)) {
clientConfig.APIType = openai.APITypeAzure
clientConfig.APIVersion = DefaultAzureAPIVersion
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
return nil
} else if ourApiType == strings.ToLower(string(openai.APITypeAzureAD)) {
clientConfig.APIType = openai.APITypeAzureAD
clientConfig.APIVersion = DefaultAzureAPIVersion
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
return nil
} else if ourApiType == strings.ToLower(string(openai.APITypeCloudflareAzure)) {
clientConfig.APIType = openai.APITypeCloudflareAzure
clientConfig.APIVersion = DefaultAzureAPIVersion
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
return nil
} else {
return fmt.Errorf("invalid api type %q", opts.APIType)
}
}
func RunLocalCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
go func() {
@ -207,6 +245,17 @@ func RunLocalCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRe
if request.Opts.BaseURL != "" {
clientConfig.BaseURL = request.Opts.BaseURL
}
err := setApiType(request.Opts, &clientConfig)
if err != nil {
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: err}
return
}
if request.Opts.OrgID != "" {
clientConfig.OrgID = request.Opts.OrgID
}
if request.Opts.APIVersion != "" {
clientConfig.APIVersion = request.Opts.APIVersion
}
client := openaiapi.NewClientWithConfig(clientConfig)
req := openaiapi.ChatCompletionRequest{
Model: request.Opts.Model,
@ -251,33 +300,3 @@ func RunLocalCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRe
}()
return rtn
}
func marshalResponse(resp openaiapi.ChatCompletionResponse) []*wshrpc.OpenAIPacketType {
var rtn []*wshrpc.OpenAIPacketType
headerPk := MakeOpenAIPacket()
headerPk.Model = resp.Model
headerPk.Created = resp.Created
headerPk.Usage = convertUsage(resp)
rtn = append(rtn, headerPk)
for _, choice := range resp.Choices {
choicePk := MakeOpenAIPacket()
choicePk.Index = choice.Index
choicePk.Text = choice.Message.Content
choicePk.FinishReason = string(choice.FinishReason)
rtn = append(rtn, choicePk)
}
return rtn
}
func CreateErrorPacket(errStr string) *wshrpc.OpenAIPacketType {
errPk := MakeOpenAIPacket()
errPk.FinishReason = "error"
errPk.Error = errStr
return errPk
}
func CreateTextPacket(text string) *wshrpc.OpenAIPacketType {
pk := MakeOpenAIPacket()
pk.Text = text
return pk
}

View File

@ -1,7 +1,7 @@
{
"ai:model": "gpt-4o-mini",
"ai:maxtokens": 1000,
"ai:timeoutms": 10000,
"ai:maxtokens": 2048,
"ai:timeoutms": 60000,
"autoupdate:enabled": true,
"autoupdate:installonquit": true,
"autoupdate:intervalms": 3600000,

View File

@ -7,10 +7,13 @@ package wconfig
const (
ConfigKey_AiClear = "ai:*"
ConfigKey_AiApiType = "ai:apitype"
ConfigKey_AiBaseURL = "ai:baseurl"
ConfigKey_AiApiToken = "ai:apitoken"
ConfigKey_AiName = "ai:name"
ConfigKey_AiModel = "ai:model"
ConfigKey_AiOrgID = "ai:orgid"
ConfigKey_AIApiVersion = "ai:apiversion"
ConfigKey_AiMaxTokens = "ai:maxtokens"
ConfigKey_AiTimeoutMs = "ai:timeoutms"

View File

@ -40,13 +40,16 @@ func (m MetaSettingsType) MarshalJSON() ([]byte, error) {
}
type SettingsType struct {
AiClear bool `json:"ai:*,omitempty"`
AiBaseURL string `json:"ai:baseurl,omitempty"`
AiApiToken string `json:"ai:apitoken,omitempty"`
AiName string `json:"ai:name,omitempty"`
AiModel string `json:"ai:model,omitempty"`
AiMaxTokens float64 `json:"ai:maxtokens,omitempty"`
AiTimeoutMs float64 `json:"ai:timeoutms,omitempty"`
AiClear bool `json:"ai:*,omitempty"`
AiApiType string `json:"ai:apitype,omitempty"`
AiBaseURL string `json:"ai:baseurl,omitempty"`
AiApiToken string `json:"ai:apitoken,omitempty"`
AiName string `json:"ai:name,omitempty"`
AiModel string `json:"ai:model,omitempty"`
AiOrgID string `json:"ai:orgid,omitempty"`
AIApiVersion string `json:"ai:apiversion,omitempty"`
AiMaxTokens float64 `json:"ai:maxtokens,omitempty"`
AiTimeoutMs float64 `json:"ai:timeoutms,omitempty"`
TermClear bool `json:"term:*,omitempty"`
TermFontSize float64 `json:"term:fontsize,omitempty"`

View File

@ -275,11 +275,14 @@ type OpenAIPromptMessageType struct {
type OpenAIOptsType struct {
Model string `json:"model"`
APIType string `json:"apitype,omitempty"`
APIToken string `json:"apitoken"`
OrgID string `json:"orgid,omitempty"`
APIVersion string `json:"apiversion,omitempty"`
BaseURL string `json:"baseurl,omitempty"`
MaxTokens int `json:"maxtokens,omitempty"`
MaxChoices int `json:"maxchoices,omitempty"`
Timeout int `json:"timeout,omitempty"`
TimeoutMs int `json:"timeoutms,omitempty"`
}
type OpenAIPacketType struct {

View File

@ -74,12 +74,7 @@ func (ws *WshServer) StreamTestCommand(ctx context.Context) chan wshrpc.RespOrEr
}
func (ws *WshServer) StreamWaveAiCommand(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
if request.Opts.BaseURL == "" && request.Opts.APIToken == "" {
log.Print("sending ai chat message to waveterm default endpoint with openai\n")
return waveai.RunCloudCompletionStream(ctx, request)
}
log.Printf("sending ai chat message to user-configured endpoint %s\n", request.Opts.BaseURL)
return waveai.RunLocalCompletionStream(ctx, request)
return waveai.RunAICommand(ctx, request)
}
func MakePlotData(ctx context.Context, blockId string) error {