openai api integration 'working'

This commit is contained in:
sawka 2023-05-04 01:01:13 -07:00
parent ab5deafdb6
commit 8302ca1fcb
7 changed files with 455 additions and 9 deletions

View File

@ -24,6 +24,7 @@ import (
"github.com/scripthaus-dev/sh2-server/pkg/comp"
"github.com/scripthaus-dev/sh2-server/pkg/pcloud"
"github.com/scripthaus-dev/sh2-server/pkg/remote"
"github.com/scripthaus-dev/sh2-server/pkg/remote/openai"
"github.com/scripthaus-dev/sh2-server/pkg/scbase"
"github.com/scripthaus-dev/sh2-server/pkg/scpacket"
"github.com/scripthaus-dev/sh2-server/pkg/sstore"
@ -41,7 +42,7 @@ func init() {
comp.RegisterSimpleCompFn(comp.CGTypeCommandMeta, simpleCompCommandMeta)
}
const DefaultUserId = "sawka"
const DefaultUserId = "user"
const MaxNameLen = 50
const MaxShareNameLen = 150
const MaxRendererLen = 50
@ -198,6 +199,9 @@ func init() {
registerCmdFn("bookmark:set", BookmarkSetCommand)
registerCmdFn("bookmark:delete", BookmarkDeleteCommand)
registerCmdFn("openai", OpenAICommand)
registerCmdFn("openai:stream", OpenAICommand)
registerCmdFn("_killserver", KillServerCommand)
registerCmdFn("set", SetCommand)
@ -551,7 +555,6 @@ func EvalCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstore.
log.Printf("[error] incrementing activity numcommands: %v\n", err)
// fall through (non-fatal error)
}
log.Printf("inc numcommands\n")
}
if evalDepth > MaxEvalDepth {
return nil, fmt.Errorf("alias/history expansion max-depth exceeded")
@ -1312,6 +1315,185 @@ func GetFullRemoteDisplayName(rptr *sstore.RemotePtrType, rstate *remote.RemoteR
}
}
func writeErrorToPty(cmd *sstore.CmdType, errStr string, outputPos int64) {
errPk := openai.CreateErrorPacket(errStr)
errBytes, err := packet.MarshalPacket(errPk)
if err != nil {
log.Printf("error writing error packet to openai response: %v\n", err)
return
}
errCtx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelFn()
update, err := sstore.AppendToCmdPtyBlob(errCtx, cmd.ScreenId, cmd.CmdId, errBytes, outputPos)
if err != nil {
log.Printf("error writing ptyupdate for openai response: %v\n", err)
return
}
sstore.MainBus.SendScreenUpdate(cmd.ScreenId, update)
return
}
func writePacketToPty(ctx context.Context, cmd *sstore.CmdType, pk packet.PacketType, outputPos *int64) error {
outBytes, err := packet.MarshalPacket(pk)
if err != nil {
return err
}
update, err := sstore.AppendToCmdPtyBlob(ctx, cmd.ScreenId, cmd.CmdId, outBytes, *outputPos)
if err != nil {
return err
}
*outputPos += int64(len(outBytes))
sstore.MainBus.SendScreenUpdate(cmd.ScreenId, update)
return nil
}
func doOpenAICompletion(cmd *sstore.CmdType, opts *sstore.OpenAIOptsType, prompt []sstore.OpenAIPromptMessageType) {
var outputPos int64
var hadError bool
startTime := time.Now()
ctx, cancelFn := context.WithTimeout(context.Background(), 30*time.Second)
defer cancelFn()
defer func() {
r := recover()
if r != nil {
panicMsg := fmt.Sprintf("panic: %v", r)
log.Printf("panic in doOpenAICompletion: %s\n", panicMsg)
writeErrorToPty(cmd, panicMsg, outputPos)
hadError = true
}
duration := time.Since(startTime)
cmdStatus := sstore.CmdStatusDone
var exitCode int64
if hadError {
cmdStatus = sstore.CmdStatusError
exitCode = 1
}
doneInfo := &sstore.CmdDoneInfo{
Ts: time.Now().UnixMilli(),
ExitCode: exitCode,
DurationMs: duration.Milliseconds(),
}
ck := base.MakeCommandKey(cmd.ScreenId, cmd.CmdId)
update, err := sstore.UpdateCmdDoneInfo(context.Background(), ck, doneInfo, cmdStatus)
if err != nil {
// nothing to do
log.Printf("error updating cmddoneinfo (in openai): %v\n", err)
return
}
sstore.MainBus.SendScreenUpdate(cmd.ScreenId, update)
}()
respPks, err := openai.RunCompletion(ctx, opts, prompt)
if err != nil {
writeErrorToPty(cmd, fmt.Sprintf("error calling OpenAI API: %v", err), outputPos)
return
}
for _, pk := range respPks {
err = writePacketToPty(ctx, cmd, pk, &outputPos)
if err != nil {
writeErrorToPty(cmd, fmt.Sprintf("error writing response to ptybuffer: %v", err), outputPos)
return
}
}
return
}
func doOpenAIStreamCompletion(cmd *sstore.CmdType, opts *sstore.OpenAIOptsType, prompt []sstore.OpenAIPromptMessageType) {
var outputPos int64
var hadError bool
startTime := time.Now()
ctx, cancelFn := context.WithTimeout(context.Background(), 30*time.Second)
defer cancelFn()
defer func() {
r := recover()
if r != nil {
panicMsg := fmt.Sprintf("panic: %v", r)
log.Printf("panic in doOpenAICompletion: %s\n", panicMsg)
writeErrorToPty(cmd, panicMsg, outputPos)
hadError = true
}
duration := time.Since(startTime)
cmdStatus := sstore.CmdStatusDone
var exitCode int64
if hadError {
cmdStatus = sstore.CmdStatusError
exitCode = 1
}
doneInfo := &sstore.CmdDoneInfo{
Ts: time.Now().UnixMilli(),
ExitCode: exitCode,
DurationMs: duration.Milliseconds(),
}
ck := base.MakeCommandKey(cmd.ScreenId, cmd.CmdId)
update, err := sstore.UpdateCmdDoneInfo(context.Background(), ck, doneInfo, cmdStatus)
if err != nil {
// nothing to do
log.Printf("error updating cmddoneinfo (in openai): %v\n", err)
return
}
sstore.MainBus.SendScreenUpdate(cmd.ScreenId, update)
}()
ch, err := openai.RunCompletionStream(ctx, opts, prompt)
if err != nil {
writeErrorToPty(cmd, fmt.Sprintf("error calling OpenAI API: %v", err), outputPos)
return
}
for pk := range ch {
err = writePacketToPty(ctx, cmd, pk, &outputPos)
if err != nil {
writeErrorToPty(cmd, fmt.Sprintf("error writing response to ptybuffer: %v", err), outputPos)
return
}
}
return
}
func OpenAICommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstore.UpdatePacket, error) {
ids, err := resolveUiIds(ctx, pk, R_Session|R_Screen)
if err != nil {
return nil, fmt.Errorf("/%s error: %w", GetCmdStr(pk), err)
}
opts := &sstore.OpenAIOptsType{
Model: "gpt-3.5-turbo",
APIToken: OpenAIKey,
MaxTokens: 1000,
}
promptStr := firstArg(pk)
if promptStr == "" {
return nil, fmt.Errorf("/openai error, prompt string is blank")
}
ptermVal := defaultStr(pk.Kwargs["pterm"], DefaultPTERM)
pkTermOpts, err := GetUITermOpts(pk.UIContext.WinSize, ptermVal)
if err != nil {
return nil, fmt.Errorf("/openai error, invalid 'pterm' value %q: %v", ptermVal, err)
}
termOpts := convertTermOpts(pkTermOpts)
cmd, err := makeDynCmd(ctx, GetCmdStr(pk), ids, pk.GetRawStr(), *termOpts)
if err != nil {
return nil, fmt.Errorf("/openai error, cannot make dyn cmd")
}
line, err := sstore.AddOpenAILine(ctx, ids.ScreenId, DefaultUserId, cmd)
if err != nil {
return nil, fmt.Errorf("cannot add new line: %v", err)
}
prompt := []sstore.OpenAIPromptMessageType{{Role: sstore.OpenAIRoleUser, Content: promptStr}}
if pk.MetaSubCmd == "stream" {
go doOpenAIStreamCompletion(cmd, opts, prompt)
} else {
go doOpenAICompletion(cmd, opts, prompt)
}
updateHistoryContext(ctx, line, cmd)
updateMap := make(map[string]interface{})
updateMap[sstore.ScreenField_SelectedLine] = line.LineNum
updateMap[sstore.ScreenField_Focus] = sstore.ScreenFocusInput
screen, err := sstore.UpdateScreen(ctx, ids.ScreenId, updateMap)
if err != nil {
// ignore error again (nothing to do)
log.Printf("/openai error updating screen selected line: %v\n", err)
}
update := sstore.ModelUpdate{Line: line, Cmd: cmd, Screens: []*sstore.ScreenType{screen}}
return update, nil
}
func CrCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstore.UpdatePacket, error) {
ids, err := resolveUiIds(ctx, pk, R_Session|R_Screen)
if err != nil {
@ -1350,6 +1532,33 @@ func CrCommand(ctx context.Context, pk *scpacket.FeCommandPacketType) (sstore.Up
return update, nil
}
func makeDynCmd(ctx context.Context, metaCmd string, ids resolvedIds, cmdStr string, termOpts sstore.TermOpts) (*sstore.CmdType, error) {
cmd := &sstore.CmdType{
ScreenId: ids.ScreenId,
CmdId: scbase.GenPromptUUID(),
CmdStr: cmdStr,
RawCmdStr: cmdStr,
Remote: ids.Remote.RemotePtr,
TermOpts: termOpts,
Status: sstore.CmdStatusRunning,
StartPk: nil,
DoneInfo: nil,
RunOut: nil,
}
if ids.Remote.StatePtr != nil {
cmd.StatePtr = *ids.Remote.StatePtr
}
if ids.Remote.FeState != nil {
cmd.FeState = ids.Remote.FeState
}
err := sstore.CreateCmdPtyFile(ctx, cmd.ScreenId, cmd.CmdId, cmd.TermOpts.MaxPtySize)
if err != nil {
// TODO tricky error since the command was a success, but we can't show the output
return nil, fmt.Errorf("cannot create local ptyout file for %s command: %w", metaCmd, err)
}
return cmd, nil
}
func makeStaticCmd(ctx context.Context, metaCmd string, ids resolvedIds, cmdStr string, cmdOutput []byte) (*sstore.CmdType, error) {
cmd := &sstore.CmdType{
ScreenId: ids.ScreenId,

View File

@ -28,6 +28,20 @@ var BareMetaCmds = []BareMetaCmdDecl{
BareMetaCmdDecl{"reset", "reset"},
}
const (
CmdParseTypePositional = "pos"
CmdParseTypeRaw = "raw"
)
var CmdParseOverrides map[string]string = map[string]string{
"setenv": CmdParseTypePositional,
"unset": CmdParseTypePositional,
"set": CmdParseTypePositional,
"run": CmdParseTypeRaw,
"comment": CmdParseTypeRaw,
"openai": CmdParseTypeRaw,
}
func DumpPacket(pk *scpacket.FeCommandPacketType) {
if pk == nil || pk.MetaCmd == "" {
fmt.Printf("[no metacmd]\n")
@ -111,11 +125,11 @@ func parseMetaCmd(origCommandStr string) (string, string, string) {
}
func onlyPositionalArgs(metaCmd string, metaSubCmd string) bool {
return (metaCmd == "setenv" || metaCmd == "unset" || metaCmd == "set") && metaSubCmd == ""
return (CmdParseOverrides[metaCmd] == CmdParseTypePositional) && metaSubCmd == ""
}
func onlyRawArgs(metaCmd string, metaSubCmd string) bool {
return metaCmd == "run" || metaCmd == "comment"
return CmdParseOverrides[metaCmd] == CmdParseTypeRaw
}
func setBracketArgs(argMap map[string]string, bracketStr string) error {

View File

@ -9,6 +9,7 @@ import (
"github.com/scripthaus-dev/mshell/pkg/packet"
"github.com/scripthaus-dev/mshell/pkg/shexec"
"github.com/scripthaus-dev/sh2-server/pkg/remote"
"github.com/scripthaus-dev/sh2-server/pkg/sstore"
)
// PTERM=MxM,Mx25
@ -108,3 +109,12 @@ func GetUITermOpts(winSize *packet.WinSize, ptermStr string) (*packet.TermOpts,
termOpts.Rows = base.BoundInt(termOpts.Rows, shexec.MinTermRows, shexec.MaxTermRows)
return termOpts, nil
}
func convertTermOpts(pkto *packet.TermOpts) *sstore.TermOpts {
return &sstore.TermOpts{
Rows: int64(pkto.Rows),
Cols: int64(pkto.Cols),
FlexRows: true,
MaxPtySize: pkto.MaxPtySize,
}
}

147
pkg/remote/openai/openai.go Normal file
View File

@ -0,0 +1,147 @@
package openai
import (
"context"
"fmt"
"io"
"log"
openaiapi "github.com/sashabaranov/go-openai"
"github.com/scripthaus-dev/mshell/pkg/packet"
"github.com/scripthaus-dev/sh2-server/pkg/sstore"
)
// https://github.com/tiktoken-go/tokenizer
const DefaultStreamChanSize = 10
func convertUsage(resp openaiapi.ChatCompletionResponse) *packet.OpenAIUsageType {
if resp.Usage.TotalTokens == 0 {
return nil
}
return &packet.OpenAIUsageType{
PromptTokens: resp.Usage.PromptTokens,
CompletionTokens: resp.Usage.CompletionTokens,
TotalTokens: resp.Usage.TotalTokens,
}
}
func convertPrompt(prompt []sstore.OpenAIPromptMessageType) []openaiapi.ChatCompletionMessage {
var rtn []openaiapi.ChatCompletionMessage
for _, p := range prompt {
msg := openaiapi.ChatCompletionMessage{Role: p.Role, Content: p.Content, Name: p.Name}
rtn = append(rtn, msg)
}
return rtn
}
func RunCompletion(ctx context.Context, opts *sstore.OpenAIOptsType, prompt []sstore.OpenAIPromptMessageType) ([]*packet.OpenAIPacketType, error) {
if opts == nil {
return nil, fmt.Errorf("no openai opts found")
}
if opts.Model == "" {
return nil, fmt.Errorf("no openai model specified")
}
if opts.APIToken == "" {
return nil, fmt.Errorf("no api token")
}
client := openaiapi.NewClient(opts.APIToken)
req := openaiapi.ChatCompletionRequest{
Model: opts.Model,
Messages: convertPrompt(prompt),
MaxTokens: opts.MaxTokens,
}
if opts.MaxChoices > 1 {
req.N = opts.MaxChoices
}
apiResp, err := client.CreateChatCompletion(ctx, req)
if err != nil {
return nil, fmt.Errorf("error calling openai API: %v", err)
}
if len(apiResp.Choices) == 0 {
return nil, fmt.Errorf("no response received")
}
return marshalResponse(apiResp), nil
}
func RunCompletionStream(ctx context.Context, opts *sstore.OpenAIOptsType, prompt []sstore.OpenAIPromptMessageType) (chan *packet.OpenAIPacketType, error) {
if opts == nil {
return nil, fmt.Errorf("no openai opts found")
}
if opts.Model == "" {
return nil, fmt.Errorf("no openai model specified")
}
if opts.APIToken == "" {
return nil, fmt.Errorf("no api token")
}
client := openaiapi.NewClient(opts.APIToken)
req := openaiapi.ChatCompletionRequest{
Model: opts.Model,
Messages: convertPrompt(prompt),
MaxTokens: opts.MaxTokens,
Stream: true,
}
if opts.MaxChoices > 1 {
req.N = opts.MaxChoices
}
apiResp, err := client.CreateChatCompletionStream(ctx, req)
if err != nil {
return nil, fmt.Errorf("error calling openai API: %v", err)
}
rtn := make(chan *packet.OpenAIPacketType, DefaultStreamChanSize)
go func() {
sentHeader := false
defer close(rtn)
for {
streamResp, err := apiResp.Recv()
if err == io.EOF {
break
}
if err != nil {
errPk := CreateErrorPacket(fmt.Sprintf("error in recv of streaming data: %v", err))
rtn <- errPk
break
}
log.Printf("stream-resp: %#v\n", streamResp)
if streamResp.Model != "" && !sentHeader {
pk := packet.MakeOpenAIPacket()
pk.Model = streamResp.Model
pk.Created = streamResp.Created
rtn <- pk
sentHeader = true
}
for _, choice := range streamResp.Choices {
pk := packet.MakeOpenAIPacket()
pk.Index = choice.Index
pk.Text = choice.Delta.Content
pk.FinishReason = choice.FinishReason
rtn <- pk
}
}
}()
return rtn, err
}
func marshalResponse(resp openaiapi.ChatCompletionResponse) []*packet.OpenAIPacketType {
var rtn []*packet.OpenAIPacketType
headerPk := packet.MakeOpenAIPacket()
headerPk.Model = resp.Model
headerPk.Created = resp.Created
headerPk.Usage = convertUsage(resp)
rtn = append(rtn, headerPk)
for _, choice := range resp.Choices {
choicePk := packet.MakeOpenAIPacket()
choicePk.Index = choice.Index
choicePk.Text = choice.Message.Content
choicePk.FinishReason = choice.FinishReason
rtn = append(rtn, choicePk)
}
return rtn
}
func CreateErrorPacket(errStr string) *packet.OpenAIPacketType {
errPk := packet.MakeOpenAIPacket()
errPk.Text = errStr
errPk.FinishReason = "stop"
return errPk
}

View File

@ -1647,7 +1647,7 @@ func (msh *MShellProc) handleCmdDonePacket(donePk *packet.CmdDonePacketType) {
ExitCode: int64(donePk.ExitCode),
DurationMs: donePk.DurationMs,
}
update, err := sstore.UpdateCmdDoneInfo(context.Background(), donePk.CK, doneInfo)
update, err := sstore.UpdateCmdDoneInfo(context.Background(), donePk.CK, doneInfo, sstore.CmdStatusDone)
if err != nil {
msh.WriteToPtyBuffer("*error updating cmddone: %v\n", err)
return

View File

@ -848,7 +848,7 @@ func GetCmdByScreenId(ctx context.Context, screenId string, cmdId string) (*CmdT
return cmd, nil
}
func UpdateCmdDoneInfo(ctx context.Context, ck base.CommandKey, doneInfo *CmdDoneInfo) (*ModelUpdate, error) {
func UpdateCmdDoneInfo(ctx context.Context, ck base.CommandKey, doneInfo *CmdDoneInfo, status string) (*ModelUpdate, error) {
if doneInfo == nil {
return nil, fmt.Errorf("invalid cmddone packet")
}
@ -859,7 +859,7 @@ func UpdateCmdDoneInfo(ctx context.Context, ck base.CommandKey, doneInfo *CmdDon
var rtnCmd *CmdType
txErr := WithTx(ctx, func(tx *TxWrap) error {
query := `UPDATE cmd SET status = ?, doneinfo = ? WHERE screenid = ? AND cmdid = ?`
tx.Exec(query, CmdStatusDone, quickJson(doneInfo), screenId, ck.GetCmdId())
tx.Exec(query, status, quickJson(doneInfo), screenId, ck.GetCmdId())
var err error
rtnCmd, err = GetCmdByScreenId(tx.Context(), screenId, ck.GetCmdId())
if err != nil {

View File

@ -28,8 +28,6 @@ import (
_ "github.com/mattn/go-sqlite3"
)
const LineTypeCmd = "cmd"
const LineTypeText = "text"
const LineNoHeight = -1
const DBFileName = "prompt.db"
const DBFileNameBackup = "backup.prompt.db"
@ -41,6 +39,12 @@ const LocalRemoteAlias = "local"
const DefaultCwd = "~"
const (
LineTypeCmd = "cmd"
LineTypeText = "text"
LineTypeOpenAI = "openai"
)
const (
MainViewSession = "session"
MainViewBookmarks = "bookmarks"
@ -56,6 +60,16 @@ const (
CmdStatusWaiting = "waiting"
)
const (
CmdRendererOpenAI = "openai"
)
const (
OpenAIRoleSystem = "system"
OpenAIRoleUser = "user"
OpenAIRoleAssistant = "assistant"
)
const (
RemoteAuthTypeNone = "none"
RemoteAuthTypePassword = "password"
@ -685,6 +699,31 @@ type LineType struct {
Remove bool `json:"remove,omitempty"`
}
type OpenAIUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type OpenAIChoiceType struct {
Text string `json:"text"`
Index int `json:"index"`
FinishReason string `json:"finish_reason"`
}
type OpenAIResponse struct {
Model string `json:"model"`
Created int64 `json:"created"`
Usage *OpenAIUsage `json:"usage,omitempty"`
Choices []OpenAIChoiceType `json:"choices,omitempty"`
}
type OpenAIPromptMessageType struct {
Role string `json:"role"`
Content string `json:"content"`
Name string `json:"name,omitempty"`
}
type PlaybookType struct {
PlaybookId string `json:"playbookid"`
PlaybookName string `json:"playbookname"`
@ -829,6 +868,10 @@ type RemoteOptsType struct {
}
type OpenAIOptsType struct {
Model string `json:"model"`
APIToken string `json:"apitoken"`
MaxTokens int `json:"maxtokens,omitempty"`
MaxChoices int `json:"maxchoices,omitempty"`
}
type RemoteType struct {
@ -1011,6 +1054,20 @@ func makeNewLineText(screenId string, userId string, text string) *LineType {
return rtn
}
func makeNewLineOpenAI(screenId string, userId string, cmdId string) *LineType {
rtn := &LineType{}
rtn.ScreenId = screenId
rtn.UserId = userId
rtn.LineId = scbase.GenPromptUUID()
rtn.CmdId = cmdId
rtn.Ts = time.Now().UnixMilli()
rtn.LineLocal = true
rtn.LineType = LineTypeOpenAI
rtn.ContentHeight = LineNoHeight
rtn.Renderer = CmdRendererOpenAI
return rtn
}
func AddCommentLine(ctx context.Context, screenId string, userId string, commentText string) (*LineType, error) {
rtnLine := makeNewLineText(screenId, userId, commentText)
err := InsertLine(ctx, rtnLine, nil)
@ -1020,6 +1077,15 @@ func AddCommentLine(ctx context.Context, screenId string, userId string, comment
return rtnLine, nil
}
func AddOpenAILine(ctx context.Context, screenId string, userId string, cmd *CmdType) (*LineType, error) {
rtnLine := makeNewLineOpenAI(screenId, userId, cmd.CmdId)
err := InsertLine(ctx, rtnLine, cmd)
if err != nil {
return nil, err
}
return rtnLine, nil
}
func AddCmdLine(ctx context.Context, screenId string, userId string, cmd *CmdType, renderer string) (*LineType, error) {
rtnLine := makeNewLineCmd(screenId, userId, cmd.CmdId, renderer)
err := InsertLine(ctx, rtnLine, cmd)