Skip to content

Commit

Permalink
cmd/utils: Add an asynchronous cancellable version of askForConfirmation
Browse files Browse the repository at this point in the history
A subsequent commit will use this to ensure that the user can still
interact with the image download prompt while 'skopeo inspect' fetches
the image size from the remote registry.  Initially, the prompt will be
shown without the image size.  Once the size has been fetched, the older
prompt will be cancelled and a new one will be shown that includes the
size.

#752
#1263
  • Loading branch information
debarshiray committed Dec 12, 2023
1 parent 0a1ba75 commit 12d082c
Showing 1 changed file with 204 additions and 0 deletions.
204 changes: 204 additions & 0 deletions src/cmd/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,30 @@ package cmd

import (
"bufio"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"os"
"os/exec"
"strings"
"syscall"

"github.com/containers/toolbox/pkg/utils"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)

type askForConfirmationPreFunc func() error
type pollFunc func(error, []unix.PollFd) error

var (
errClosed = errors.New("closed")

errContinue = errors.New("continue")

errHUP = errors.New("HUP")
)

// askForConfirmation prints prompt to stdout and waits for response from the
Expand Down Expand Up @@ -67,6 +83,109 @@ func askForConfirmation(prompt string) bool {
return retVal
}

func askForConfirmationAsync(ctx context.Context,
prompt string,
askForConfirmationPreFn askForConfirmationPreFunc) (<-chan bool, <-chan error) {

retValCh := make(chan bool, 1)
errCh := make(chan error, 1)

done := ctx.Done()
eventFD := -1
if done != nil {
fd, err := unix.Eventfd(0, unix.EFD_CLOEXEC|unix.EFD_NONBLOCK)
if err != nil {
errCh <- fmt.Errorf("eventfd(2) failed: %w", err)
return retValCh, errCh
}

eventFD = fd
}

go func() {
for {
fmt.Printf("%s ", prompt)
if askForConfirmationPreFn != nil {
if err := askForConfirmationPreFn(); err != nil {
if errors.Is(err, errContinue) {
continue
}

errCh <- err
break
}
}

var response string

pollFn := func(errPoll error, pollFDs []unix.PollFd) error {
if len(pollFDs) != 1 {
panic("unexpected number of file descriptors")
}

if errPoll != nil {
return errPoll
}

if pollFDs[0].Revents&unix.POLLIN != 0 {
logrus.Debug("Returned from /dev/stdin: POLLIN")

scanner := bufio.NewScanner(os.Stdin)
scanner.Split(bufio.ScanLines)

if !scanner.Scan() {
if err := scanner.Err(); err != nil {
return err
} else {
return io.EOF
}
}

response = scanner.Text()
return nil
}

if pollFDs[0].Revents&unix.POLLHUP != 0 {
logrus.Debug("Returned from /dev/stdin: POLLHUP")
return errHUP
}

if pollFDs[0].Revents&unix.POLLNVAL != 0 {
logrus.Debug("Returned from /dev/stdin: POLLNVAL")
return errClosed
}

return errContinue
}

stdinFD := int32(os.Stdin.Fd())

err := poll(pollFn, int32(eventFD), stdinFD)
if err != nil {
errCh <- err
break
}

if response == "" {
response = "n"
} else {
response = strings.ToLower(response)
}

if response == "no" || response == "n" {
retValCh <- false
break
} else if response == "yes" || response == "y" {
retValCh <- true
break
}
}
}()

watchContextForEventFD(ctx, eventFD)
return retValCh, errCh
}

func createErrorContainerNotFound(container string) error {
var builder strings.Builder
fmt.Fprintf(&builder, "container %s not found\n", container)
Expand Down Expand Up @@ -148,6 +267,57 @@ func getUsageForCommonCommands() string {
return usage
}

func poll(pollFn pollFunc, eventFD int32, fds ...int32) error {
if len(fds) == 0 {
panic("file descriptors not specified")
}

pollFDs := []unix.PollFd{
{
Fd: eventFD,
Events: unix.POLLIN,
Revents: 0,
},
}

for _, fd := range fds {
pollFD := unix.PollFd{Fd: fd, Events: unix.POLLIN, Revents: 0}
pollFDs = append(pollFDs, pollFD)
}

for {
if _, err := unix.Poll(pollFDs, -1); err != nil {
if errors.Is(err, unix.EINTR) {
logrus.Debugf("Failed to poll(2): %s: ignoring", err)
continue
}

return fmt.Errorf("poll(2) failed: %w", err)
}

var err error

if pollFDs[0].Revents&unix.POLLIN != 0 {
logrus.Debug("Returned from eventfd: POLLIN")
err = context.Canceled

for {
buffer := make([]byte, 8)
if n, err := unix.Read(int(eventFD), buffer); n != len(buffer) || err != nil {
break
}
}
} else if pollFDs[0].Revents&unix.POLLNVAL != 0 {
logrus.Debug("Returned from eventfd: POLLNVAL")
err = context.Canceled
}

if err := pollFn(err, pollFDs[1:]); !errors.Is(err, errContinue) {
return err
}
}
}

func resolveContainerAndImageNames(container, containerArg, distroCLI, imageCLI, releaseCLI string) (
string, string, string, error,
) {
Expand Down Expand Up @@ -245,3 +415,37 @@ func showManual(manual string) error {

return nil
}

func watchContextForEventFD(ctx context.Context, eventFD int) {
done := ctx.Done()
if done == nil {
return
}

if eventFD < 0 {
panic("invalid file descriptor for eventfd")
}

go func() {
defer unix.Close(eventFD)

select {
case <-done:
// A varint-encoded uint64 takes a maximum of 10 bytes,
// as defined by binary.MaxVarintLen64. However, 1
// byte is enough to encode the number 1. See:
// https://protobuf.dev/programming-guides/encoding/
//
// An eventfd(2) file descriptor expects a uint64 to be
// given to write(2). Luckily, a varint-encoded number
// 1 happens to work.
buffer := make([]byte, 8)
binary.PutUvarint(buffer, 1)

if _, err := unix.Write(eventFD, buffer); err != nil {
panicMsg := fmt.Sprintf("write(2) to eventfd failed: %s", err)
panic(panicMsg)
}
}
}()
}

0 comments on commit 12d082c

Please sign in to comment.