waveterm/pkg/waveai/googlebackend.go
2024-12-27 17:39:42 -05:00

92 lines
2.1 KiB
Go

package waveai
import (
"context"
"fmt"
"log"
"github.com/google/generative-ai-go/genai"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
)
type GoogleBackend struct{}
var _ AIBackend = GoogleBackend{}
func (GoogleBackend) StreamCompletion(ctx context.Context, request wshrpc.WaveAIStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType] {
client, err := genai.NewClient(ctx, option.WithAPIKey(request.Opts.APIToken))
if err != nil {
log.Printf("failed to create client: %v", err)
return nil
}
model := client.GenerativeModel(request.Opts.Model)
if model == nil {
log.Println("model not found")
client.Close()
return nil
}
cs := model.StartChat()
cs.History = extractHistory(request.Prompt)
iter := cs.SendMessageStream(ctx, extractPrompt(request.Prompt))
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType])
go func() {
defer client.Close()
defer close(rtn)
for {
// Check for context cancellation
select {
case <-ctx.Done():
rtn <- makeAIError(fmt.Errorf("request cancelled: %v", ctx.Err()))
break
default:
}
resp, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
rtn <- makeAIError(fmt.Errorf("Google API error: %v", err))
break
}
rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: wshrpc.WaveAIPacketType{Text: convertCandidatesToText(resp.Candidates)}}
}
}()
return rtn
}
func extractHistory(history []wshrpc.WaveAIPromptMessageType) []*genai.Content {
var rtn []*genai.Content
for _, h := range history[:len(history)-1] {
if h.Role == "user" || h.Role == "model" {
rtn = append(rtn, &genai.Content{
Role: h.Role,
Parts: []genai.Part{genai.Text(h.Content)},
})
}
}
return rtn
}
func extractPrompt(prompt []wshrpc.WaveAIPromptMessageType) genai.Part {
p := prompt[len(prompt)-1]
return genai.Text(p.Content)
}
func convertCandidatesToText(candidates []*genai.Candidate) string {
var rtn string
for _, c := range candidates {
for _, p := range c.Content.Parts {
rtn += fmt.Sprintf("%v", p)
}
}
return rtn
}