mirror of
https://github.com/wavetermdev/waveterm.git
synced 2024-12-21 16:38:23 +01:00
azure ai support (#997)
This commit is contained in:
parent
b81ab63ddc
commit
f33028af1d
@ -194,12 +194,15 @@ export class WaveAiModel implements ViewModel {
|
||||
};
|
||||
addMessage(newMessage);
|
||||
// send message to backend and get response
|
||||
const settings = globalStore.get(atoms.settingsAtom);
|
||||
const settings = globalStore.get(atoms.settingsAtom) ?? {};
|
||||
const opts: OpenAIOptsType = {
|
||||
model: settings["ai:model"],
|
||||
apitype: settings["ai:apitype"],
|
||||
orgid: settings["ai:orgid"],
|
||||
apitoken: settings["ai:apitoken"],
|
||||
apiversion: settings["ai:apiversion"],
|
||||
maxtokens: settings["ai:maxtokens"],
|
||||
timeout: settings["ai:timeoutms"] / 1000,
|
||||
timeoutms: settings["ai:timeoutms"] ?? 60000,
|
||||
baseurl: settings["ai:baseurl"],
|
||||
};
|
||||
const newPrompt: OpenAIPromptMessageType = {
|
||||
@ -222,7 +225,7 @@ export class WaveAiModel implements ViewModel {
|
||||
opts: opts,
|
||||
prompt: [...history, newPrompt],
|
||||
};
|
||||
const aiGen = RpcApi.StreamWaveAiCommand(WindowRpcClient, beMsg, { timeout: 60000 });
|
||||
const aiGen = RpcApi.StreamWaveAiCommand(WindowRpcClient, beMsg, { timeout: opts.timeoutms });
|
||||
let fullMsg = "";
|
||||
for await (const msg of aiGen) {
|
||||
fullMsg += msg.text ?? "";
|
||||
|
8
frontend/types/gotypes.d.ts
vendored
8
frontend/types/gotypes.d.ts
vendored
@ -329,11 +329,14 @@ declare global {
|
||||
// wshrpc.OpenAIOptsType
|
||||
type OpenAIOptsType = {
|
||||
model: string;
|
||||
apitype?: string;
|
||||
apitoken: string;
|
||||
orgid?: string;
|
||||
apiversion?: string;
|
||||
baseurl?: string;
|
||||
maxtokens?: number;
|
||||
maxchoices?: number;
|
||||
timeout?: number;
|
||||
timeoutms?: number;
|
||||
};
|
||||
|
||||
// wshrpc.OpenAIPacketType
|
||||
@ -413,10 +416,13 @@ declare global {
|
||||
// wconfig.SettingsType
|
||||
type SettingsType = {
|
||||
"ai:*"?: boolean;
|
||||
"ai:apitype"?: string;
|
||||
"ai:baseurl"?: string;
|
||||
"ai:apitoken"?: string;
|
||||
"ai:name"?: string;
|
||||
"ai:model"?: string;
|
||||
"ai:orgid"?: string;
|
||||
"ai:apiversion"?: string;
|
||||
"ai:maxtokens"?: number;
|
||||
"ai:timeoutms"?: number;
|
||||
"term:*"?: boolean;
|
||||
|
@ -10,8 +10,11 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
openaiapi "github.com/sashabaranov/go-openai"
|
||||
"github.com/wavetermdev/waveterm/pkg/wavebase"
|
||||
"github.com/wavetermdev/waveterm/pkg/wcloud"
|
||||
@ -23,6 +26,7 @@ import (
|
||||
const OpenAIPacketStr = "openai"
|
||||
const OpenAICloudReqStr = "openai-cloudreq"
|
||||
const PacketEOFStr = "EOF"
|
||||
const DefaultAzureAPIVersion = "2023-05-15"
|
||||
|
||||
type OpenAICmdInfoPacketOutputType struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
@ -52,16 +56,6 @@ type OpenAICloudReqPacketType struct {
|
||||
MaxChoices int `json:"maxchoices,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAIOptsType struct {
|
||||
Model string `json:"model"`
|
||||
APIToken string `json:"apitoken"`
|
||||
BaseURL string `json:"baseurl,omitempty"`
|
||||
MaxTokens int `json:"maxtokens,omitempty"`
|
||||
MaxChoices int `json:"maxchoices,omitempty"`
|
||||
Timeout int `json:"timeout,omitempty"`
|
||||
BlockId string `json:"blockid"`
|
||||
}
|
||||
|
||||
func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType {
|
||||
return &OpenAICloudReqPacketType{
|
||||
Type: OpenAICloudReqStr,
|
||||
@ -69,26 +63,31 @@ func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType {
|
||||
}
|
||||
|
||||
func GetWSEndpoint() string {
|
||||
return PCloudWSEndpoint
|
||||
if !wavebase.IsDevMode() {
|
||||
return PCloudWSEndpoint
|
||||
return WCloudWSEndpoint
|
||||
} else {
|
||||
endpoint := os.Getenv(PCloudWSEndpointVarName)
|
||||
endpoint := os.Getenv(WCloudWSEndpointVarName)
|
||||
if endpoint == "" {
|
||||
panic("Invalid PCloud ws dev endpoint, PCLOUD_WS_ENDPOINT not set or invalid")
|
||||
panic("Invalid WCloud websocket dev endpoint, WCLOUD_WS_ENDPOINT not set or invalid")
|
||||
}
|
||||
return endpoint
|
||||
}
|
||||
}
|
||||
|
||||
const DefaultMaxTokens = 1000
|
||||
const DefaultMaxTokens = 2048
|
||||
const DefaultModel = "gpt-4o-mini"
|
||||
const DefaultStreamChanSize = 10
|
||||
const PCloudWSEndpoint = "wss://wsapi.waveterm.dev/"
|
||||
const PCloudWSEndpointVarName = "PCLOUD_WS_ENDPOINT"
|
||||
const WCloudWSEndpoint = "wss://wsapi.waveterm.dev/"
|
||||
const WCloudWSEndpointVarName = "WCLOUD_WS_ENDPOINT"
|
||||
|
||||
const CloudWebsocketConnectTimeout = 1 * time.Minute
|
||||
|
||||
func IsCloudAIRequest(opts *wshrpc.OpenAIOptsType) bool {
|
||||
if opts == nil {
|
||||
return true
|
||||
}
|
||||
return opts.BaseURL == "" && opts.APIToken == ""
|
||||
}
|
||||
|
||||
func convertUsage(resp openaiapi.ChatCompletionResponse) *wshrpc.OpenAIUsageType {
|
||||
if resp.Usage.TotalTokens == 0 {
|
||||
return nil
|
||||
@ -113,6 +112,15 @@ func makeAIError(err error) wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
||||
return wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: err}
|
||||
}
|
||||
|
||||
func RunAICommand(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
||||
if IsCloudAIRequest(request.Opts) {
|
||||
log.Print("sending ai chat message to default waveterm cloud endpoint\n")
|
||||
return RunCloudCompletionStream(ctx, request)
|
||||
}
|
||||
log.Printf("sending ai chat message to user-configured endpoint %s\n", request.Opts.BaseURL)
|
||||
return RunLocalCompletionStream(ctx, request)
|
||||
}
|
||||
|
||||
func RunCloudCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
||||
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
|
||||
wsEndpoint := wcloud.GetWSEndpoint()
|
||||
@ -187,6 +195,36 @@ func RunCloudCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRe
|
||||
return rtn
|
||||
}
|
||||
|
||||
// copied from go-openai/config.go
|
||||
func defaultAzureMapperFn(model string) string {
|
||||
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
|
||||
}
|
||||
|
||||
func setApiType(opts *wshrpc.OpenAIOptsType, clientConfig *openai.ClientConfig) error {
|
||||
ourApiType := strings.ToLower(opts.APIType)
|
||||
if ourApiType == "" || ourApiType == strings.ToLower(string(openai.APITypeOpenAI)) {
|
||||
clientConfig.APIType = openai.APITypeOpenAI
|
||||
return nil
|
||||
} else if ourApiType == strings.ToLower(string(openai.APITypeAzure)) {
|
||||
clientConfig.APIType = openai.APITypeAzure
|
||||
clientConfig.APIVersion = DefaultAzureAPIVersion
|
||||
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
||||
return nil
|
||||
} else if ourApiType == strings.ToLower(string(openai.APITypeAzureAD)) {
|
||||
clientConfig.APIType = openai.APITypeAzureAD
|
||||
clientConfig.APIVersion = DefaultAzureAPIVersion
|
||||
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
||||
return nil
|
||||
} else if ourApiType == strings.ToLower(string(openai.APITypeCloudflareAzure)) {
|
||||
clientConfig.APIType = openai.APITypeCloudflareAzure
|
||||
clientConfig.APIVersion = DefaultAzureAPIVersion
|
||||
clientConfig.AzureModelMapperFunc = defaultAzureMapperFn
|
||||
return nil
|
||||
} else {
|
||||
return fmt.Errorf("invalid api type %q", opts.APIType)
|
||||
}
|
||||
}
|
||||
|
||||
func RunLocalCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
||||
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
|
||||
go func() {
|
||||
@ -207,6 +245,17 @@ func RunLocalCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRe
|
||||
if request.Opts.BaseURL != "" {
|
||||
clientConfig.BaseURL = request.Opts.BaseURL
|
||||
}
|
||||
err := setApiType(request.Opts, &clientConfig)
|
||||
if err != nil {
|
||||
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: err}
|
||||
return
|
||||
}
|
||||
if request.Opts.OrgID != "" {
|
||||
clientConfig.OrgID = request.Opts.OrgID
|
||||
}
|
||||
if request.Opts.APIVersion != "" {
|
||||
clientConfig.APIVersion = request.Opts.APIVersion
|
||||
}
|
||||
client := openaiapi.NewClientWithConfig(clientConfig)
|
||||
req := openaiapi.ChatCompletionRequest{
|
||||
Model: request.Opts.Model,
|
||||
@ -251,33 +300,3 @@ func RunLocalCompletionStream(ctx context.Context, request wshrpc.OpenAiStreamRe
|
||||
}()
|
||||
return rtn
|
||||
}
|
||||
|
||||
func marshalResponse(resp openaiapi.ChatCompletionResponse) []*wshrpc.OpenAIPacketType {
|
||||
var rtn []*wshrpc.OpenAIPacketType
|
||||
headerPk := MakeOpenAIPacket()
|
||||
headerPk.Model = resp.Model
|
||||
headerPk.Created = resp.Created
|
||||
headerPk.Usage = convertUsage(resp)
|
||||
rtn = append(rtn, headerPk)
|
||||
for _, choice := range resp.Choices {
|
||||
choicePk := MakeOpenAIPacket()
|
||||
choicePk.Index = choice.Index
|
||||
choicePk.Text = choice.Message.Content
|
||||
choicePk.FinishReason = string(choice.FinishReason)
|
||||
rtn = append(rtn, choicePk)
|
||||
}
|
||||
return rtn
|
||||
}
|
||||
|
||||
func CreateErrorPacket(errStr string) *wshrpc.OpenAIPacketType {
|
||||
errPk := MakeOpenAIPacket()
|
||||
errPk.FinishReason = "error"
|
||||
errPk.Error = errStr
|
||||
return errPk
|
||||
}
|
||||
|
||||
func CreateTextPacket(text string) *wshrpc.OpenAIPacketType {
|
||||
pk := MakeOpenAIPacket()
|
||||
pk.Text = text
|
||||
return pk
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
{
|
||||
"ai:model": "gpt-4o-mini",
|
||||
"ai:maxtokens": 1000,
|
||||
"ai:timeoutms": 10000,
|
||||
"ai:maxtokens": 2048,
|
||||
"ai:timeoutms": 60000,
|
||||
"autoupdate:enabled": true,
|
||||
"autoupdate:installonquit": true,
|
||||
"autoupdate:intervalms": 3600000,
|
||||
|
@ -7,10 +7,13 @@ package wconfig
|
||||
|
||||
const (
|
||||
ConfigKey_AiClear = "ai:*"
|
||||
ConfigKey_AiApiType = "ai:apitype"
|
||||
ConfigKey_AiBaseURL = "ai:baseurl"
|
||||
ConfigKey_AiApiToken = "ai:apitoken"
|
||||
ConfigKey_AiName = "ai:name"
|
||||
ConfigKey_AiModel = "ai:model"
|
||||
ConfigKey_AiOrgID = "ai:orgid"
|
||||
ConfigKey_AIApiVersion = "ai:apiversion"
|
||||
ConfigKey_AiMaxTokens = "ai:maxtokens"
|
||||
ConfigKey_AiTimeoutMs = "ai:timeoutms"
|
||||
|
||||
|
@ -41,10 +41,13 @@ func (m MetaSettingsType) MarshalJSON() ([]byte, error) {
|
||||
|
||||
type SettingsType struct {
|
||||
AiClear bool `json:"ai:*,omitempty"`
|
||||
AiApiType string `json:"ai:apitype,omitempty"`
|
||||
AiBaseURL string `json:"ai:baseurl,omitempty"`
|
||||
AiApiToken string `json:"ai:apitoken,omitempty"`
|
||||
AiName string `json:"ai:name,omitempty"`
|
||||
AiModel string `json:"ai:model,omitempty"`
|
||||
AiOrgID string `json:"ai:orgid,omitempty"`
|
||||
AIApiVersion string `json:"ai:apiversion,omitempty"`
|
||||
AiMaxTokens float64 `json:"ai:maxtokens,omitempty"`
|
||||
AiTimeoutMs float64 `json:"ai:timeoutms,omitempty"`
|
||||
|
||||
|
@ -275,11 +275,14 @@ type OpenAIPromptMessageType struct {
|
||||
|
||||
type OpenAIOptsType struct {
|
||||
Model string `json:"model"`
|
||||
APIType string `json:"apitype,omitempty"`
|
||||
APIToken string `json:"apitoken"`
|
||||
OrgID string `json:"orgid,omitempty"`
|
||||
APIVersion string `json:"apiversion,omitempty"`
|
||||
BaseURL string `json:"baseurl,omitempty"`
|
||||
MaxTokens int `json:"maxtokens,omitempty"`
|
||||
MaxChoices int `json:"maxchoices,omitempty"`
|
||||
Timeout int `json:"timeout,omitempty"`
|
||||
TimeoutMs int `json:"timeoutms,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAIPacketType struct {
|
||||
|
@ -74,12 +74,7 @@ func (ws *WshServer) StreamTestCommand(ctx context.Context) chan wshrpc.RespOrEr
|
||||
}
|
||||
|
||||
func (ws *WshServer) StreamWaveAiCommand(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
|
||||
if request.Opts.BaseURL == "" && request.Opts.APIToken == "" {
|
||||
log.Print("sending ai chat message to waveterm default endpoint with openai\n")
|
||||
return waveai.RunCloudCompletionStream(ctx, request)
|
||||
}
|
||||
log.Printf("sending ai chat message to user-configured endpoint %s\n", request.Opts.BaseURL)
|
||||
return waveai.RunLocalCompletionStream(ctx, request)
|
||||
return waveai.RunAICommand(ctx, request)
|
||||
}
|
||||
|
||||
func MakePlotData(ctx context.Context, blockId string) error {
|
||||
|
Loading…
Reference in New Issue
Block a user