Skip to content

Commit

Permalink
feat(term): add cross-platform console state-aware support (#29)
Browse files Browse the repository at this point in the history
* feat(term): add cross-platform console state-aware support

This basically ports https://github.com/golang/term but without the VT
terminal stuff, ANSI codes, and exports platform-specific terminal
state. On Unix, the state is `unix.Termios`, and on Windows, it's a
`uint32` representing the console mode
https://learn.microsoft.com/en-us/windows/console/getconsolemode.

* feat(term): support setting terminal state (#30)

* feat(term): use uintptr instead of int

Simplify API usage
  • Loading branch information
aymanbagabas authored Jan 12, 2024
1 parent 0b337a6 commit 24b8333
Show file tree
Hide file tree
Showing 8 changed files with 376 additions and 0 deletions.
49 changes: 49 additions & 0 deletions term/term.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package term

// State contains platform-specific state of a terminal.
type State struct {
state
}

// IsTerminal returns whether the given file descriptor is a terminal.
func IsTerminal(fd uintptr) bool {
return isTerminal(fd)
}

// MakeRaw puts the terminal connected to the given file descriptor into raw
// mode and returns the previous state of the terminal so that it can be
// restored.
func MakeRaw(fd uintptr) (*State, error) {
return makeRaw(fd)
}

// GetState returns the current state of a terminal which may be useful to
// restore the terminal after a signal.
func GetState(fd uintptr) (*State, error) {
return getState(fd)
}

// SetState sets the given state of the terminal.
func SetState(fd uintptr, state *State) error {
return setState(fd, state)
}

// Restore restores the terminal connected to the given file descriptor to a
// previous state.
func Restore(fd uintptr, oldState *State) error {
return restore(fd, oldState)
}

// GetSize returns the visible dimensions of the given terminal.
//
// These dimensions don't include any scrollback buffer height.
func GetSize(fd uintptr) (width, height int, err error) {
return getSize(fd)
}

// ReadPassword reads a line of input from a terminal without local echo. This
// is commonly used for inputting passwords and other sensitive data. The slice
// returned does not include the \n.
func ReadPassword(fd uintptr) ([]byte, error) {
return readPassword(fd)
}
39 changes: 39 additions & 0 deletions term/term_other.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//go:build !aix && !darwin && !dragonfly && !freebsd && !linux && !netbsd && !openbsd && !zos && !windows && !solaris && !plan9
// +build !aix,!darwin,!dragonfly,!freebsd,!linux,!netbsd,!openbsd,!zos,!windows,!solaris,!plan9

package term

import (
"fmt"
"runtime"
)

type state struct{}

func isTerminal(fd uintptr) bool {
return false
}

func makeRaw(fd uintptr) (*State, error) {
return nil, fmt.Errorf("terminal: MakeRaw not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
}

func getState(fd uintptr) (*State, error) {
return nil, fmt.Errorf("terminal: GetState not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
}

func restore(fd uintptr, state *State) error {
return fmt.Errorf("terminal: Restore not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
}

func getSize(fd uintptr) (width, height int, err error) {
return 0, 0, fmt.Errorf("terminal: GetSize not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
}

func setState(fd uintptr, state *State) error {
return fmt.Errorf("terminal: SetState not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
}

func readPassword(fd uintptr) ([]byte, error) {
return nil, fmt.Errorf("terminal: ReadPassword not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
}
37 changes: 37 additions & 0 deletions term/term_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package term_test

import (
"os"
"runtime"
"testing"

"github.com/charmbracelet/x/term"
)

func TestIsTerminalTempFile(t *testing.T) {
file, err := os.CreateTemp("", "TestIsTerminalTempFile")
if err != nil {
t.Fatal(err)
}
defer os.Remove(file.Name())
defer file.Close()

if term.IsTerminal(file.Fd()) {
t.Fatalf("IsTerminal unexpectedly returned true for temporary file %s", file.Name())
}
}

func TestIsTerminalTerm(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skipf("unknown terminal path for GOOS %v", runtime.GOOS)
}
file, err := os.OpenFile("/dev/ptmx", os.O_RDWR, 0)
if err != nil {
t.Fatal(err)
}
defer file.Close()

if !term.IsTerminal(file.Fd()) {
t.Fatalf("IsTerminal unexpectedly returned false for terminal file %s", file.Name())
}
}
96 changes: 96 additions & 0 deletions term/term_unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris zos

package term

import (
"golang.org/x/sys/unix"
)

type state struct {
unix.Termios
}

func isTerminal(fd uintptr) bool {
_, err := unix.IoctlGetTermios(int(fd), ioctlReadTermios)
return err == nil
}

func makeRaw(fd uintptr) (*State, error) {
termios, err := unix.IoctlGetTermios(int(fd), ioctlReadTermios)
if err != nil {
return nil, err
}

oldState := State{state{Termios: *termios}}

// This attempts to replicate the behaviour documented for cfmakeraw in
// the termios(3) manpage.
termios.Iflag &^= unix.IGNBRK | unix.BRKINT | unix.PARMRK | unix.ISTRIP | unix.INLCR | unix.IGNCR | unix.ICRNL | unix.IXON
termios.Oflag &^= unix.OPOST
termios.Lflag &^= unix.ECHO | unix.ECHONL | unix.ICANON | unix.ISIG | unix.IEXTEN
termios.Cflag &^= unix.CSIZE | unix.PARENB
termios.Cflag |= unix.CS8
termios.Cc[unix.VMIN] = 1
termios.Cc[unix.VTIME] = 0
if err := unix.IoctlSetTermios(int(fd), ioctlWriteTermios, termios); err != nil {
return nil, err
}

return &oldState, nil
}

func setState(fd uintptr, state *State) error {
var termios *unix.Termios
if state != nil {
termios = &state.Termios
}
return unix.IoctlSetTermios(int(fd), ioctlWriteTermios, termios)
}

func getState(fd uintptr) (*State, error) {
termios, err := unix.IoctlGetTermios(int(fd), ioctlReadTermios)
if err != nil {
return nil, err
}

return &State{state{Termios: *termios}}, nil
}

func restore(fd uintptr, state *State) error {
return unix.IoctlSetTermios(int(fd), ioctlWriteTermios, &state.Termios)
}

func getSize(fd uintptr) (width, height int, err error) {
ws, err := unix.IoctlGetWinsize(int(fd), unix.TIOCGWINSZ)
if err != nil {
return 0, 0, err
}
return int(ws.Col), int(ws.Row), nil
}

// passwordReader is an io.Reader that reads from a specific file descriptor.
type passwordReader int

func (r passwordReader) Read(buf []byte) (int, error) {
return unix.Read(int(r), buf)
}

func readPassword(fd uintptr) ([]byte, error) {
termios, err := unix.IoctlGetTermios(int(fd), ioctlReadTermios)
if err != nil {
return nil, err
}

newState := *termios
newState.Lflag &^= unix.ECHO
newState.Lflag |= unix.ICANON | unix.ISIG
newState.Iflag |= unix.ICRNL
if err := unix.IoctlSetTermios(int(fd), ioctlWriteTermios, &newState); err != nil {
return nil, err
}

defer unix.IoctlSetTermios(int(fd), ioctlWriteTermios, termios)

return readPasswordLine(passwordReader(fd))
}
11 changes: 11 additions & 0 deletions term/term_unix_bsd.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
// +build darwin dragonfly freebsd netbsd openbsd

package term

import "golang.org/x/sys/unix"

const (
ioctlReadTermios = unix.TIOCGETA
ioctlWriteTermios = unix.TIOCSETA
)
11 changes: 11 additions & 0 deletions term/term_unix_other.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
//go:build aix || linux || solaris || zos
// +build aix linux solaris zos

package term

import "golang.org/x/sys/unix"

const (
ioctlReadTermios = unix.TCGETS
ioctlWriteTermios = unix.TCSETS
)
86 changes: 86 additions & 0 deletions term/term_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
//go:build windows
// +build windows

package term

import (
"os"

"golang.org/x/sys/windows"
)

type state struct {
Mode uint32
}

func isTerminal(fd uintptr) bool {
var st uint32
err := windows.GetConsoleMode(windows.Handle(fd), &st)
return err == nil
}

func makeRaw(fd uintptr) (*State, error) {
var st uint32
if err := windows.GetConsoleMode(windows.Handle(fd), &st); err != nil {
return nil, err
}
raw := st &^ (windows.ENABLE_ECHO_INPUT | windows.ENABLE_PROCESSED_INPUT | windows.ENABLE_LINE_INPUT | windows.ENABLE_PROCESSED_OUTPUT)
if err := windows.SetConsoleMode(windows.Handle(fd), raw); err != nil {
return nil, err
}
return &State{state{st}}, nil
}

func setState(fd uintptr, state *State) error {
var mode uint32
if state != nil {
mode = state.Mode
}
return windows.SetConsoleMode(windows.Handle(fd), mode)
}

func getState(fd uintptr) (*State, error) {
var st uint32
if err := windows.GetConsoleMode(windows.Handle(fd), &st); err != nil {
return nil, err
}
return &State{state{st}}, nil
}

func restore(fd uintptr, state *State) error {
return windows.SetConsoleMode(windows.Handle(fd), state.Mode)
}

func getSize(fd uintptr) (width, height int, err error) {
var info windows.ConsoleScreenBufferInfo
if err := windows.GetConsoleScreenBufferInfo(windows.Handle(fd), &info); err != nil {
return 0, 0, err
}
return int(info.Window.Right - info.Window.Left + 1), int(info.Window.Bottom - info.Window.Top + 1), nil
}

func readPassword(fd uintptr) ([]byte, error) {
var st uint32
if err := windows.GetConsoleMode(windows.Handle(fd), &st); err != nil {
return nil, err
}
old := st

st &^= (windows.ENABLE_ECHO_INPUT | windows.ENABLE_LINE_INPUT)
st |= (windows.ENABLE_PROCESSED_OUTPUT | windows.ENABLE_PROCESSED_INPUT)
if err := windows.SetConsoleMode(windows.Handle(fd), st); err != nil {
return nil, err
}

defer windows.SetConsoleMode(windows.Handle(fd), old)

var h windows.Handle
p, _ := windows.GetCurrentProcess()
if err := windows.DuplicateHandle(p, windows.Handle(fd), p, &h, 0, false, windows.DUPLICATE_SAME_ACCESS); err != nil {
return nil, err
}

f := os.NewFile(uintptr(h), "stdin")
defer f.Close()
return readPasswordLine(f)
}
47 changes: 47 additions & 0 deletions term/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package term

import (
"io"
"runtime"
)

// readPasswordLine reads from reader until it finds \n or io.EOF.
// The slice returned does not include the \n.
// readPasswordLine also ignores any \r it finds.
// Windows uses \r as end of line. So, on Windows, readPasswordLine
// reads until it finds \r and ignores any \n it finds during processing.
func readPasswordLine(reader io.Reader) ([]byte, error) {
var buf [1]byte
var ret []byte

for {
n, err := reader.Read(buf[:])
if n > 0 {
switch buf[0] {
case '\b':
if len(ret) > 0 {
ret = ret[:len(ret)-1]
}
case '\n':
if runtime.GOOS != "windows" {
return ret, nil
}
// otherwise ignore \n
case '\r':
if runtime.GOOS == "windows" {
return ret, nil
}
// otherwise ignore \r
default:
ret = append(ret, buf[0])
}
continue
}
if err != nil {
if err == io.EOF && len(ret) > 0 {
return ret, nil
}
return ret, err
}
}
}

0 comments on commit 24b8333

Please sign in to comment.