checkpoint, working on db calls

This commit is contained in:
sawka 2022-07-01 14:07:13 -07:00
parent 643f08e584
commit b85be3457c
4 changed files with 274 additions and 62 deletions

View File

@ -39,6 +39,8 @@ CREATE TABLE remote (
remoteid varchar(36) PRIMARY KEY, remoteid varchar(36) PRIMARY KEY,
remotetype varchar(10) NOT NULL, remotetype varchar(10) NOT NULL,
remotename varchar(50) NOT NULL, remotename varchar(50) NOT NULL,
hostname varchar(200) NOT NULL,
lastconnectts bigint NOT NULL,
connectopts varchar(300) NOT NULL, connectopts varchar(300) NOT NULL,
ptyout BLOB NOT NULL ptyout BLOB NOT NULL
); );
@ -62,7 +64,7 @@ CREATE TABLE history (
sessionid varchar(36) NOT NULL, sessionid varchar(36) NOT NULL,
windowid varchar(36) NOT NULL, windowid varchar(36) NOT NULL,
userid varchar(36) NOT NULL, userid varchar(36) NOT NULL,
ts int64 NOT NULL, ts bigint NOT NULL,
lineid varchar(36) NOT NULL, lineid varchar(36) NOT NULL,
PRIMARY KEY (sessionid, windowid, lineid) PRIMARY KEY (sessionid, windowid, lineid)
); );

147
pkg/sstore/dbops.go Normal file
View File

@ -0,0 +1,147 @@
package sstore
import (
"context"
"database/sql"
"fmt"
"github.com/google/uuid"
)
func NumSessions(ctx context.Context) (int, error) {
db, err := GetDB()
if err != nil {
return 0, err
}
query := "SELECT count(*) FROM session"
var count int
err = db.GetContext(ctx, &count, query)
if err != nil {
return 0, err
}
return count, nil
}
func GetRemoteByName(ctx context.Context, remoteName string) (*RemoteType, error) {
db, err := GetDB()
if err != nil {
return nil, err
}
query := `SELECT rowid, remoteid, remotetype, remotename, hostname, connectopts, lastconnectts FROM remote WHERE remotename = ?`
var remote RemoteType
err = db.GetContext(ctx, &remote, query, remoteName)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &remote, nil
}
func GetRemoteById(ctx context.Context, remoteId string) (*RemoteType, error) {
db, err := GetDB()
if err != nil {
return nil, err
}
query := `SELECT rowid, remoteid, remotetype, remotename, hostname, connectopts, lastconnectts FROM remote WHERE remoteid = ?`
var remote RemoteType
err = db.GetContext(ctx, &remote, query, remoteId)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &remote, nil
}
func InsertRemote(ctx context.Context, remote *RemoteType) error {
if remote == nil {
return fmt.Errorf("cannot insert nil remote")
}
if remote.RowId != 0 {
return fmt.Errorf("cannot insert a remote that already has rowid set, rowid=%d", remote.RowId)
}
db, err := GetDB()
if err != nil {
return err
}
query := `INSERT INTO remote (remoteid, remotetype, remotename, hostname, connectopts, lastconnectts, ptyout) VALUES (:remoteid, :remotetype, :remotename, :hostname, :connectopts, 0, '')`
result, err := db.NamedExec(query, remote)
if err != nil {
return err
}
remote.RowId, err = result.LastInsertId()
if err != nil {
return fmt.Errorf("cannot get lastinsertid from insert remote: %w", err)
}
return nil
}
func GetSessionById(ctx context.Context, id string) (*SessionType, error) {
db, err := GetDB()
query := `SELECT * FROM session WHERE sessionid = ?`
var session SessionType
err = db.GetContext(ctx, &session, query, id)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &session, nil
}
func GetSessionByName(ctx context.Context, name string) (*SessionType, error) {
db, err := GetDB()
query := `SELECT * FROM session WHERE name = ?`
var session SessionType
err = db.GetContext(ctx, &session, query, name)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &session, nil
}
// also creates window, and sessionremote
func InsertSessionWithName(ctx context.Context, sessionName string) error {
if sessionName == "" {
return fmt.Errorf("invalid session name '%s'", sessionName)
}
session := &SessionType{
SessionId: uuid.New().String(),
Name: sessionName,
}
localRemote, err := GetRemoteByName(ctx, LocalRemoteName)
if err != nil {
return err
}
return WithTx(ctx, func(tx *TxWrap) error {
query := `INSERT INTO session (sessionid, name) VALUES (:sessionid, :name)`
tx.NamedExecWrap(query, session)
window := &WindowType{
SessionId: session.SessionId,
WindowId: uuid.New().String(),
Name: DefaultWindowName,
CurRemote: LocalRemoteName,
}
query = `INSERT INTO window (sessionid, windowid, name, curremote, version) VALUES (:sessionid, :windowid, :name, :curremote, :version)`
tx.NamedExecWrap(query, window)
sr := &SessionRemote{
SessionId: session.SessionId,
WindowId: window.WindowId,
RemoteName: localRemote.RemoteName,
RemoteId: localRemote.RemoteId,
Cwd: DefaultCwd,
}
query = `INSERT INTO session_remote (sessionid, windowid, remotename, remoteid, cwd) VALUES (:sessionid, :windowid, :remotename, :remoteid, :cwd)`
tx.NamedExecWrap(query, sr)
return nil
})
}

View File

@ -2,8 +2,9 @@ package sstore
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"log"
"os"
"path" "path"
"sync" "sync"
"time" "time"
@ -28,6 +29,8 @@ const DefaultSessionName = "default"
const DefaultWindowName = "default" const DefaultWindowName = "default"
const LocalRemoteName = "local" const LocalRemoteName = "local"
const DefaultCwd = "~"
var globalDBLock = &sync.Mutex{} var globalDBLock = &sync.Mutex{}
var globalDB *sqlx.DB var globalDB *sqlx.DB
var globalDBErr error var globalDBErr error
@ -88,8 +91,13 @@ type RemoteType struct {
RemoteId string `json:"remoteid"` RemoteId string `json:"remoteid"`
RemoteType string `json:"remotetype"` RemoteType string `json:"remotetype"`
RemoteName string `json:"remotename"` RemoteName string `json:"remotename"`
HostName string `json:"hostname"`
LastConnectTs int64 `json:"lastconnectts"`
ConnectOpts string `json:"connectopts"` ConnectOpts string `json:"connectopts"`
// runtime
Connected bool `json:"connected"` Connected bool `json:"connected"`
InitPk *packet.InitPacketType `json:"-"`
} }
type CmdType struct { type CmdType struct {
@ -139,60 +147,6 @@ func GetNextLine() int {
return rtn return rtn
} }
func NumSessions(ctx context.Context) (int, error) {
db, err := GetDB()
if err != nil {
return 0, err
}
query := "SELECT count(*) FROM session"
var count int
err = db.GetContext(ctx, &count, query)
if err != nil {
return 0, err
}
return count, nil
}
func GetRemoteById(ctx context.Context, remoteId string) (*RemoteType, error) {
db, err := GetDB()
if err != nil {
return nil, err
}
query := `SELECT rowid, remoteid, remotetype, remotename, connectopts FROM remote WHERE remoteid = ?`
var remote RemoteType
err = db.GetContext(ctx, &remote, query, remoteId)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &remote, nil
}
func InsertRemote(ctx context.Context, remote *RemoteType) error {
if remote == nil {
return fmt.Errorf("cannot insert nil remote")
}
if remote.RowId != 0 {
return fmt.Errorf("cannot insert a remote that already has rowid set, rowid=%d", remote.RowId)
}
db, err := GetDB()
if err != nil {
return err
}
query := `INSERT INTO remote (remoteid, remotetype, remotename, connectopts, ptyout) VALUES (:remoteid, :remotetype, :remotename, :connectopts, '')`
result, err := db.NamedExec(query, remote)
if err != nil {
return err
}
remote.RowId, err = result.LastInsertId()
if err != nil {
return fmt.Errorf("cannot get lastinsertid from insert remote: %w", err)
}
return nil
}
func EnsureLocalRemote(ctx context.Context) error { func EnsureLocalRemote(ctx context.Context) error {
remoteId, err := base.GetRemoteId() remoteId, err := base.GetRemoteId()
if err != nil { if err != nil {
@ -205,16 +159,37 @@ func EnsureLocalRemote(ctx context.Context) error {
if remote != nil { if remote != nil {
return nil return nil
} }
hostName, err := os.Hostname()
if err != nil {
return fmt.Errorf("cannot get hostname: %w", err)
}
// create the local remote // create the local remote
localRemote := &RemoteType{ localRemote := &RemoteType{
RemoteId: remoteId, RemoteId: remoteId,
RemoteType: "ssh", RemoteType: "ssh",
RemoteName: LocalRemoteName, RemoteName: LocalRemoteName,
HostName: hostName,
} }
err = InsertRemote(ctx, localRemote) err = InsertRemote(ctx, localRemote)
if err != nil { if err != nil {
return err return err
} }
log.Printf("[db] added remote '%s', id=%s\n", localRemote.RemoteName, localRemote.RemoteId)
return nil
}
func EnsureDefaultSession(ctx context.Context) error {
session, err := GetSessionByName(ctx, DefaultSessionName)
if err != nil {
return err
}
if session != nil {
return nil
}
err = InsertSessionWithName(ctx, DefaultSessionName)
if err != nil {
return err
}
return nil return nil
} }

88
pkg/sstore/txwrap.go Normal file
View File

@ -0,0 +1,88 @@
package sstore
import (
"context"
"database/sql"
"github.com/jmoiron/sqlx"
)
type TxWrap struct {
Txx *sqlx.Tx
Err error
}
func WithTx(ctx context.Context, fn func(tx *TxWrap) error) (rtnErr error) {
db, err := GetDB()
if err != nil {
return err
}
tx, beginErr := db.BeginTxx(ctx, nil)
if beginErr != nil {
return beginErr
}
txWrap := &TxWrap{Txx: tx}
defer func() {
if p := recover(); p != nil {
tx.Rollback()
panic(p)
}
if rtnErr != nil {
tx.Rollback()
} else {
rtnErr = tx.Commit()
}
}()
fnErr := fn(txWrap)
if fnErr != nil {
return fnErr
}
if txWrap.Err != nil {
return txWrap.Err
}
return nil
}
func (tx *TxWrap) NamedExecWrap(query string, arg interface{}) sql.Result {
if tx.Err != nil {
return nil
}
result, err := tx.Txx.NamedExec(query, arg)
if err != nil {
tx.Err = err
}
return result
}
func (tx *TxWrap) ExecWrap(query string, args ...interface{}) sql.Result {
if tx.Err != nil {
return nil
}
result, err := tx.Txx.Exec(query, args...)
if err != nil {
tx.Err = err
}
return result
}
func (tx *TxWrap) GetWrap(dest interface{}, query string, args ...interface{}) error {
if tx.Err != nil {
return nil
}
err := tx.Txx.Get(dest, query, args...)
if err != nil && err != sql.ErrNoRows {
tx.Err = err
}
return err
}
func (tx *TxWrap) SelectWrap(dest interface{}, query string, args ...interface{}) error {
if tx.Err != nil {
return nil
}
err := tx.Txx.Select(dest, query, args...)
if err != nil {
tx.Err = err
}
return err
}