checkpoint, working on db calls

This commit is contained in:
sawka 2022-07-01 14:07:13 -07:00
parent 643f08e584
commit b85be3457c
4 changed files with 274 additions and 62 deletions

View File

@ -39,6 +39,8 @@ CREATE TABLE remote (
remoteid varchar(36) PRIMARY KEY,
remotetype varchar(10) NOT NULL,
remotename varchar(50) NOT NULL,
hostname varchar(200) NOT NULL,
lastconnectts bigint NOT NULL,
connectopts varchar(300) NOT NULL,
ptyout BLOB NOT NULL
);
@ -62,7 +64,7 @@ CREATE TABLE history (
sessionid varchar(36) NOT NULL,
windowid varchar(36) NOT NULL,
userid varchar(36) NOT NULL,
ts int64 NOT NULL,
ts bigint NOT NULL,
lineid varchar(36) NOT NULL,
PRIMARY KEY (sessionid, windowid, lineid)
);

147
pkg/sstore/dbops.go Normal file
View File

@ -0,0 +1,147 @@
package sstore
import (
"context"
"database/sql"
"fmt"
"github.com/google/uuid"
)
func NumSessions(ctx context.Context) (int, error) {
db, err := GetDB()
if err != nil {
return 0, err
}
query := "SELECT count(*) FROM session"
var count int
err = db.GetContext(ctx, &count, query)
if err != nil {
return 0, err
}
return count, nil
}
func GetRemoteByName(ctx context.Context, remoteName string) (*RemoteType, error) {
db, err := GetDB()
if err != nil {
return nil, err
}
query := `SELECT rowid, remoteid, remotetype, remotename, hostname, connectopts, lastconnectts FROM remote WHERE remotename = ?`
var remote RemoteType
err = db.GetContext(ctx, &remote, query, remoteName)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &remote, nil
}
func GetRemoteById(ctx context.Context, remoteId string) (*RemoteType, error) {
db, err := GetDB()
if err != nil {
return nil, err
}
query := `SELECT rowid, remoteid, remotetype, remotename, hostname, connectopts, lastconnectts FROM remote WHERE remoteid = ?`
var remote RemoteType
err = db.GetContext(ctx, &remote, query, remoteId)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &remote, nil
}
func InsertRemote(ctx context.Context, remote *RemoteType) error {
if remote == nil {
return fmt.Errorf("cannot insert nil remote")
}
if remote.RowId != 0 {
return fmt.Errorf("cannot insert a remote that already has rowid set, rowid=%d", remote.RowId)
}
db, err := GetDB()
if err != nil {
return err
}
query := `INSERT INTO remote (remoteid, remotetype, remotename, hostname, connectopts, lastconnectts, ptyout) VALUES (:remoteid, :remotetype, :remotename, :hostname, :connectopts, 0, '')`
result, err := db.NamedExec(query, remote)
if err != nil {
return err
}
remote.RowId, err = result.LastInsertId()
if err != nil {
return fmt.Errorf("cannot get lastinsertid from insert remote: %w", err)
}
return nil
}
func GetSessionById(ctx context.Context, id string) (*SessionType, error) {
db, err := GetDB()
query := `SELECT * FROM session WHERE sessionid = ?`
var session SessionType
err = db.GetContext(ctx, &session, query, id)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &session, nil
}
func GetSessionByName(ctx context.Context, name string) (*SessionType, error) {
db, err := GetDB()
query := `SELECT * FROM session WHERE name = ?`
var session SessionType
err = db.GetContext(ctx, &session, query, name)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &session, nil
}
// also creates window, and sessionremote
func InsertSessionWithName(ctx context.Context, sessionName string) error {
if sessionName == "" {
return fmt.Errorf("invalid session name '%s'", sessionName)
}
session := &SessionType{
SessionId: uuid.New().String(),
Name: sessionName,
}
localRemote, err := GetRemoteByName(ctx, LocalRemoteName)
if err != nil {
return err
}
return WithTx(ctx, func(tx *TxWrap) error {
query := `INSERT INTO session (sessionid, name) VALUES (:sessionid, :name)`
tx.NamedExecWrap(query, session)
window := &WindowType{
SessionId: session.SessionId,
WindowId: uuid.New().String(),
Name: DefaultWindowName,
CurRemote: LocalRemoteName,
}
query = `INSERT INTO window (sessionid, windowid, name, curremote, version) VALUES (:sessionid, :windowid, :name, :curremote, :version)`
tx.NamedExecWrap(query, window)
sr := &SessionRemote{
SessionId: session.SessionId,
WindowId: window.WindowId,
RemoteName: localRemote.RemoteName,
RemoteId: localRemote.RemoteId,
Cwd: DefaultCwd,
}
query = `INSERT INTO session_remote (sessionid, windowid, remotename, remoteid, cwd) VALUES (:sessionid, :windowid, :remotename, :remoteid, :cwd)`
tx.NamedExecWrap(query, sr)
return nil
})
}

View File

@ -2,8 +2,9 @@ package sstore
import (
"context"
"database/sql"
"fmt"
"log"
"os"
"path"
"sync"
"time"
@ -28,6 +29,8 @@ const DefaultSessionName = "default"
const DefaultWindowName = "default"
const LocalRemoteName = "local"
const DefaultCwd = "~"
var globalDBLock = &sync.Mutex{}
var globalDB *sqlx.DB
var globalDBErr error
@ -84,12 +87,17 @@ type LineType struct {
}
type RemoteType struct {
RowId int64 `json:"rowid"`
RemoteId string `json:"remoteid"`
RemoteType string `json:"remotetype"`
RemoteName string `json:"remotename"`
ConnectOpts string `json:"connectopts"`
Connected bool `json:"connected"`
RowId int64 `json:"rowid"`
RemoteId string `json:"remoteid"`
RemoteType string `json:"remotetype"`
RemoteName string `json:"remotename"`
HostName string `json:"hostname"`
LastConnectTs int64 `json:"lastconnectts"`
ConnectOpts string `json:"connectopts"`
// runtime
Connected bool `json:"connected"`
InitPk *packet.InitPacketType `json:"-"`
}
type CmdType struct {
@ -139,60 +147,6 @@ func GetNextLine() int {
return rtn
}
func NumSessions(ctx context.Context) (int, error) {
db, err := GetDB()
if err != nil {
return 0, err
}
query := "SELECT count(*) FROM session"
var count int
err = db.GetContext(ctx, &count, query)
if err != nil {
return 0, err
}
return count, nil
}
func GetRemoteById(ctx context.Context, remoteId string) (*RemoteType, error) {
db, err := GetDB()
if err != nil {
return nil, err
}
query := `SELECT rowid, remoteid, remotetype, remotename, connectopts FROM remote WHERE remoteid = ?`
var remote RemoteType
err = db.GetContext(ctx, &remote, query, remoteId)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &remote, nil
}
func InsertRemote(ctx context.Context, remote *RemoteType) error {
if remote == nil {
return fmt.Errorf("cannot insert nil remote")
}
if remote.RowId != 0 {
return fmt.Errorf("cannot insert a remote that already has rowid set, rowid=%d", remote.RowId)
}
db, err := GetDB()
if err != nil {
return err
}
query := `INSERT INTO remote (remoteid, remotetype, remotename, connectopts, ptyout) VALUES (:remoteid, :remotetype, :remotename, :connectopts, '')`
result, err := db.NamedExec(query, remote)
if err != nil {
return err
}
remote.RowId, err = result.LastInsertId()
if err != nil {
return fmt.Errorf("cannot get lastinsertid from insert remote: %w", err)
}
return nil
}
func EnsureLocalRemote(ctx context.Context) error {
remoteId, err := base.GetRemoteId()
if err != nil {
@ -205,16 +159,37 @@ func EnsureLocalRemote(ctx context.Context) error {
if remote != nil {
return nil
}
hostName, err := os.Hostname()
if err != nil {
return fmt.Errorf("cannot get hostname: %w", err)
}
// create the local remote
localRemote := &RemoteType{
RemoteId: remoteId,
RemoteType: "ssh",
RemoteName: LocalRemoteName,
HostName: hostName,
}
err = InsertRemote(ctx, localRemote)
if err != nil {
return err
}
log.Printf("[db] added remote '%s', id=%s\n", localRemote.RemoteName, localRemote.RemoteId)
return nil
}
func EnsureDefaultSession(ctx context.Context) error {
session, err := GetSessionByName(ctx, DefaultSessionName)
if err != nil {
return err
}
if session != nil {
return nil
}
err = InsertSessionWithName(ctx, DefaultSessionName)
if err != nil {
return err
}
return nil
}

88
pkg/sstore/txwrap.go Normal file
View File

@ -0,0 +1,88 @@
package sstore
import (
"context"
"database/sql"
"github.com/jmoiron/sqlx"
)
type TxWrap struct {
Txx *sqlx.Tx
Err error
}
func WithTx(ctx context.Context, fn func(tx *TxWrap) error) (rtnErr error) {
db, err := GetDB()
if err != nil {
return err
}
tx, beginErr := db.BeginTxx(ctx, nil)
if beginErr != nil {
return beginErr
}
txWrap := &TxWrap{Txx: tx}
defer func() {
if p := recover(); p != nil {
tx.Rollback()
panic(p)
}
if rtnErr != nil {
tx.Rollback()
} else {
rtnErr = tx.Commit()
}
}()
fnErr := fn(txWrap)
if fnErr != nil {
return fnErr
}
if txWrap.Err != nil {
return txWrap.Err
}
return nil
}
func (tx *TxWrap) NamedExecWrap(query string, arg interface{}) sql.Result {
if tx.Err != nil {
return nil
}
result, err := tx.Txx.NamedExec(query, arg)
if err != nil {
tx.Err = err
}
return result
}
func (tx *TxWrap) ExecWrap(query string, args ...interface{}) sql.Result {
if tx.Err != nil {
return nil
}
result, err := tx.Txx.Exec(query, args...)
if err != nil {
tx.Err = err
}
return result
}
func (tx *TxWrap) GetWrap(dest interface{}, query string, args ...interface{}) error {
if tx.Err != nil {
return nil
}
err := tx.Txx.Get(dest, query, args...)
if err != nil && err != sql.ErrNoRows {
tx.Err = err
}
return err
}
func (tx *TxWrap) SelectWrap(dest interface{}, query string, args ...interface{}) error {
if tx.Err != nil {
return nil
}
err := tx.Txx.Select(dest, query, args...)
if err != nil {
tx.Err = err
}
return err
}