waveterm/pkg/sstore/dbops.go
2022-07-07 00:10:37 -07:00

237 lines
6.6 KiB
Go

package sstore
import (
"context"
"database/sql"
"fmt"
"github.com/google/uuid"
)
func NumSessions(ctx context.Context) (int, error) {
db, err := GetDB()
if err != nil {
return 0, err
}
query := "SELECT count(*) FROM session"
var count int
err = db.GetContext(ctx, &count, query)
if err != nil {
return 0, err
}
return count, nil
}
const remoteSelectCols = "remoteid, remotetype, remotename, autoconnect, sshhost, sshopts, sshidentity, sshuser, lastconnectts"
func GetAllRemotes(ctx context.Context) ([]*RemoteType, error) {
db, err := GetDB()
if err != nil {
return nil, err
}
query := fmt.Sprintf(`SELECT %s FROM remote`, remoteSelectCols)
var remoteArr []*RemoteType
err = db.SelectContext(ctx, &remoteArr, query)
if err != nil {
return nil, err
}
return remoteArr, nil
}
func GetRemoteByName(ctx context.Context, remoteName string) (*RemoteType, error) {
db, err := GetDB()
if err != nil {
return nil, err
}
query := fmt.Sprintf(`SELECT %s FROM remote WHERE remotename = ?`, remoteSelectCols)
var remote RemoteType
err = db.GetContext(ctx, &remote, query, remoteName)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &remote, nil
}
func GetRemoteById(ctx context.Context, remoteId string) (*RemoteType, error) {
db, err := GetDB()
if err != nil {
return nil, err
}
query := fmt.Sprintf(`SELECT %s FROM remote WHERE remoteid = ?`, remoteSelectCols)
var remote RemoteType
err = db.GetContext(ctx, &remote, query, remoteId)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &remote, nil
}
func InsertRemote(ctx context.Context, remote *RemoteType) error {
if remote == nil {
return fmt.Errorf("cannot insert nil remote")
}
db, err := GetDB()
if err != nil {
return err
}
query := `INSERT INTO remote ( remoteid, remotetype, remotename, autoconnect, sshhost, sshopts, sshidentity, sshuser, lastconnectts) VALUES
(:remoteid,:remotetype,:remotename,:autoconnect,:sshhost,:sshopts,:sshidentity,:sshuser, 0 )`
_, err = db.NamedExec(query, remote)
if err != nil {
return err
}
return nil
}
func GetSessionById(ctx context.Context, id string) (*SessionType, error) {
var rtnSession *SessionType
err := WithTx(ctx, func(tx *TxWrap) error {
var session SessionType
query := `SELECT * FROM session WHERE sessionid = ?`
found := tx.GetWrap(&session, query, id)
if !found {
return nil
}
rtnSession = &session
query = `SELECT sessionid, windowid, name, curremote, version FROM window WHERE sessionid = ?`
tx.SelectWrap(&session.Windows, query, session.SessionId)
query = `SELECT * FROM remote_instance WHERE sessionid = ?`
tx.SelectWrap(&session.Remotes, query, session.SessionId)
query = `SELECT * FROM cmd WHERE sessionid = ?`
marr := tx.SelectMaps(query, session.SessionId)
for _, m := range marr {
session.Cmds = append(session.Cmds, CmdFromMap(m))
}
return nil
})
if err != nil {
return nil, err
}
return rtnSession, nil
}
func GetSessionByName(ctx context.Context, name string) (*SessionType, error) {
db, err := GetDB()
if err != nil {
return nil, err
}
var sessionId string
query := `SELECT sessionid FROM session WHERE name = ?`
err = db.GetContext(ctx, &sessionId, query, name)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
return GetSessionById(ctx, sessionId)
}
func GetWindowLines(ctx context.Context, sessionId string, windowId string) ([]*LineType, error) {
var lines []*LineType
db, err := GetDB()
if err != nil {
return nil, err
}
query := `SELECT * FROM line WHERE sessionid = ? AND windowid = ?`
err = db.SelectContext(ctx, &lines, query, sessionId, windowId)
if err != nil {
return nil, err
}
return lines, nil
}
// also creates window
func InsertSessionWithName(ctx context.Context, sessionName string) error {
if sessionName == "" {
return fmt.Errorf("invalid session name '%s'", sessionName)
}
session := &SessionType{
SessionId: uuid.New().String(),
Name: sessionName,
}
return WithTx(ctx, func(tx *TxWrap) error {
query := `INSERT INTO session (sessionid, name) VALUES (:sessionid, :name)`
tx.NamedExecWrap(query, session)
window := &WindowType{
SessionId: session.SessionId,
WindowId: uuid.New().String(),
Name: DefaultWindowName,
CurRemote: LocalRemoteName,
}
query = `INSERT INTO window (sessionid, windowid, name, curremote, version) VALUES (:sessionid, :windowid, :name, :curremote, :version)`
tx.NamedExecWrap(query, window)
return nil
})
}
func InsertLine(ctx context.Context, line *LineType) error {
if line == nil {
return fmt.Errorf("line cannot be nil")
}
if line.LineId != 0 {
return fmt.Errorf("new line cannot have LineId set")
}
return WithTx(ctx, func(tx *TxWrap) error {
var windowId string
query := `SELECT windowid FROM window WHERE sessionid = ? AND windowid = ?`
hasWindow := tx.GetWrap(&windowId, query, line.SessionId, line.WindowId)
if !hasWindow {
return fmt.Errorf("window not found, cannot insert line[%s/%s]", line.SessionId, line.WindowId)
}
var maxLineId int
query = `SELECT COALESCE(max(lineid), 0) FROM line WHERE sessionid = ? AND windowid = ?`
tx.GetWrap(&maxLineId, query, line.SessionId, line.WindowId)
line.LineId = maxLineId + 1
query = `INSERT INTO line ( sessionid, windowid, lineid, ts, userid, linetype, text, cmdid)
VALUES (:sessionid,:windowid,:lineid,:ts,:userid,:linetype,:text,:cmdid)`
tx.NamedExecWrap(query, line)
return nil
})
}
func InsertCmd(ctx context.Context, cmd *CmdType) error {
if cmd == nil {
return fmt.Errorf("cmd cannot be nil")
}
return WithTx(ctx, func(tx *TxWrap) error {
var sessionId string
query := `SELECT sessionid FROM session WHERE sessionid = ?`
hasSession := tx.GetWrap(&sessionId, query, cmd.SessionId)
if !hasSession {
return fmt.Errorf("session not found, cannot insert cmd")
}
cmdMap := cmd.ToMap()
query = `
INSERT INTO cmd ( sessionid, cmdid, remoteid, cmdstr, remotestate, termopts, status, startpk, donepk, runout)
VALUES (:sessionid,:cmdid,:remoteid,:cmdstr,:remotestate,:termopts,:status,:startpk,:donepk,:runout)
`
tx.NamedExecWrap(query, cmdMap)
return nil
})
}
func GetCmd(ctx context.Context, sessionId string, cmdId string) (*CmdType, error) {
db, err := GetDB()
if err != nil {
return nil, err
}
var m map[string]interface{}
query := `SELECT * FROM cmd WHERE sessionid = ? AND cmdid = ?`
err = db.GetContext(ctx, &m, query, sessionId, cmdId)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
return CmdFromMap(m), nil
}