waveterm/pkg/wshutil/wshcmdreader.go

172 lines
3.8 KiB
Go
Raw Normal View History

// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package wshutil
import (
"bytes"
"fmt"
"io"
"sync"
)
const (
Mode_Normal = "normal"
Mode_Esc = "esc"
Mode_WaveEsc = "waveesc"
)
const MaxBufferedDataSize = 256 * 1024
type PtyBuffer struct {
CVar *sync.Cond
DataBuf *bytes.Buffer
EscMode string
EscSeqBuf []byte
OSCPrefix string
InputReader io.Reader
MessageCh chan []byte
AtEOF bool
Err error
}
// closes messageCh when input is closed (or error)
func MakePtyBuffer(oscPrefix string, input io.Reader, messageCh chan []byte) *PtyBuffer {
if len(oscPrefix) != WaveOSCPrefixLen {
panic(fmt.Sprintf("invalid OSC prefix length: %d", len(oscPrefix)))
}
b := &PtyBuffer{
CVar: sync.NewCond(&sync.Mutex{}),
DataBuf: &bytes.Buffer{},
OSCPrefix: oscPrefix,
EscMode: Mode_Normal,
InputReader: input,
MessageCh: messageCh,
}
go b.run()
return b
}
func (b *PtyBuffer) setErr(err error) {
b.CVar.L.Lock()
defer b.CVar.L.Unlock()
if b.Err == nil {
b.Err = err
}
b.CVar.Broadcast()
}
func (b *PtyBuffer) setEOF() {
b.CVar.L.Lock()
defer b.CVar.L.Unlock()
b.AtEOF = true
b.CVar.Broadcast()
}
func (b *PtyBuffer) processWaveEscSeq(escSeq []byte) {
b.MessageCh <- escSeq
}
func (b *PtyBuffer) run() {
defer close(b.MessageCh)
buf := make([]byte, 4096)
for {
n, err := b.InputReader.Read(buf)
b.processData(buf[:n])
if err == io.EOF {
b.setEOF()
return
}
if err != nil {
b.setErr(fmt.Errorf("error reading input: %w", err))
return
}
}
}
func (b *PtyBuffer) processData(data []byte) {
outputBuf := make([]byte, 0, len(data))
for _, ch := range data {
if b.EscMode == Mode_WaveEsc {
if ch == ESC {
// terminates the escape sequence (and the rest was invalid)
b.EscMode = Mode_Normal
outputBuf = append(outputBuf, b.EscSeqBuf...)
outputBuf = append(outputBuf, ch)
b.EscSeqBuf = nil
} else if ch == BEL || ch == ST {
// terminates the escpae sequence (is a valid Wave OSC command)
b.EscMode = Mode_Normal
waveEscSeq := b.EscSeqBuf[WaveOSCPrefixLen:]
b.EscSeqBuf = nil
b.processWaveEscSeq(waveEscSeq)
} else {
b.EscSeqBuf = append(b.EscSeqBuf, ch)
}
continue
}
if b.EscMode == Mode_Esc {
if ch == ESC || ch == BEL || ch == ST {
// these all terminate the escape sequence (invalid, not a Wave OSC)
b.EscMode = Mode_Normal
outputBuf = append(outputBuf, b.EscSeqBuf...)
outputBuf = append(outputBuf, ch)
b.EscSeqBuf = nil
continue
}
if ch != b.OSCPrefix[len(b.EscSeqBuf)] {
// this is not a Wave OSC sequence, just an escape sequence
b.EscMode = Mode_Normal
outputBuf = append(outputBuf, b.EscSeqBuf...)
outputBuf = append(outputBuf, ch)
b.EscSeqBuf = nil
continue
}
// we're still building what could be a Wave OSC sequence
b.EscSeqBuf = append(b.EscSeqBuf, ch)
// check to see if we have a full Wave OSC prefix
if len(b.EscSeqBuf) == len(b.OSCPrefix) {
b.EscMode = Mode_WaveEsc
}
continue
}
// Mode_Normal
if ch == ESC {
b.EscMode = Mode_Esc
b.EscSeqBuf = []byte{ch}
continue
}
outputBuf = append(outputBuf, ch)
}
if len(outputBuf) > 0 {
b.writeData(outputBuf)
}
}
func (b *PtyBuffer) writeData(data []byte) {
b.CVar.L.Lock()
defer b.CVar.L.Unlock()
// only wait if buffer is currently over max size, otherwise allow this append to go through
for b.DataBuf.Len() > MaxBufferedDataSize {
b.CVar.Wait()
}
b.DataBuf.Write(data)
b.CVar.Broadcast()
}
func (b *PtyBuffer) Read(p []byte) (n int, err error) {
b.CVar.L.Lock()
defer b.CVar.L.Unlock()
for b.DataBuf.Len() == 0 {
if b.Err != nil {
return 0, b.Err
}
if b.AtEOF {
return 0, io.EOF
}
b.CVar.Wait()
}
b.CVar.Broadcast()
return b.DataBuf.Read(p)
}