update writefile code. changed the way usetemp works to make sure file permissions/owner/attributes are kept on original file

This commit is contained in:
sawka 2023-09-06 21:45:15 -07:00
parent bc488cf242
commit 71ba4b5b46
2 changed files with 101 additions and 49 deletions

View File

@ -417,11 +417,12 @@ func MakeStreamFilePacket() *StreamFilePacketType {
}
type FileInfo struct {
Name string `json:"name"`
Size int64 `json:"size"`
ModTs int64 `json:"modts"`
IsDir bool `json:"isdir,omitempty"`
Perm int `json:"perm"`
Name string `json:"name"`
Size int64 `json:"size"`
ModTs int64 `json:"modts"`
IsDir bool `json:"isdir,omitempty"`
Perm int `json:"perm"`
NotFound bool `json:"notfound,omitempty"` // when NotFound is set, Perm will be set to permission for directory
}
type StreamFileResponseType struct {

View File

@ -272,6 +272,58 @@ func makeTemp(path string, mode fs.FileMode) (*os.File, error) {
return writeFd, nil
}
func checkFileWritable(path string) error {
finfo, err := os.Stat(path) // ok to follow symlinks
if errors.Is(err, fs.ErrNotExist) {
dirName := filepath.Dir(path)
dirInfo, err := os.Stat(dirName)
if err != nil {
return fmt.Errorf("file does not exist, error trying to stat parent directory: %w", err)
}
if !dirInfo.IsDir() {
return fmt.Errorf("file does not exist, parent path [%s] is not a directory", dirName)
}
return nil
} else {
if err != nil {
return fmt.Errorf("cannot stat: %w", err)
}
if finfo.IsDir() {
return fmt.Errorf("invalid path, cannot write a directory")
}
if (finfo.Mode() & fs.ModeSymlink) != 0 {
return fmt.Errorf("writefile does not support symlinks") // note this shouldn't happen because we're using Stat (not Lstat)
}
if (finfo.Mode() & (fs.ModeNamedPipe | fs.ModeSocket | fs.ModeDevice)) != 0 {
return fmt.Errorf("writefile does not support special files (named pipes, sockets, devices): mode=%v", finfo.Mode())
}
writePerm := (finfo.Mode().Perm() & 0o222)
if writePerm == 0 {
return fmt.Errorf("file is not writable, perms: %v", finfo.Mode().Perm())
}
return nil
}
}
func copyFile(dstName string, srcName string) error {
srcFd, err := os.Open(srcName)
if err != nil {
return err
}
defer srcFd.Close()
dstFd, err := os.OpenFile(dstName, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o777) // use 777 because OpenFile respects umask
if err != nil {
return err
}
// we don't defer dstFd.Close() so we can return an error if dstFd.Close() returns an error
_, err = io.Copy(dstFd, srcFd)
if err != nil {
dstFd.Close()
return err
}
return dstFd.Close()
}
func (m *MServer) writeFile(pk *packet.WriteFilePacketType, wfc *WriteFileContext) {
defer wfc.setDone()
if pk.Path == "" {
@ -280,55 +332,19 @@ func (m *MServer) writeFile(pk *packet.WriteFilePacketType, wfc *WriteFileContex
m.Sender.SendPacket(resp)
return
}
var finfo fs.FileInfo
var err error
if pk.UseTemp {
finfo, err = os.Lstat(pk.Path)
} else {
finfo, err = os.Stat(pk.Path)
}
if err == nil && finfo.IsDir() {
err = fmt.Errorf("invalid path, cannot write a directory")
}
if err == nil && ((finfo.Mode() & fs.ModeSymlink) != 0) {
err = fmt.Errorf("writefile (with usetemp) does not support symlinks")
}
if err == nil && ((finfo.Mode() & (fs.ModeNamedPipe | fs.ModeSocket | fs.ModeDevice | fs.ModeSetuid | fs.ModeSetgid)) != 0) {
err = fmt.Errorf("writefile does not support special files (named pipes, sockets, devices, setuid, or setgid): mode=%v", finfo.Mode())
}
if err == nil {
writePerm := (finfo.Mode().Perm() & 0o222)
if writePerm == 0 {
err = fmt.Errorf("file is not writable, perms: %v", finfo.Mode().Perm())
}
}
err := checkFileWritable(pk.Path)
if err != nil {
resp := packet.MakeWriteFileReadyPacket(pk.ReqId)
resp.Error = err.Error()
m.Sender.SendPacket(resp)
return
}
var writeFd *os.File
if pk.UseTemp {
dirName := filepath.Dir(pk.Path)
dirFInfo, err := os.Stat(dirName)
if err == nil {
writePerm := (dirFInfo.Mode().Perm() & 0o222)
if writePerm == 0 {
err = fmt.Errorf("file-write tempmode is set, but parent directory is not writeable, perms: %v", dirFInfo.Mode().Perm())
}
}
writeFd, err = os.CreateTemp("", "mshell.writefile.*") // "" means make this file in standard TempDir
if err != nil {
resp := packet.MakeWriteFileReadyPacket(pk.ReqId)
resp.Error = err.Error()
m.Sender.SendPacket(resp)
return
}
writeFd, err = makeTemp(pk.Path, finfo.Mode().Perm())
if err != nil {
resp := packet.MakeWriteFileReadyPacket(pk.ReqId)
resp.Error = fmt.Sprintf("write-file could not open tempfile: %v", err)
resp.Error = fmt.Sprintf("cannot create temp file: %v", err)
m.Sender.SendPacket(resp)
return
}
@ -388,12 +404,12 @@ func (m *MServer) writeFile(pk *packet.WriteFilePacketType, wfc *WriteFileContex
if doneErr != nil {
os.Remove(writeFd.Name())
} else {
renameErr := os.Rename(writeFd.Name(), pk.Path)
if renameErr != nil {
doneErr = fmt.Errorf("error renaming temp file: %v", renameErr)
// rename failed, try to remove temp file still
os.Remove(writeFd.Name())
// copy file between writeFd.Name() and pk.Path
copyErr := copyFile(pk.Path, writeFd.Name())
if err != nil {
doneErr = fmt.Errorf("error writing file: %v", copyErr)
}
os.Remove(writeFd.Name())
}
}
donePk := packet.MakeWriteFileDonePacket(pk.ReqId)
@ -403,9 +419,44 @@ func (m *MServer) writeFile(pk *packet.WriteFilePacketType, wfc *WriteFileContex
m.Sender.SendPacket(donePk)
}
func (m *MServer) returnStreamFileNewFileResponse(pk *packet.StreamFilePacketType) {
// ok, file doesn't exist, so try to check the directory at least to see if we can write a file here
resp := packet.MakeStreamFileResponse(pk.ReqId)
defer func() {
if resp.Error == "" {
resp.Done = true
}
m.Sender.SendPacket(resp)
}()
dirName := filepath.Dir(pk.Path)
dirInfo, err := os.Stat(dirName)
if err != nil {
resp.Error = fmt.Sprintf("file does not exist, error trying to stat parent directory: %v", err)
return
}
if !dirInfo.IsDir() {
resp.Error = fmt.Sprintf("file does not exist, parent path [%s] is not a directory", dirName)
return
}
resp.Info = &packet.FileInfo{
Name: pk.Path,
Size: 0,
ModTs: 0,
IsDir: false,
Perm: int(dirInfo.Mode().Perm()),
NotFound: true,
}
return
}
func (m *MServer) streamFile(pk *packet.StreamFilePacketType) {
resp := packet.MakeStreamFileResponse(pk.ReqId)
finfo, err := os.Stat(pk.Path)
if errors.Is(err, fs.ErrNotExist) {
// special return
m.returnStreamFileNewFileResponse(pk)
return
}
if err != nil {
resp.Error = fmt.Sprintf("cannot stat file %q: %v", pk.Path, err)
m.Sender.SendPacket(resp)