// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0

package shellexec

import (
	"bytes"
	"fmt"
	"io"
	"os"
	"os/exec"
	"reflect"
	"syscall"

	"github.com/creack/pty"
	"github.com/wavetermdev/thenextwave/pkg/util/shellutil"
	"github.com/wavetermdev/thenextwave/pkg/wavebase"
)

type TermSize struct {
	Rows int `json:"rows"`
	Cols int `json:"cols"`
}

type ShellProc struct {
	Cmd *exec.Cmd
	Pty *os.File
}

func (sp *ShellProc) Close() {
	sp.Cmd.Process.Kill()
	go func() {
		sp.Cmd.Process.Wait()
		sp.Pty.Close()
	}()
}

func setBoolConditionally(rval reflect.Value, field string, value bool) {
	if rval.Elem().FieldByName(field).IsValid() {
		rval.Elem().FieldByName(field).SetBool(value)
	}
}

func setSysProcAttrs(cmd *exec.Cmd) {
	rval := reflect.ValueOf(cmd.SysProcAttr)
	setBoolConditionally(rval, "Setsid", true)
	setBoolConditionally(rval, "Setctty", true)
}

func StartShellProc(termSize TermSize) (*ShellProc, error) {
	shellPath := shellutil.DetectLocalShellPath()
	ecmd := exec.Command(shellPath, "-i", "-l")
	ecmd.Env = os.Environ()
	ecmd.Dir = wavebase.GetHomeDir()
	envToAdd := shellutil.WaveshellEnvVars(shellutil.DefaultTermType)
	if os.Getenv("LANG") == "" {
		envToAdd["LANG"] = wavebase.DetermineLang()
	}
	shellutil.UpdateCmdEnv(ecmd, envToAdd)
	cmdPty, cmdTty, err := pty.Open()
	if err != nil {
		return nil, fmt.Errorf("opening new pty: %w", err)
	}
	if termSize.Rows == 0 || termSize.Cols == 0 {
		termSize.Rows = shellutil.DefaultTermRows
		termSize.Cols = shellutil.DefaultTermCols
	}
	if termSize.Rows <= 0 || termSize.Cols <= 0 {
		return nil, fmt.Errorf("invalid term size: %v", termSize)
	}
	pty.Setsize(cmdPty, &pty.Winsize{Rows: uint16(termSize.Rows), Cols: uint16(termSize.Cols)})
	ecmd.Stdin = cmdTty
	ecmd.Stdout = cmdTty
	ecmd.Stderr = cmdTty
	ecmd.SysProcAttr = &syscall.SysProcAttr{}
	setSysProcAttrs(ecmd)
	err = ecmd.Start()
	cmdTty.Close()
	if err != nil {
		cmdPty.Close()
		return nil, err
	}
	return &ShellProc{Cmd: ecmd, Pty: cmdPty}, nil
}

func RunSimpleCmdInPty(ecmd *exec.Cmd, termSize TermSize) ([]byte, error) {
	ecmd.Env = os.Environ()
	shellutil.UpdateCmdEnv(ecmd, shellutil.WaveshellEnvVars(shellutil.DefaultTermType))
	cmdPty, cmdTty, err := pty.Open()
	if err != nil {
		return nil, fmt.Errorf("opening new pty: %w", err)
	}
	if termSize.Rows == 0 || termSize.Cols == 0 {
		termSize.Rows = shellutil.DefaultTermRows
		termSize.Cols = shellutil.DefaultTermCols
	}
	if termSize.Rows <= 0 || termSize.Cols <= 0 {
		return nil, fmt.Errorf("invalid term size: %v", termSize)
	}
	pty.Setsize(cmdPty, &pty.Winsize{Rows: uint16(termSize.Rows), Cols: uint16(termSize.Cols)})
	ecmd.Stdin = cmdTty
	ecmd.Stdout = cmdTty
	ecmd.Stderr = cmdTty
	ecmd.SysProcAttr = &syscall.SysProcAttr{}
	setSysProcAttrs(ecmd)
	err = ecmd.Start()
	cmdTty.Close()
	if err != nil {
		cmdPty.Close()
		return nil, err
	}
	defer cmdPty.Close()
	ioDone := make(chan bool)
	var outputBuf bytes.Buffer
	go func() {
		// ignore error (/dev/ptmx has read error when process is done)
		defer close(ioDone)
		io.Copy(&outputBuf, cmdPty)
	}()
	exitErr := ecmd.Wait()
	if exitErr != nil {
		return nil, exitErr
	}
	<-ioDone
	return outputBuf.Bytes(), nil
}