mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-02-23 02:51:26 +01:00
openai api integration 'working'
This commit is contained in:
parent
ab5deafdb6
commit
8302ca1fcb
@ -24,6 +24,7 @@ import (
|
|||||||
"github.com/scripthaus-dev/sh2-server/pkg/comp"
|
"github.com/scripthaus-dev/sh2-server/pkg/comp"
|
||||||
"github.com/scripthaus-dev/sh2-server/pkg/pcloud"
|
"github.com/scripthaus-dev/sh2-server/pkg/pcloud"
|
||||||
"github.com/scripthaus-dev/sh2-server/pkg/remote"
|
"github.com/scripthaus-dev/sh2-server/pkg/remote"
|
||||||
|
"github.com/scripthaus-dev/sh2-server/pkg/remote/openai"
|
||||||
"github.com/scripthaus-dev/sh2-server/pkg/scbase"
|
"github.com/scripthaus-dev/sh2-server/pkg/scbase"
|
||||||
"github.com/scripthaus-dev/sh2-server/pkg/scpacket"
|
"github.com/scripthaus-dev/sh2-server/pkg/scpacket"
|
||||||
"github.com/scripthaus-dev/sh2-server/pkg/sstore"
|
"github.com/scripthaus-dev/sh2-server/pkg/sstore"
|
||||||
@ -41,7 +42,7 @@ func init() {
|
|||||||
comp.RegisterSimpleCompFn(comp.CGTypeCommandMeta, simpleCompCommandMeta)
|
comp.RegisterSimpleCompFn(comp.CGTypeCommandMeta, simpleCompCommandMeta)
|
||||||
}
|
}
|
||||||
|
|
||||||
const DefaultUserId = "sawka"
|
const DefaultUserId = "user"
|
||||||
const MaxNameLen = 50
|
const MaxNameLen = 50
|
||||||
const MaxShareNameLen = 150
|
const MaxShareNameLen = 150
|
||||||
const MaxRendererLen = 50
|
const MaxRendererLen = 50
|
||||||
@ -198,6 +199,9 @@ func init() {
|
|||||||
registerCmdFn("bookmark:set", BookmarkSetCommand)
|
registerCmdFn("bookmark:set", BookmarkSetCommand)
|
||||||
registerCmdFn("bookmark:delete", BookmarkDeleteCommand)
|
registerCmdFn("bookmark:delete", BookmarkDeleteCommand)
|
||||||
|
|
||||||
|
registerCmdFn("openai", OpenAICommand)
|
||||||
|
registerCmdFn("openai:stream", OpenAICommand)
|
||||||
|
|
||||||
registerCmdFn("_killserver", KillServerCommand)
|
registerCmdFn("_killserver", KillServerCommand)
|
||||||
|
|
||||||
registerCmdFn("set", SetCommand)
|
registerCmdFn("set", SetCommand)
|
||||||
@ -551,7 +555,6 @@ func EvalCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstore.
|
|||||||
log.Printf("[error] incrementing activity numcommands: %v\n", err)
|
log.Printf("[error] incrementing activity numcommands: %v\n", err)
|
||||||
// fall through (non-fatal error)
|
// fall through (non-fatal error)
|
||||||
}
|
}
|
||||||
log.Printf("inc numcommands\n")
|
|
||||||
}
|
}
|
||||||
if evalDepth > MaxEvalDepth {
|
if evalDepth > MaxEvalDepth {
|
||||||
return nil, fmt.Errorf("alias/history expansion max-depth exceeded")
|
return nil, fmt.Errorf("alias/history expansion max-depth exceeded")
|
||||||
@ -1312,6 +1315,185 @@ func GetFullRemoteDisplayName(rptr *sstore.RemotePtrType, rstate *remote.RemoteR
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func writeErrorToPty(cmd *sstore.CmdType, errStr string, outputPos int64) {
|
||||||
|
errPk := openai.CreateErrorPacket(errStr)
|
||||||
|
errBytes, err := packet.MarshalPacket(errPk)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("error writing error packet to openai response: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
errCtx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancelFn()
|
||||||
|
update, err := sstore.AppendToCmdPtyBlob(errCtx, cmd.ScreenId, cmd.CmdId, errBytes, outputPos)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("error writing ptyupdate for openai response: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sstore.MainBus.SendScreenUpdate(cmd.ScreenId, update)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func writePacketToPty(ctx context.Context, cmd *sstore.CmdType, pk packet.PacketType, outputPos *int64) error {
|
||||||
|
outBytes, err := packet.MarshalPacket(pk)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
update, err := sstore.AppendToCmdPtyBlob(ctx, cmd.ScreenId, cmd.CmdId, outBytes, *outputPos)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*outputPos += int64(len(outBytes))
|
||||||
|
sstore.MainBus.SendScreenUpdate(cmd.ScreenId, update)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func doOpenAICompletion(cmd *sstore.CmdType, opts *sstore.OpenAIOptsType, prompt []sstore.OpenAIPromptMessageType) {
|
||||||
|
var outputPos int64
|
||||||
|
var hadError bool
|
||||||
|
startTime := time.Now()
|
||||||
|
ctx, cancelFn := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancelFn()
|
||||||
|
defer func() {
|
||||||
|
r := recover()
|
||||||
|
if r != nil {
|
||||||
|
panicMsg := fmt.Sprintf("panic: %v", r)
|
||||||
|
log.Printf("panic in doOpenAICompletion: %s\n", panicMsg)
|
||||||
|
writeErrorToPty(cmd, panicMsg, outputPos)
|
||||||
|
hadError = true
|
||||||
|
}
|
||||||
|
duration := time.Since(startTime)
|
||||||
|
cmdStatus := sstore.CmdStatusDone
|
||||||
|
var exitCode int64
|
||||||
|
if hadError {
|
||||||
|
cmdStatus = sstore.CmdStatusError
|
||||||
|
exitCode = 1
|
||||||
|
}
|
||||||
|
doneInfo := &sstore.CmdDoneInfo{
|
||||||
|
Ts: time.Now().UnixMilli(),
|
||||||
|
ExitCode: exitCode,
|
||||||
|
DurationMs: duration.Milliseconds(),
|
||||||
|
}
|
||||||
|
ck := base.MakeCommandKey(cmd.ScreenId, cmd.CmdId)
|
||||||
|
update, err := sstore.UpdateCmdDoneInfo(context.Background(), ck, doneInfo, cmdStatus)
|
||||||
|
if err != nil {
|
||||||
|
// nothing to do
|
||||||
|
log.Printf("error updating cmddoneinfo (in openai): %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sstore.MainBus.SendScreenUpdate(cmd.ScreenId, update)
|
||||||
|
}()
|
||||||
|
respPks, err := openai.RunCompletion(ctx, opts, prompt)
|
||||||
|
if err != nil {
|
||||||
|
writeErrorToPty(cmd, fmt.Sprintf("error calling OpenAI API: %v", err), outputPos)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, pk := range respPks {
|
||||||
|
err = writePacketToPty(ctx, cmd, pk, &outputPos)
|
||||||
|
if err != nil {
|
||||||
|
writeErrorToPty(cmd, fmt.Sprintf("error writing response to ptybuffer: %v", err), outputPos)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func doOpenAIStreamCompletion(cmd *sstore.CmdType, opts *sstore.OpenAIOptsType, prompt []sstore.OpenAIPromptMessageType) {
|
||||||
|
var outputPos int64
|
||||||
|
var hadError bool
|
||||||
|
startTime := time.Now()
|
||||||
|
ctx, cancelFn := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancelFn()
|
||||||
|
defer func() {
|
||||||
|
r := recover()
|
||||||
|
if r != nil {
|
||||||
|
panicMsg := fmt.Sprintf("panic: %v", r)
|
||||||
|
log.Printf("panic in doOpenAICompletion: %s\n", panicMsg)
|
||||||
|
writeErrorToPty(cmd, panicMsg, outputPos)
|
||||||
|
hadError = true
|
||||||
|
}
|
||||||
|
duration := time.Since(startTime)
|
||||||
|
cmdStatus := sstore.CmdStatusDone
|
||||||
|
var exitCode int64
|
||||||
|
if hadError {
|
||||||
|
cmdStatus = sstore.CmdStatusError
|
||||||
|
exitCode = 1
|
||||||
|
}
|
||||||
|
doneInfo := &sstore.CmdDoneInfo{
|
||||||
|
Ts: time.Now().UnixMilli(),
|
||||||
|
ExitCode: exitCode,
|
||||||
|
DurationMs: duration.Milliseconds(),
|
||||||
|
}
|
||||||
|
ck := base.MakeCommandKey(cmd.ScreenId, cmd.CmdId)
|
||||||
|
update, err := sstore.UpdateCmdDoneInfo(context.Background(), ck, doneInfo, cmdStatus)
|
||||||
|
if err != nil {
|
||||||
|
// nothing to do
|
||||||
|
log.Printf("error updating cmddoneinfo (in openai): %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sstore.MainBus.SendScreenUpdate(cmd.ScreenId, update)
|
||||||
|
}()
|
||||||
|
ch, err := openai.RunCompletionStream(ctx, opts, prompt)
|
||||||
|
if err != nil {
|
||||||
|
writeErrorToPty(cmd, fmt.Sprintf("error calling OpenAI API: %v", err), outputPos)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for pk := range ch {
|
||||||
|
err = writePacketToPty(ctx, cmd, pk, &outputPos)
|
||||||
|
if err != nil {
|
||||||
|
writeErrorToPty(cmd, fmt.Sprintf("error writing response to ptybuffer: %v", err), outputPos)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func OpenAICommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstore.UpdatePacket, error) {
|
||||||
|
ids, err := resolveUiIds(ctx, pk, R_Session|R_Screen)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("/%s error: %w", GetCmdStr(pk), err)
|
||||||
|
}
|
||||||
|
opts := &sstore.OpenAIOptsType{
|
||||||
|
Model: "gpt-3.5-turbo",
|
||||||
|
APIToken: OpenAIKey,
|
||||||
|
MaxTokens: 1000,
|
||||||
|
}
|
||||||
|
promptStr := firstArg(pk)
|
||||||
|
if promptStr == "" {
|
||||||
|
return nil, fmt.Errorf("/openai error, prompt string is blank")
|
||||||
|
}
|
||||||
|
ptermVal := defaultStr(pk.Kwargs["pterm"], DefaultPTERM)
|
||||||
|
pkTermOpts, err := GetUITermOpts(pk.UIContext.WinSize, ptermVal)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("/openai error, invalid 'pterm' value %q: %v", ptermVal, err)
|
||||||
|
}
|
||||||
|
termOpts := convertTermOpts(pkTermOpts)
|
||||||
|
cmd, err := makeDynCmd(ctx, GetCmdStr(pk), ids, pk.GetRawStr(), *termOpts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("/openai error, cannot make dyn cmd")
|
||||||
|
}
|
||||||
|
line, err := sstore.AddOpenAILine(ctx, ids.ScreenId, DefaultUserId, cmd)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cannot add new line: %v", err)
|
||||||
|
}
|
||||||
|
prompt := []sstore.OpenAIPromptMessageType{{Role: sstore.OpenAIRoleUser, Content: promptStr}}
|
||||||
|
if pk.MetaSubCmd == "stream" {
|
||||||
|
go doOpenAIStreamCompletion(cmd, opts, prompt)
|
||||||
|
} else {
|
||||||
|
go doOpenAICompletion(cmd, opts, prompt)
|
||||||
|
}
|
||||||
|
updateHistoryContext(ctx, line, cmd)
|
||||||
|
updateMap := make(map[string]interface{})
|
||||||
|
updateMap[sstore.ScreenField_SelectedLine] = line.LineNum
|
||||||
|
updateMap[sstore.ScreenField_Focus] = sstore.ScreenFocusInput
|
||||||
|
screen, err := sstore.UpdateScreen(ctx, ids.ScreenId, updateMap)
|
||||||
|
if err != nil {
|
||||||
|
// ignore error again (nothing to do)
|
||||||
|
log.Printf("/openai error updating screen selected line: %v\n", err)
|
||||||
|
}
|
||||||
|
update := sstore.ModelUpdate{Line: line, Cmd: cmd, Screens: []*sstore.ScreenType{screen}}
|
||||||
|
return update, nil
|
||||||
|
}
|
||||||
|
|
||||||
func CrCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstore.UpdatePacket, error) {
|
func CrCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstore.UpdatePacket, error) {
|
||||||
ids, err := resolveUiIds(ctx, pk, R_Session|R_Screen)
|
ids, err := resolveUiIds(ctx, pk, R_Session|R_Screen)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -1350,6 +1532,33 @@ func CrCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstore.Up
|
|||||||
return update, nil
|
return update, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func makeDynCmd(ctx context.Context, metaCmd string, ids resolvedIds, cmdStr string, termOpts sstore.TermOpts) (*sstore.CmdType, error) {
|
||||||
|
cmd := &sstore.CmdType{
|
||||||
|
ScreenId: ids.ScreenId,
|
||||||
|
CmdId: scbase.GenPromptUUID(),
|
||||||
|
CmdStr: cmdStr,
|
||||||
|
RawCmdStr: cmdStr,
|
||||||
|
Remote: ids.Remote.RemotePtr,
|
||||||
|
TermOpts: termOpts,
|
||||||
|
Status: sstore.CmdStatusRunning,
|
||||||
|
StartPk: nil,
|
||||||
|
DoneInfo: nil,
|
||||||
|
RunOut: nil,
|
||||||
|
}
|
||||||
|
if ids.Remote.StatePtr != nil {
|
||||||
|
cmd.StatePtr = *ids.Remote.StatePtr
|
||||||
|
}
|
||||||
|
if ids.Remote.FeState != nil {
|
||||||
|
cmd.FeState = ids.Remote.FeState
|
||||||
|
}
|
||||||
|
err := sstore.CreateCmdPtyFile(ctx, cmd.ScreenId, cmd.CmdId, cmd.TermOpts.MaxPtySize)
|
||||||
|
if err != nil {
|
||||||
|
// TODO tricky error since the command was a success, but we can't show the output
|
||||||
|
return nil, fmt.Errorf("cannot create local ptyout file for %s command: %w", metaCmd, err)
|
||||||
|
}
|
||||||
|
return cmd, nil
|
||||||
|
}
|
||||||
|
|
||||||
func makeStaticCmd(ctx context.Context, metaCmd string, ids resolvedIds, cmdStr string, cmdOutput []byte) (*sstore.CmdType, error) {
|
func makeStaticCmd(ctx context.Context, metaCmd string, ids resolvedIds, cmdStr string, cmdOutput []byte) (*sstore.CmdType, error) {
|
||||||
cmd := &sstore.CmdType{
|
cmd := &sstore.CmdType{
|
||||||
ScreenId: ids.ScreenId,
|
ScreenId: ids.ScreenId,
|
||||||
|
@ -28,6 +28,20 @@ var BareMetaCmds = []BareMetaCmdDecl{
|
|||||||
BareMetaCmdDecl{"reset", "reset"},
|
BareMetaCmdDecl{"reset", "reset"},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
CmdParseTypePositional = "pos"
|
||||||
|
CmdParseTypeRaw = "raw"
|
||||||
|
)
|
||||||
|
|
||||||
|
var CmdParseOverrides map[string]string = map[string]string{
|
||||||
|
"setenv": CmdParseTypePositional,
|
||||||
|
"unset": CmdParseTypePositional,
|
||||||
|
"set": CmdParseTypePositional,
|
||||||
|
"run": CmdParseTypeRaw,
|
||||||
|
"comment": CmdParseTypeRaw,
|
||||||
|
"openai": CmdParseTypeRaw,
|
||||||
|
}
|
||||||
|
|
||||||
func DumpPacket(pk *scpacket.FeCommandPacketType) {
|
func DumpPacket(pk *scpacket.FeCommandPacketType) {
|
||||||
if pk == nil || pk.MetaCmd == "" {
|
if pk == nil || pk.MetaCmd == "" {
|
||||||
fmt.Printf("[no metacmd]\n")
|
fmt.Printf("[no metacmd]\n")
|
||||||
@ -111,11 +125,11 @@ func parseMetaCmd(origCommandStr string) (string, string, string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func onlyPositionalArgs(metaCmd string, metaSubCmd string) bool {
|
func onlyPositionalArgs(metaCmd string, metaSubCmd string) bool {
|
||||||
return (metaCmd == "setenv" || metaCmd == "unset" || metaCmd == "set") && metaSubCmd == ""
|
return (CmdParseOverrides[metaCmd] == CmdParseTypePositional) && metaSubCmd == ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func onlyRawArgs(metaCmd string, metaSubCmd string) bool {
|
func onlyRawArgs(metaCmd string, metaSubCmd string) bool {
|
||||||
return metaCmd == "run" || metaCmd == "comment"
|
return CmdParseOverrides[metaCmd] == CmdParseTypeRaw
|
||||||
}
|
}
|
||||||
|
|
||||||
func setBracketArgs(argMap map[string]string, bracketStr string) error {
|
func setBracketArgs(argMap map[string]string, bracketStr string) error {
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/scripthaus-dev/mshell/pkg/packet"
|
"github.com/scripthaus-dev/mshell/pkg/packet"
|
||||||
"github.com/scripthaus-dev/mshell/pkg/shexec"
|
"github.com/scripthaus-dev/mshell/pkg/shexec"
|
||||||
"github.com/scripthaus-dev/sh2-server/pkg/remote"
|
"github.com/scripthaus-dev/sh2-server/pkg/remote"
|
||||||
|
"github.com/scripthaus-dev/sh2-server/pkg/sstore"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PTERM=MxM,Mx25
|
// PTERM=MxM,Mx25
|
||||||
@ -108,3 +109,12 @@ func GetUITermOpts(winSize *packet.WinSize, ptermStr string) (*packet.TermOpts,
|
|||||||
termOpts.Rows = base.BoundInt(termOpts.Rows, shexec.MinTermRows, shexec.MaxTermRows)
|
termOpts.Rows = base.BoundInt(termOpts.Rows, shexec.MinTermRows, shexec.MaxTermRows)
|
||||||
return termOpts, nil
|
return termOpts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func convertTermOpts(pkto *packet.TermOpts) *sstore.TermOpts {
|
||||||
|
return &sstore.TermOpts{
|
||||||
|
Rows: int64(pkto.Rows),
|
||||||
|
Cols: int64(pkto.Cols),
|
||||||
|
FlexRows: true,
|
||||||
|
MaxPtySize: pkto.MaxPtySize,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
147
pkg/remote/openai/openai.go
Normal file
147
pkg/remote/openai/openai.go
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
openaiapi "github.com/sashabaranov/go-openai"
|
||||||
|
"github.com/scripthaus-dev/mshell/pkg/packet"
|
||||||
|
"github.com/scripthaus-dev/sh2-server/pkg/sstore"
|
||||||
|
)
|
||||||
|
|
||||||
|
// https://github.com/tiktoken-go/tokenizer
|
||||||
|
|
||||||
|
const DefaultStreamChanSize = 10
|
||||||
|
|
||||||
|
func convertUsage(resp openaiapi.ChatCompletionResponse) *packet.OpenAIUsageType {
|
||||||
|
if resp.Usage.TotalTokens == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &packet.OpenAIUsageType{
|
||||||
|
PromptTokens: resp.Usage.PromptTokens,
|
||||||
|
CompletionTokens: resp.Usage.CompletionTokens,
|
||||||
|
TotalTokens: resp.Usage.TotalTokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertPrompt(prompt []sstore.OpenAIPromptMessageType) []openaiapi.ChatCompletionMessage {
|
||||||
|
var rtn []openaiapi.ChatCompletionMessage
|
||||||
|
for _, p := range prompt {
|
||||||
|
msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name}
|
||||||
|
rtn = append(rtn, msg)
|
||||||
|
}
|
||||||
|
return rtn
|
||||||
|
}
|
||||||
|
|
||||||
|
func RunCompletion(ctx context.Context, opts *sstore.OpenAIOptsType, prompt []sstore.OpenAIPromptMessageType) ([]*packet.OpenAIPacketType, error) {
|
||||||
|
if opts == nil {
|
||||||
|
return nil, fmt.Errorf("no openai opts found")
|
||||||
|
}
|
||||||
|
if opts.Model == "" {
|
||||||
|
return nil, fmt.Errorf("no openai model specified")
|
||||||
|
}
|
||||||
|
if opts.APIToken == "" {
|
||||||
|
return nil, fmt.Errorf("no api token")
|
||||||
|
}
|
||||||
|
client := openaiapi.NewClient(opts.APIToken)
|
||||||
|
req := openaiapi.ChatCompletionRequest{
|
||||||
|
Model: opts.Model,
|
||||||
|
Messages: convertPrompt(prompt),
|
||||||
|
MaxTokens: opts.MaxTokens,
|
||||||
|
}
|
||||||
|
if opts.MaxChoices > 1 {
|
||||||
|
req.N = opts.MaxChoices
|
||||||
|
}
|
||||||
|
apiResp, err := client.CreateChatCompletion(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error calling openai API: %v", err)
|
||||||
|
}
|
||||||
|
if len(apiResp.Choices) == 0 {
|
||||||
|
return nil, fmt.Errorf("no response received")
|
||||||
|
}
|
||||||
|
return marshalResponse(apiResp), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func RunCompletionStream(ctx context.Context, opts *sstore.OpenAIOptsType, prompt []sstore.OpenAIPromptMessageType) (chan *packet.OpenAIPacketType, error) {
|
||||||
|
if opts == nil {
|
||||||
|
return nil, fmt.Errorf("no openai opts found")
|
||||||
|
}
|
||||||
|
if opts.Model == "" {
|
||||||
|
return nil, fmt.Errorf("no openai model specified")
|
||||||
|
}
|
||||||
|
if opts.APIToken == "" {
|
||||||
|
return nil, fmt.Errorf("no api token")
|
||||||
|
}
|
||||||
|
client := openaiapi.NewClient(opts.APIToken)
|
||||||
|
req := openaiapi.ChatCompletionRequest{
|
||||||
|
Model: opts.Model,
|
||||||
|
Messages: convertPrompt(prompt),
|
||||||
|
MaxTokens: opts.MaxTokens,
|
||||||
|
Stream: true,
|
||||||
|
}
|
||||||
|
if opts.MaxChoices > 1 {
|
||||||
|
req.N = opts.MaxChoices
|
||||||
|
}
|
||||||
|
apiResp, err := client.CreateChatCompletionStream(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error calling openai API: %v", err)
|
||||||
|
}
|
||||||
|
rtn := make(chan *packet.OpenAIPacketType, DefaultStreamChanSize)
|
||||||
|
go func() {
|
||||||
|
sentHeader := false
|
||||||
|
defer close(rtn)
|
||||||
|
for {
|
||||||
|
streamResp, err := apiResp.Recv()
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
errPk := CreateErrorPacket(fmt.Sprintf("error in recv of streaming data: %v", err))
|
||||||
|
rtn <- errPk
|
||||||
|
break
|
||||||
|
}
|
||||||
|
log.Printf("stream-resp: %#v\n", streamResp)
|
||||||
|
if streamResp.Model != "" && !sentHeader {
|
||||||
|
pk := packet.MakeOpenAIPacket()
|
||||||
|
pk.Model = streamResp.Model
|
||||||
|
pk.Created = streamResp.Created
|
||||||
|
rtn <- pk
|
||||||
|
sentHeader = true
|
||||||
|
}
|
||||||
|
for _, choice := range streamResp.Choices {
|
||||||
|
pk := packet.MakeOpenAIPacket()
|
||||||
|
pk.Index = choice.Index
|
||||||
|
pk.Text = choice.Delta.Content
|
||||||
|
pk.FinishReason = choice.FinishReason
|
||||||
|
rtn <- pk
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return rtn, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func marshalResponse(resp openaiapi.ChatCompletionResponse) []*packet.OpenAIPacketType {
|
||||||
|
var rtn []*packet.OpenAIPacketType
|
||||||
|
headerPk := packet.MakeOpenAIPacket()
|
||||||
|
headerPk.Model = resp.Model
|
||||||
|
headerPk.Created = resp.Created
|
||||||
|
headerPk.Usage = convertUsage(resp)
|
||||||
|
rtn = append(rtn, headerPk)
|
||||||
|
for _, choice := range resp.Choices {
|
||||||
|
choicePk := packet.MakeOpenAIPacket()
|
||||||
|
choicePk.Index = choice.Index
|
||||||
|
choicePk.Text = choice.Message.Content
|
||||||
|
choicePk.FinishReason = choice.FinishReason
|
||||||
|
rtn = append(rtn, choicePk)
|
||||||
|
}
|
||||||
|
return rtn
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateErrorPacket(errStr string) *packet.OpenAIPacketType {
|
||||||
|
errPk := packet.MakeOpenAIPacket()
|
||||||
|
errPk.Text = errStr
|
||||||
|
errPk.FinishReason = "stop"
|
||||||
|
return errPk
|
||||||
|
}
|
@ -1647,7 +1647,7 @@ func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) {
|
|||||||
ExitCode: int64(donePk.ExitCode),
|
ExitCode: int64(donePk.ExitCode),
|
||||||
DurationMs: donePk.DurationMs,
|
DurationMs: donePk.DurationMs,
|
||||||
}
|
}
|
||||||
update, err := sstore.UpdateCmdDoneInfo(context.Background(), donePk.CK, doneInfo)
|
update, err := sstore.UpdateCmdDoneInfo(context.Background(), donePk.CK, doneInfo, sstore.CmdStatusDone)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
msh.WriteToPtyBuffer("*error updating cmddone: %v\n", err)
|
msh.WriteToPtyBuffer("*error updating cmddone: %v\n", err)
|
||||||
return
|
return
|
||||||
|
@ -848,7 +848,7 @@ func GetCmdByScreenId(ctx context.Context, screenId string, cmdId string) (*CmdT
|
|||||||
return cmd, nil
|
return cmd, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateCmdDoneInfo(ctx context.Context, ck base.CommandKey, doneInfo *CmdDoneInfo) (*ModelUpdate, error) {
|
func UpdateCmdDoneInfo(ctx context.Context, ck base.CommandKey, doneInfo *CmdDoneInfo, status string) (*ModelUpdate, error) {
|
||||||
if doneInfo == nil {
|
if doneInfo == nil {
|
||||||
return nil, fmt.Errorf("invalid cmddone packet")
|
return nil, fmt.Errorf("invalid cmddone packet")
|
||||||
}
|
}
|
||||||
@ -859,7 +859,7 @@ func UpdateCmdDoneInfo(ctx context.Context, ck base.CommandKey, doneInfo *CmdDon
|
|||||||
var rtnCmd *CmdType
|
var rtnCmd *CmdType
|
||||||
txErr := WithTx(ctx, func(tx *TxWrap) error {
|
txErr := WithTx(ctx, func(tx *TxWrap) error {
|
||||||
query := `UPDATE cmd SET status = ?, doneinfo = ? WHERE screenid = ? AND cmdid = ?`
|
query := `UPDATE cmd SET status = ?, doneinfo = ? WHERE screenid = ? AND cmdid = ?`
|
||||||
tx.Exec(query, CmdStatusDone, quickJson(doneInfo), screenId, ck.GetCmdId())
|
tx.Exec(query, status, quickJson(doneInfo), screenId, ck.GetCmdId())
|
||||||
var err error
|
var err error
|
||||||
rtnCmd, err = GetCmdByScreenId(tx.Context(), screenId, ck.GetCmdId())
|
rtnCmd, err = GetCmdByScreenId(tx.Context(), screenId, ck.GetCmdId())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -28,8 +28,6 @@ import (
|
|||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
const LineTypeCmd = "cmd"
|
|
||||||
const LineTypeText = "text"
|
|
||||||
const LineNoHeight = -1
|
const LineNoHeight = -1
|
||||||
const DBFileName = "prompt.db"
|
const DBFileName = "prompt.db"
|
||||||
const DBFileNameBackup = "backup.prompt.db"
|
const DBFileNameBackup = "backup.prompt.db"
|
||||||
@ -41,6 +39,12 @@ const LocalRemoteAlias = "local"
|
|||||||
|
|
||||||
const DefaultCwd = "~"
|
const DefaultCwd = "~"
|
||||||
|
|
||||||
|
const (
|
||||||
|
LineTypeCmd = "cmd"
|
||||||
|
LineTypeText = "text"
|
||||||
|
LineTypeOpenAI = "openai"
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
MainViewSession = "session"
|
MainViewSession = "session"
|
||||||
MainViewBookmarks = "bookmarks"
|
MainViewBookmarks = "bookmarks"
|
||||||
@ -56,6 +60,16 @@ const (
|
|||||||
CmdStatusWaiting = "waiting"
|
CmdStatusWaiting = "waiting"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
CmdRendererOpenAI = "openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
OpenAIRoleSystem = "system"
|
||||||
|
OpenAIRoleUser = "user"
|
||||||
|
OpenAIRoleAssistant = "assistant"
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RemoteAuthTypeNone = "none"
|
RemoteAuthTypeNone = "none"
|
||||||
RemoteAuthTypePassword = "password"
|
RemoteAuthTypePassword = "password"
|
||||||
@ -685,6 +699,31 @@ type LineType struct {
|
|||||||
Remove bool `json:"remove,omitempty"`
|
Remove bool `json:"remove,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OpenAIUsage struct {
|
||||||
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIChoiceType struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIResponse struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Usage *OpenAIUsage `json:"usage,omitempty"`
|
||||||
|
Choices []OpenAIChoiceType `json:"choices,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIPromptMessageType struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
type PlaybookType struct {
|
type PlaybookType struct {
|
||||||
PlaybookId string `json:"playbookid"`
|
PlaybookId string `json:"playbookid"`
|
||||||
PlaybookName string `json:"playbookname"`
|
PlaybookName string `json:"playbookname"`
|
||||||
@ -829,6 +868,10 @@ type RemoteOptsType struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type OpenAIOptsType struct {
|
type OpenAIOptsType struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
APIToken string `json:"apitoken"`
|
||||||
|
MaxTokens int `json:"maxtokens,omitempty"`
|
||||||
|
MaxChoices int `json:"maxchoices,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type RemoteType struct {
|
type RemoteType struct {
|
||||||
@ -1011,6 +1054,20 @@ func makeNewLineText(screenId string, userId string, text string) *LineType {
|
|||||||
return rtn
|
return rtn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func makeNewLineOpenAI(screenId string, userId string, cmdId string) *LineType {
|
||||||
|
rtn := &LineType{}
|
||||||
|
rtn.ScreenId = screenId
|
||||||
|
rtn.UserId = userId
|
||||||
|
rtn.LineId = scbase.GenPromptUUID()
|
||||||
|
rtn.CmdId = cmdId
|
||||||
|
rtn.Ts = time.Now().UnixMilli()
|
||||||
|
rtn.LineLocal = true
|
||||||
|
rtn.LineType = LineTypeOpenAI
|
||||||
|
rtn.ContentHeight = LineNoHeight
|
||||||
|
rtn.Renderer = CmdRendererOpenAI
|
||||||
|
return rtn
|
||||||
|
}
|
||||||
|
|
||||||
func AddCommentLine(ctx context.Context, screenId string, userId string, commentText string) (*LineType, error) {
|
func AddCommentLine(ctx context.Context, screenId string, userId string, commentText string) (*LineType, error) {
|
||||||
rtnLine := makeNewLineText(screenId, userId, commentText)
|
rtnLine := makeNewLineText(screenId, userId, commentText)
|
||||||
err := InsertLine(ctx, rtnLine, nil)
|
err := InsertLine(ctx, rtnLine, nil)
|
||||||
@ -1020,6 +1077,15 @@ func AddCommentLine(ctx context.Context, screenId string, userId string, comment
|
|||||||
return rtnLine, nil
|
return rtnLine, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func AddOpenAILine(ctx context.Context, screenId string, userId string, cmd *CmdType) (*LineType, error) {
|
||||||
|
rtnLine := makeNewLineOpenAI(screenId, userId, cmd.CmdId)
|
||||||
|
err := InsertLine(ctx, rtnLine, cmd)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return rtnLine, nil
|
||||||
|
}
|
||||||
|
|
||||||
func AddCmdLine(ctx context.Context, screenId string, userId string, cmd *CmdType, renderer string) (*LineType, error) {
|
func AddCmdLine(ctx context.Context, screenId string, userId string, cmd *CmdType, renderer string) (*LineType, error) {
|
||||||
rtnLine := makeNewLineCmd(screenId, userId, cmd.CmdId, renderer)
|
rtnLine := makeNewLineCmd(screenId, userId, cmd.CmdId, renderer)
|
||||||
err := InsertLine(ctx, rtnLine, cmd)
|
err := InsertLine(ctx, rtnLine, cmd)
|
||||||
|
Loading…
Reference in New Issue
Block a user