Rename outdated WaveAI types (#1609)

A bunch of the Wave AI types still mentioned OpenAI. Now that most of
them are being used for multiple AI backends, we need to update the
names to be more generic.
This commit is contained in:
Evan Simkowitz 2024-12-23 13:55:04 -05:00 committed by GitHub
parent 5cfbdcab1a
commit dbacae8a99
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 123 additions and 134 deletions

View File

@ -15,7 +15,7 @@ class BlockServiceType {
SaveTerminalState(blockId: string, state: string, stateType: string, ptyOffset: number, termSize: TermSize): Promise<void> { SaveTerminalState(blockId: string, state: string, stateType: string, ptyOffset: number, termSize: TermSize): Promise<void> {
return WOS.callBackendService("block", "SaveTerminalState", Array.from(arguments)) return WOS.callBackendService("block", "SaveTerminalState", Array.from(arguments))
} }
SaveWaveAiData(arg2: string, arg3: OpenAIPromptMessageType[]): Promise<void> { SaveWaveAiData(arg2: string, arg3: WaveAIPromptMessageType[]): Promise<void> {
return WOS.callBackendService("block", "SaveWaveAiData", Array.from(arguments)) return WOS.callBackendService("block", "SaveWaveAiData", Array.from(arguments))
} }
} }

View File

@ -303,7 +303,7 @@ class RpcApiType {
} }
// command "streamwaveai" [responsestream] // command "streamwaveai" [responsestream]
StreamWaveAiCommand(client: WshClient, data: OpenAiStreamRequest, opts?: RpcOpts): AsyncGenerator<OpenAIPacketType, void, boolean> { StreamWaveAiCommand(client: WshClient, data: WaveAIStreamRequest, opts?: RpcOpts): AsyncGenerator<WaveAIPacketType, void, boolean> {
return client.wshRpcStream("streamwaveai", data, opts); return client.wshRpcStream("streamwaveai", data, opts);
} }

View File

@ -35,7 +35,7 @@ interface ChatItemProps {
model: WaveAiModel; model: WaveAiModel;
} }
function promptToMsg(prompt: OpenAIPromptMessageType): ChatMessageType { function promptToMsg(prompt: WaveAIPromptMessageType): ChatMessageType {
return { return {
id: crypto.randomUUID(), id: crypto.randomUUID(),
user: prompt.role, user: prompt.role,
@ -67,7 +67,7 @@ export class WaveAiModel implements ViewModel {
blockAtom: Atom<Block>; blockAtom: Atom<Block>;
presetKey: Atom<string>; presetKey: Atom<string>;
presetMap: Atom<{ [k: string]: MetaType }>; presetMap: Atom<{ [k: string]: MetaType }>;
aiOpts: Atom<OpenAIOptsType>; aiOpts: Atom<WaveAIOptsType>;
viewIcon?: Atom<string | IconButtonDecl>; viewIcon?: Atom<string | IconButtonDecl>;
viewName?: Atom<string>; viewName?: Atom<string>;
viewText?: Atom<string | HeaderElem[]>; viewText?: Atom<string | HeaderElem[]>;
@ -167,7 +167,7 @@ export class WaveAiModel implements ViewModel {
...settings, ...settings,
...meta, ...meta,
}; };
const opts: OpenAIOptsType = { const opts: WaveAIOptsType = {
model: settings["ai:model"] ?? null, model: settings["ai:model"] ?? null,
apitype: settings["ai:apitype"] ?? null, apitype: settings["ai:apitype"] ?? null,
orgid: settings["ai:orgid"] ?? null, orgid: settings["ai:orgid"] ?? null,
@ -293,12 +293,12 @@ export class WaveAiModel implements ViewModel {
globalStore.set(this.messagesAtom, history.map(promptToMsg)); globalStore.set(this.messagesAtom, history.map(promptToMsg));
} }
async fetchAiData(): Promise<Array<OpenAIPromptMessageType>> { async fetchAiData(): Promise<Array<WaveAIPromptMessageType>> {
const { data } = await fetchWaveFile(this.blockId, "aidata"); const { data } = await fetchWaveFile(this.blockId, "aidata");
if (!data) { if (!data) {
return []; return [];
} }
const history: Array<OpenAIPromptMessageType> = JSON.parse(new TextDecoder().decode(data)); const history: Array<WaveAIPromptMessageType> = JSON.parse(new TextDecoder().decode(data));
return history.slice(Math.max(history.length - slidingWindowSize, 0)); return history.slice(Math.max(history.length - slidingWindowSize, 0));
} }
@ -333,7 +333,7 @@ export class WaveAiModel implements ViewModel {
globalStore.set(this.addMessageAtom, newMessage); globalStore.set(this.addMessageAtom, newMessage);
// send message to backend and get response // send message to backend and get response
const opts = globalStore.get(this.aiOpts); const opts = globalStore.get(this.aiOpts);
const newPrompt: OpenAIPromptMessageType = { const newPrompt: WaveAIPromptMessageType = {
role: "user", role: "user",
content: text, content: text,
}; };
@ -368,7 +368,7 @@ export class WaveAiModel implements ViewModel {
// only save the author's prompt // only save the author's prompt
await BlockService.SaveWaveAiData(this.blockId, [...history, newPrompt]); await BlockService.SaveWaveAiData(this.blockId, [...history, newPrompt]);
} else { } else {
const responsePrompt: OpenAIPromptMessageType = { const responsePrompt: WaveAIPromptMessageType = {
role: "assistant", role: "assistant",
content: fullMsg, content: fullMsg,
}; };
@ -383,7 +383,7 @@ export class WaveAiModel implements ViewModel {
globalStore.set(this.removeLastMessageAtom); globalStore.set(this.removeLastMessageAtom);
} else { } else {
globalStore.set(this.updateLastMessageAtom, "", false); globalStore.set(this.updateLastMessageAtom, "", false);
const responsePrompt: OpenAIPromptMessageType = { const responsePrompt: WaveAIPromptMessageType = {
role: "assistant", role: "assistant",
content: fullMsg, content: fullMsg,
}; };
@ -397,7 +397,7 @@ export class WaveAiModel implements ViewModel {
}; };
globalStore.set(this.addMessageAtom, errorMessage); globalStore.set(this.addMessageAtom, errorMessage);
globalStore.set(this.updateLastMessageAtom, "", false); globalStore.set(this.updateLastMessageAtom, "", false);
const errorPrompt: OpenAIPromptMessageType = { const errorPrompt: WaveAIPromptMessageType = {
role: "error", role: "error",
content: errMsg, content: errMsg,
}; };

View File

@ -519,52 +519,6 @@ declare global {
// waveobj.ORef // waveobj.ORef
type ORef = string; type ORef = string;
// wshrpc.OpenAIOptsType
type OpenAIOptsType = {
model: string;
apitype?: string;
apitoken: string;
orgid?: string;
apiversion?: string;
baseurl?: string;
maxtokens?: number;
maxchoices?: number;
timeoutms?: number;
};
// wshrpc.OpenAIPacketType
type OpenAIPacketType = {
type: string;
model?: string;
created?: number;
finish_reason?: string;
usage?: OpenAIUsageType;
index?: number;
text?: string;
error?: string;
};
// wshrpc.OpenAIPromptMessageType
type OpenAIPromptMessageType = {
role: string;
content: string;
name?: string;
};
// wshrpc.OpenAIUsageType
type OpenAIUsageType = {
prompt_tokens?: number;
completion_tokens?: number;
total_tokens?: number;
};
// wshrpc.OpenAiStreamRequest
type OpenAiStreamRequest = {
clientid?: string;
opts: OpenAIOptsType;
prompt: OpenAIPromptMessageType[];
};
// wshrpc.PathCommandData // wshrpc.PathCommandData
type PathCommandData = { type PathCommandData = {
pathtype: string; pathtype: string;
@ -1016,6 +970,52 @@ declare global {
fullconfig: FullConfigType; fullconfig: FullConfigType;
}; };
// wshrpc.WaveAIOptsType
type WaveAIOptsType = {
model: string;
apitype?: string;
apitoken: string;
orgid?: string;
apiversion?: string;
baseurl?: string;
maxtokens?: number;
maxchoices?: number;
timeoutms?: number;
};
// wshrpc.WaveAIPacketType
type WaveAIPacketType = {
type: string;
model?: string;
created?: number;
finish_reason?: string;
usage?: WaveAIUsageType;
index?: number;
text?: string;
error?: string;
};
// wshrpc.WaveAIPromptMessageType
type WaveAIPromptMessageType = {
role: string;
content: string;
name?: string;
};
// wshrpc.WaveAIStreamRequest
type WaveAIStreamRequest = {
clientid?: string;
opts: WaveAIOptsType;
prompt: WaveAIPromptMessageType[];
};
// wshrpc.WaveAIUsageType
type WaveAIUsageType = {
prompt_tokens?: number;
completion_tokens?: number;
total_tokens?: number;
};
// wps.WaveEvent // wps.WaveEvent
type WaveEvent = { type WaveEvent = {
event: string; event: string;

View File

@ -70,7 +70,7 @@ func (bs *BlockService) SaveTerminalState(ctx context.Context, blockId string, s
return nil return nil
} }
func (bs *BlockService) SaveWaveAiData(ctx context.Context, blockId string, history []wshrpc.OpenAIPromptMessageType) error { func (bs *BlockService) SaveWaveAiData(ctx context.Context, blockId string, history []wshrpc.WaveAIPromptMessageType) error {
block, err := wstore.DBMustGet[*waveobj.Block](ctx, blockId) block, err := wstore.DBMustGet[*waveobj.Block](ctx, blockId)
if err != nil { if err != nil {
return err return err

View File

@ -109,8 +109,8 @@ func parseSSE(reader *bufio.Reader) (*sseEvent, error) {
} }
} }
func (AnthropicBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { func (AnthropicBackend) StreamCompletion(ctx context.Context, request wshrpc.WaveAIStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType] {
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]) rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType])
go func() { go func() {
defer func() { defer func() {
@ -231,23 +231,23 @@ func (AnthropicBackend) StreamCompletion(ctx context.Context, request wshrpc.Ope
switch sse.Event { switch sse.Event {
case "message_start": case "message_start":
if event.Message != nil { if event.Message != nil {
pk := MakeOpenAIPacket() pk := MakeWaveAIPacket()
pk.Model = event.Message.Model pk.Model = event.Message.Model
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk} rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk}
} }
case "content_block_start": case "content_block_start":
if event.ContentBlock != nil && event.ContentBlock.Text != "" { if event.ContentBlock != nil && event.ContentBlock.Text != "" {
pk := MakeOpenAIPacket() pk := MakeWaveAIPacket()
pk.Text = event.ContentBlock.Text pk.Text = event.ContentBlock.Text
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk} rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk}
} }
case "content_block_delta": case "content_block_delta":
if event.Delta != nil && event.Delta.Text != "" { if event.Delta != nil && event.Delta.Text != "" {
pk := MakeOpenAIPacket() pk := MakeWaveAIPacket()
pk.Text = event.Delta.Text pk.Text = event.Delta.Text
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk} rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk}
} }
case "content_block_stop": case "content_block_stop":
@ -258,27 +258,27 @@ func (AnthropicBackend) StreamCompletion(ctx context.Context, request wshrpc.Ope
case "message_delta": case "message_delta":
// Update message metadata, usage stats // Update message metadata, usage stats
if event.Usage != nil { if event.Usage != nil {
pk := MakeOpenAIPacket() pk := MakeWaveAIPacket()
pk.Usage = &wshrpc.OpenAIUsageType{ pk.Usage = &wshrpc.WaveAIUsageType{
PromptTokens: event.Usage.InputTokens, PromptTokens: event.Usage.InputTokens,
CompletionTokens: event.Usage.OutputTokens, CompletionTokens: event.Usage.OutputTokens,
TotalTokens: event.Usage.InputTokens + event.Usage.OutputTokens, TotalTokens: event.Usage.InputTokens + event.Usage.OutputTokens,
} }
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk} rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk}
} }
case "message_stop": case "message_stop":
if event.Message != nil { if event.Message != nil {
pk := MakeOpenAIPacket() pk := MakeWaveAIPacket()
pk.FinishReason = event.Message.StopReason pk.FinishReason = event.Message.StopReason
if event.Message.Usage != nil { if event.Message.Usage != nil {
pk.Usage = &wshrpc.OpenAIUsageType{ pk.Usage = &wshrpc.WaveAIUsageType{
PromptTokens: event.Message.Usage.InputTokens, PromptTokens: event.Message.Usage.InputTokens,
CompletionTokens: event.Message.Usage.OutputTokens, CompletionTokens: event.Message.Usage.OutputTokens,
TotalTokens: event.Message.Usage.InputTokens + event.Message.Usage.OutputTokens, TotalTokens: event.Message.Usage.InputTokens + event.Message.Usage.OutputTokens,
} }
} }
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk} rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk}
} }
default: default:

View File

@ -20,22 +20,22 @@ type WaveAICloudBackend struct{}
var _ AIBackend = WaveAICloudBackend{} var _ AIBackend = WaveAICloudBackend{}
type OpenAICloudReqPacketType struct { type WaveAICloudReqPacketType struct {
Type string `json:"type"` Type string `json:"type"`
ClientId string `json:"clientid"` ClientId string `json:"clientid"`
Prompt []wshrpc.OpenAIPromptMessageType `json:"prompt"` Prompt []wshrpc.WaveAIPromptMessageType `json:"prompt"`
MaxTokens int `json:"maxtokens,omitempty"` MaxTokens int `json:"maxtokens,omitempty"`
MaxChoices int `json:"maxchoices,omitempty"` MaxChoices int `json:"maxchoices,omitempty"`
} }
func MakeOpenAICloudReqPacket() *OpenAICloudReqPacketType { func MakeWaveAICloudReqPacket() *WaveAICloudReqPacketType {
return &OpenAICloudReqPacketType{ return &WaveAICloudReqPacketType{
Type: OpenAICloudReqStr, Type: OpenAICloudReqStr,
} }
} }
func (WaveAICloudBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { func (WaveAICloudBackend) StreamCompletion(ctx context.Context, request wshrpc.WaveAIStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType] {
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]) rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType])
wsEndpoint := wcloud.GetWSEndpoint() wsEndpoint := wcloud.GetWSEndpoint()
go func() { go func() {
defer func() { defer func() {
@ -69,14 +69,14 @@ func (WaveAICloudBackend) StreamCompletion(ctx context.Context, request wshrpc.O
rtn <- makeAIError(fmt.Errorf("unable to close openai channel: %v", err)) rtn <- makeAIError(fmt.Errorf("unable to close openai channel: %v", err))
} }
}() }()
var sendablePromptMsgs []wshrpc.OpenAIPromptMessageType var sendablePromptMsgs []wshrpc.WaveAIPromptMessageType
for _, promptMsg := range request.Prompt { for _, promptMsg := range request.Prompt {
if promptMsg.Role == "error" { if promptMsg.Role == "error" {
continue continue
} }
sendablePromptMsgs = append(sendablePromptMsgs, promptMsg) sendablePromptMsgs = append(sendablePromptMsgs, promptMsg)
} }
reqPk := MakeOpenAICloudReqPacket() reqPk := MakeWaveAICloudReqPacket()
reqPk.ClientId = request.ClientId reqPk.ClientId = request.ClientId
reqPk.Prompt = sendablePromptMsgs reqPk.Prompt = sendablePromptMsgs
reqPk.MaxTokens = request.Opts.MaxTokens reqPk.MaxTokens = request.Opts.MaxTokens
@ -101,7 +101,7 @@ func (WaveAICloudBackend) StreamCompletion(ctx context.Context, request wshrpc.O
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket error reading message: %v", err)) rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket error reading message: %v", err))
break break
} }
var streamResp *wshrpc.OpenAIPacketType var streamResp *wshrpc.WaveAIPacketType
err = json.Unmarshal(socketMessage, &streamResp) err = json.Unmarshal(socketMessage, &streamResp)
if err != nil { if err != nil {
rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket response json decode error: %v", err)) rtn <- makeAIError(fmt.Errorf("OpenAI request, websocket response json decode error: %v", err))
@ -115,7 +115,7 @@ func (WaveAICloudBackend) StreamCompletion(ctx context.Context, request wshrpc.O
rtn <- makeAIError(fmt.Errorf("%v", streamResp.Error)) rtn <- makeAIError(fmt.Errorf("%v", streamResp.Error))
break break
} }
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *streamResp} rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *streamResp}
} }
}() }()
return rtn return rtn

View File

@ -25,7 +25,7 @@ func defaultAzureMapperFn(model string) string {
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "") return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
} }
func setApiType(opts *wshrpc.OpenAIOptsType, clientConfig *openaiapi.ClientConfig) error { func setApiType(opts *wshrpc.WaveAIOptsType, clientConfig *openaiapi.ClientConfig) error {
ourApiType := strings.ToLower(opts.APIType) ourApiType := strings.ToLower(opts.APIType)
if ourApiType == "" || ourApiType == strings.ToLower(string(openaiapi.APITypeOpenAI)) { if ourApiType == "" || ourApiType == strings.ToLower(string(openaiapi.APITypeOpenAI)) {
clientConfig.APIType = openaiapi.APITypeOpenAI clientConfig.APIType = openaiapi.APITypeOpenAI
@ -50,7 +50,7 @@ func setApiType(opts *wshrpc.OpenAIOptsType, clientConfig *openaiapi.ClientConfi
} }
} }
func convertPrompt(prompt []wshrpc.OpenAIPromptMessageType) []openaiapi.ChatCompletionMessage { func convertPrompt(prompt []wshrpc.WaveAIPromptMessageType) []openaiapi.ChatCompletionMessage {
var rtn []openaiapi.ChatCompletionMessage var rtn []openaiapi.ChatCompletionMessage
for _, p := range prompt { for _, p := range prompt {
msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name} msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name}
@ -59,19 +59,8 @@ func convertPrompt(prompt []wshrpc.OpenAIPromptMessageType) []openaiapi.ChatComp
return rtn return rtn
} }
func convertUsage(resp openaiapi.ChatCompletionResponse) *wshrpc.OpenAIUsageType { func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.WaveAIStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType] {
if resp.Usage.TotalTokens == 0 { rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType])
return nil
}
return &wshrpc.OpenAIUsageType{
PromptTokens: resp.Usage.PromptTokens,
CompletionTokens: resp.Usage.CompletionTokens,
TotalTokens: resp.Usage.TotalTokens,
}
}
func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] {
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType])
go func() { go func() {
defer func() { defer func() {
panicErr := panichandler.PanicHandler("OpenAIBackend.StreamCompletion") panicErr := panichandler.PanicHandler("OpenAIBackend.StreamCompletion")
@ -128,18 +117,18 @@ func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAi
} }
// Send header packet // Send header packet
headerPk := MakeOpenAIPacket() headerPk := MakeWaveAIPacket()
headerPk.Model = resp.Model headerPk.Model = resp.Model
headerPk.Created = resp.Created headerPk.Created = resp.Created
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *headerPk} rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *headerPk}
// Send content packet(s) // Send content packet(s)
for i, choice := range resp.Choices { for i, choice := range resp.Choices {
pk := MakeOpenAIPacket() pk := MakeWaveAIPacket()
pk.Index = i pk.Index = i
pk.Text = choice.Message.Content pk.Text = choice.Message.Content
pk.FinishReason = string(choice.FinishReason) pk.FinishReason = string(choice.FinishReason)
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk} rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk}
} }
return return
} }
@ -167,18 +156,18 @@ func (OpenAIBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAi
break break
} }
if streamResp.Model != "" && !sentHeader { if streamResp.Model != "" && !sentHeader {
pk := MakeOpenAIPacket() pk := MakeWaveAIPacket()
pk.Model = streamResp.Model pk.Model = streamResp.Model
pk.Created = streamResp.Created pk.Created = streamResp.Created
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk} rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk}
sentHeader = true sentHeader = true
} }
for _, choice := range streamResp.Choices { for _, choice := range streamResp.Choices {
pk := MakeOpenAIPacket() pk := MakeWaveAIPacket()
pk.Index = choice.Index pk.Index = choice.Index
pk.Text = choice.Delta.Content pk.Text = choice.Delta.Content
pk.FinishReason = string(choice.FinishReason) pk.FinishReason = string(choice.FinishReason)
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk} rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk}
} }
} }
}() }()

View File

@ -49,8 +49,8 @@ type perplexityResponse struct {
Model string `json:"model"` Model string `json:"model"`
} }
func (PerplexityBackend) StreamCompletion(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { func (PerplexityBackend) StreamCompletion(ctx context.Context, request wshrpc.WaveAIStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType] {
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]) rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType])
go func() { go func() {
defer func() { defer func() {
@ -160,17 +160,17 @@ func (PerplexityBackend) StreamCompletion(ctx context.Context, request wshrpc.Op
} }
if !sentHeader { if !sentHeader {
pk := MakeOpenAIPacket() pk := MakeWaveAIPacket()
pk.Model = response.Model pk.Model = response.Model
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk} rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk}
sentHeader = true sentHeader = true
} }
for _, choice := range response.Choices { for _, choice := range response.Choices {
pk := MakeOpenAIPacket() pk := MakeWaveAIPacket()
pk.Text = choice.Delta.Content pk.Text = choice.Delta.Content
pk.FinishReason = choice.FinishReason pk.FinishReason = choice.FinishReason
rtn <- wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Response: *pk} rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: *pk}
} }
} }
}() }()

View File

@ -19,7 +19,7 @@ const DefaultAzureAPIVersion = "2023-05-15"
const ApiType_Anthropic = "anthropic" const ApiType_Anthropic = "anthropic"
const ApiType_Perplexity = "perplexity" const ApiType_Perplexity = "perplexity"
type OpenAICmdInfoPacketOutputType struct { type WaveAICmdInfoPacketOutputType struct {
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Created int64 `json:"created,omitempty"` Created int64 `json:"created,omitempty"`
FinishReason string `json:"finish_reason,omitempty"` FinishReason string `json:"finish_reason,omitempty"`
@ -27,14 +27,14 @@ type OpenAICmdInfoPacketOutputType struct {
Error string `json:"error,omitempty"` Error string `json:"error,omitempty"`
} }
func MakeOpenAIPacket() *wshrpc.OpenAIPacketType { func MakeWaveAIPacket() *wshrpc.WaveAIPacketType {
return &wshrpc.OpenAIPacketType{Type: OpenAIPacketStr} return &wshrpc.WaveAIPacketType{Type: OpenAIPacketStr}
} }
type OpenAICmdInfoChatMessage struct { type WaveAICmdInfoChatMessage struct {
MessageID int `json:"messageid"` MessageID int `json:"messageid"`
IsAssistantResponse bool `json:"isassistantresponse,omitempty"` IsAssistantResponse bool `json:"isassistantresponse,omitempty"`
AssistantResponse *OpenAICmdInfoPacketOutputType `json:"assistantresponse,omitempty"` AssistantResponse *WaveAICmdInfoPacketOutputType `json:"assistantresponse,omitempty"`
UserQuery string `json:"userquery,omitempty"` UserQuery string `json:"userquery,omitempty"`
UserEngineeredQuery string `json:"userengineeredquery,omitempty"` UserEngineeredQuery string `json:"userengineeredquery,omitempty"`
} }
@ -42,8 +42,8 @@ type OpenAICmdInfoChatMessage struct {
type AIBackend interface { type AIBackend interface {
StreamCompletion( StreamCompletion(
ctx context.Context, ctx context.Context,
request wshrpc.OpenAiStreamRequest, request wshrpc.WaveAIStreamRequest,
) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] ) chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]
} }
const DefaultMaxTokens = 2048 const DefaultMaxTokens = 2048
@ -53,18 +53,18 @@ const WCloudWSEndpointVarName = "WCLOUD_WS_ENDPOINT"
const CloudWebsocketConnectTimeout = 1 * time.Minute const CloudWebsocketConnectTimeout = 1 * time.Minute
func IsCloudAIRequest(opts *wshrpc.OpenAIOptsType) bool { func IsCloudAIRequest(opts *wshrpc.WaveAIOptsType) bool {
if opts == nil { if opts == nil {
return true return true
} }
return opts.BaseURL == "" && opts.APIToken == "" return opts.BaseURL == "" && opts.APIToken == ""
} }
func makeAIError(err error) wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { func makeAIError(err error) wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType] {
return wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType]{Error: err} return wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Error: err}
} }
func RunAICommand(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { func RunAICommand(ctx context.Context, request wshrpc.WaveAIStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType] {
telemetry.GoUpdateActivityWrap(wshrpc.ActivityUpdate{NumAIReqs: 1}, "RunAICommand") telemetry.GoUpdateActivityWrap(wshrpc.ActivityUpdate{NumAIReqs: 1}, "RunAICommand")
if request.Opts.APIType == ApiType_Anthropic { if request.Opts.APIType == ApiType_Anthropic {
endpoint := request.Opts.BaseURL endpoint := request.Opts.BaseURL

View File

@ -364,8 +364,8 @@ func StreamTestCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) chan wshrpc.Resp
} }
// command "streamwaveai", wshserver.StreamWaveAiCommand // command "streamwaveai", wshserver.StreamWaveAiCommand
func StreamWaveAiCommand(w *wshutil.WshRpc, data wshrpc.OpenAiStreamRequest, opts *wshrpc.RpcOpts) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { func StreamWaveAiCommand(w *wshutil.WshRpc, data wshrpc.WaveAIStreamRequest, opts *wshrpc.RpcOpts) chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType] {
return sendRpcRequestResponseStreamHelper[wshrpc.OpenAIPacketType](w, "streamwaveai", data, opts) return sendRpcRequestResponseStreamHelper[wshrpc.WaveAIPacketType](w, "streamwaveai", data, opts)
} }
// command "test", wshserver.TestCommand // command "test", wshserver.TestCommand

View File

@ -138,7 +138,7 @@ type WshRpcInterface interface {
EventUnsubAllCommand(ctx context.Context) error EventUnsubAllCommand(ctx context.Context) error
EventReadHistoryCommand(ctx context.Context, data CommandEventReadHistoryData) ([]*wps.WaveEvent, error) EventReadHistoryCommand(ctx context.Context, data CommandEventReadHistoryData) ([]*wps.WaveEvent, error)
StreamTestCommand(ctx context.Context) chan RespOrErrorUnion[int] StreamTestCommand(ctx context.Context) chan RespOrErrorUnion[int]
StreamWaveAiCommand(ctx context.Context, request OpenAiStreamRequest) chan RespOrErrorUnion[OpenAIPacketType] StreamWaveAiCommand(ctx context.Context, request WaveAIStreamRequest) chan RespOrErrorUnion[WaveAIPacketType]
StreamCpuDataCommand(ctx context.Context, request CpuDataRequest) chan RespOrErrorUnion[TimeSeriesData] StreamCpuDataCommand(ctx context.Context, request CpuDataRequest) chan RespOrErrorUnion[TimeSeriesData]
TestCommand(ctx context.Context, data string) error TestCommand(ctx context.Context, data string) error
SetConfigCommand(ctx context.Context, data MetaSettingsType) error SetConfigCommand(ctx context.Context, data MetaSettingsType) error
@ -377,19 +377,19 @@ type CommandEventReadHistoryData struct {
MaxItems int `json:"maxitems"` MaxItems int `json:"maxitems"`
} }
type OpenAiStreamRequest struct { type WaveAIStreamRequest struct {
ClientId string `json:"clientid,omitempty"` ClientId string `json:"clientid,omitempty"`
Opts *OpenAIOptsType `json:"opts"` Opts *WaveAIOptsType `json:"opts"`
Prompt []OpenAIPromptMessageType `json:"prompt"` Prompt []WaveAIPromptMessageType `json:"prompt"`
} }
type OpenAIPromptMessageType struct { type WaveAIPromptMessageType struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content string `json:"content"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
} }
type OpenAIOptsType struct { type WaveAIOptsType struct {
Model string `json:"model"` Model string `json:"model"`
APIType string `json:"apitype,omitempty"` APIType string `json:"apitype,omitempty"`
APIToken string `json:"apitoken"` APIToken string `json:"apitoken"`
@ -401,18 +401,18 @@ type OpenAIOptsType struct {
TimeoutMs int `json:"timeoutms,omitempty"` TimeoutMs int `json:"timeoutms,omitempty"`
} }
type OpenAIPacketType struct { type WaveAIPacketType struct {
Type string `json:"type"` Type string `json:"type"`
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Created int64 `json:"created,omitempty"` Created int64 `json:"created,omitempty"`
FinishReason string `json:"finish_reason,omitempty"` FinishReason string `json:"finish_reason,omitempty"`
Usage *OpenAIUsageType `json:"usage,omitempty"` Usage *WaveAIUsageType `json:"usage,omitempty"`
Index int `json:"index,omitempty"` Index int `json:"index,omitempty"`
Text string `json:"text,omitempty"` Text string `json:"text,omitempty"`
Error string `json:"error,omitempty"` Error string `json:"error,omitempty"`
} }
type OpenAIUsageType struct { type WaveAIUsageType struct {
PromptTokens int `json:"prompt_tokens,omitempty"` PromptTokens int `json:"prompt_tokens,omitempty"`
CompletionTokens int `json:"completion_tokens,omitempty"` CompletionTokens int `json:"completion_tokens,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"` TotalTokens int `json:"total_tokens,omitempty"`

View File

@ -73,7 +73,7 @@ func (ws *WshServer) StreamTestCommand(ctx context.Context) chan wshrpc.RespOrEr
return rtn return rtn
} }
func (ws *WshServer) StreamWaveAiCommand(ctx context.Context, request wshrpc.OpenAiStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.OpenAIPacketType] { func (ws *WshServer) StreamWaveAiCommand(ctx context.Context, request wshrpc.WaveAIStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType] {
return waveai.RunAICommand(ctx, request) return waveai.RunAICommand(ctx, request)
} }