From 289a6b0659b441833d8b376e1b70a996a2e8b3f0 Mon Sep 17 00:00:00 2001
From: liujian <54946465+redscholar@users.noreply.github.com>
Date: Wed, 11 Sep 2024 15:38:59 +0800
Subject: [PATCH] add `sudo_password` to use sudo mod. (#2402)

* fix: graceful delete runtime dir.

Signed-off-by: joyceliu <joyceliu@yunify.com>

* fix: graceful delete runtime dir.

Signed-off-by: joyceliu <joyceliu@yunify.com>

* fix: add `sudo` and SHELL in connector.

Signed-off-by: joyceliu <joyceliu@yunify.com>

---------

Signed-off-by: joyceliu <joyceliu@yunify.com>
Co-authored-by: joyceliu <joyceliu@yunify.com>
---
 pkg/connector/connector.go            |  2 +-
 pkg/connector/kubernetes_connector.go |  4 +-
 pkg/connector/local_connector.go      | 21 ++++----
 pkg/connector/ssh_connector.go        | 73 ++++++++++++++++++++++-----
 pkg/const/common.go                   |  4 +-
 pkg/executor/pipeline_executor.go     |  2 +-
 6 files changed, 79 insertions(+), 27 deletions(-)

diff --git a/pkg/connector/connector.go b/pkg/connector/connector.go
index 4b08cceae..cd6e8a524 100644
--- a/pkg/connector/connector.go
+++ b/pkg/connector/connector.go
@@ -36,7 +36,7 @@ const (
 	connectedKubernetes = "kubernetes"
 )
 
-var shell = commandShell()
+var localShell = commandShell()
 
 // Connector is the interface for connecting to a remote host
 type Connector interface {
diff --git a/pkg/connector/kubernetes_connector.go b/pkg/connector/kubernetes_connector.go
index 8fa5f2d5d..cbb1ce42a 100644
--- a/pkg/connector/kubernetes_connector.go
+++ b/pkg/connector/kubernetes_connector.go
@@ -114,7 +114,7 @@ func (c *kubernetesConnector) PutFile(_ context.Context, src []byte, dst string,
 func (c *kubernetesConnector) FetchFile(ctx context.Context, src string, dst io.Writer) error {
 	// add "--kubeconfig" to src command
 	klog.V(5).InfoS("exec local command", "cmd", src)
-	command := c.Cmd.CommandContext(ctx, "/bin/sh", "-c", src)
+	command := c.Cmd.CommandContext(ctx, localShell, "-c", src)
 	command.SetDir(c.homeDir)
 	command.SetEnv([]string{"KUBECONFIG=" + filepath.Join(c.homeDir, kubeconfigRelPath)})
 	command.SetStdout(dst)
@@ -127,7 +127,7 @@ func (c *kubernetesConnector) FetchFile(ctx context.Context, src string, dst io.
 func (c *kubernetesConnector) ExecuteCommand(ctx context.Context, cmd string) ([]byte, error) {
 	// add "--kubeconfig" to src command
 	klog.V(5).InfoS("exec local command", "cmd", cmd)
-	command := c.Cmd.CommandContext(ctx, shell, "-c", cmd)
+	command := c.Cmd.CommandContext(ctx, localShell, "-c", cmd)
 	command.SetDir(c.homeDir)
 	command.SetEnv([]string{"KUBECONFIG=" + filepath.Join(c.homeDir, kubeconfigRelPath)})
 
diff --git a/pkg/connector/local_connector.go b/pkg/connector/local_connector.go
index ede6c04cb..f5764278e 100644
--- a/pkg/connector/local_connector.go
+++ b/pkg/connector/local_connector.go
@@ -28,7 +28,6 @@ import (
 
 	"k8s.io/klog/v2"
 	"k8s.io/utils/exec"
-	"k8s.io/utils/ptr"
 
 	_const "github.com/kubesphere/kubekey/v4/pkg/const"
 	"github.com/kubesphere/kubekey/v4/pkg/variable"
@@ -38,17 +37,16 @@ var _ Connector = &localConnector{}
 var _ GatherFacts = &localConnector{}
 
 func newLocalConnector(connectorVars map[string]any) *localConnector {
-	sudo, err := variable.BoolVar(nil, connectorVars, _const.VariableConnectorSudo)
+	sudo, err := variable.StringVar(nil, connectorVars, _const.VariableConnectorSudoPassword)
 	if err != nil {
-		klog.InfoS("get connector sudo failed use default port 22", "error", err)
-		sudo = ptr.To(true)
+		klog.V(4).InfoS("get connector sudo password failed, execute command without sudo", "error", err)
 	}
 
-	return &localConnector{Sudo: *sudo, Cmd: exec.New()}
+	return &localConnector{Sudo: sudo, Cmd: exec.New()}
 }
 
 type localConnector struct {
-	Sudo bool
+	Sudo string
 	Cmd  exec.Interface
 }
 
@@ -97,11 +95,16 @@ func (c *localConnector) FetchFile(_ context.Context, src string, dst io.Writer)
 // ExecuteCommand in local host
 func (c *localConnector) ExecuteCommand(ctx context.Context, cmd string) ([]byte, error) {
 	klog.V(5).InfoS("exec local command", "cmd", cmd)
-	if c.Sudo {
-		return c.Cmd.CommandContext(ctx, "sudo", "-E", shell, "-c", cmd).CombinedOutput()
+	// find command interpreter in env. default /bin/bash
+
+	if c.Sudo != "" {
+		command := c.Cmd.CommandContext(ctx, "sudo", "-E", localShell, "-c", cmd)
+		command.SetStdin(bytes.NewBufferString(c.Sudo + "\n"))
+
+		return command.CombinedOutput()
 	}
 
-	return c.Cmd.CommandContext(ctx, shell, "-c", cmd).CombinedOutput()
+	return c.Cmd.CommandContext(ctx, localShell, "-c", cmd).CombinedOutput()
 }
 
 // HostInfo for GatherFacts
diff --git a/pkg/connector/ssh_connector.go b/pkg/connector/ssh_connector.go
index b9cb6be0a..25f181e5f 100644
--- a/pkg/connector/ssh_connector.go
+++ b/pkg/connector/ssh_connector.go
@@ -26,6 +26,7 @@ import (
 	"os"
 	"os/user"
 	"path/filepath"
+	"strings"
 	"time"
 
 	"github.com/pkg/sftp"
@@ -38,8 +39,9 @@ import (
 )
 
 const (
-	defaultSSHPort = 22
-	defaultSSHUser = "root"
+	defaultSSHPort  = 22
+	defaultSSHUser  = "root"
+	defaultSSHSHELL = "/bin/bash"
 )
 
 var defaultSSHPrivateKey string
@@ -56,16 +58,15 @@ var _ Connector = &sshConnector{}
 var _ GatherFacts = &sshConnector{}
 
 func newSSHConnector(host string, connectorVars map[string]any) *sshConnector {
-	sudo, err := variable.BoolVar(nil, connectorVars, _const.VariableConnectorSudo)
+	sudo, err := variable.StringVar(nil, connectorVars, _const.VariableConnectorSudoPassword)
 	if err != nil {
-		klog.InfoS("get connector sudo failed use default port 22", "error", err)
-		sudo = ptr.To(true)
+		klog.V(4).InfoS("get connector sudo password failed, execute command without sudo", "error", err)
 	}
 
 	// get host in connector variable. if empty, set default host: host_name.
 	hostParam, err := variable.StringVar(nil, connectorVars, _const.VariableConnectorHost)
 	if err != nil {
-		klog.InfoS("get connector host failed use current hostname", "error", err)
+		klog.V(4).InfoS("get connector host failed use current hostname", "error", err)
 		hostParam = host
 	}
 	// get port in connector variable. if empty, set default port: 22.
@@ -93,23 +94,27 @@ func newSSHConnector(host string, connectorVars map[string]any) *sshConnector {
 	}
 
 	return &sshConnector{
-		Sudo:       *sudo,
+		Sudo:       sudo,
 		Host:       hostParam,
 		Port:       *portParam,
 		User:       userParam,
 		Password:   passwdParam,
 		PrivateKey: keyParam,
+		shell:      defaultSSHSHELL,
 	}
 }
 
 type sshConnector struct {
-	Sudo       bool
+	Sudo       string
 	Host       string
 	Port       int
 	User       string
 	Password   string
 	PrivateKey string
-	client     *ssh.Client
+
+	client *ssh.Client
+	// shell to execute command
+	shell string
 }
 
 // Init connector, get ssh.Client
@@ -147,6 +152,22 @@ func (c *sshConnector) Init(context.Context) error {
 	}
 	c.client = sshClient
 
+	// get shell from env
+	session, err := sshClient.NewSession()
+	if err != nil {
+		return fmt.Errorf("create session error: %w", err)
+	}
+	defer session.Close()
+
+	output, err := session.CombinedOutput("echo $SHELL")
+	if err != nil {
+		return fmt.Errorf("env command error: %w", err)
+	}
+
+	if strings.TrimSuffix(string(output), "\n") != "" {
+		c.shell = strings.TrimSuffix(string(output), "\n")
+	}
+
 	return nil
 }
 
@@ -231,11 +252,39 @@ func (c *sshConnector) ExecuteCommand(_ context.Context, cmd string) ([]byte, er
 	}
 	defer session.Close()
 
-	if c.Sudo {
-		return session.CombinedOutput(fmt.Sprintf("sudo -E %s -c \"%q\"", shell, cmd))
+	if c.Sudo != "" {
+		cmd = fmt.Sprintf("sudo -E %s -c \"%q\"", c.shell, cmd)
+		// get pipe from session
+		stdin, _ := session.StdinPipe()
+		stdout, _ := session.StdoutPipe()
+		stderr, _ := session.StderrPipe()
+		// Request a pseudo-terminal (required for sudo password input)
+		if err := session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}); err != nil {
+			return nil, err
+		}
+		// Start the remote command
+		if err := session.Start(cmd); err != nil {
+			return nil, err
+		}
+		// Write sudo password to the standard input
+		if _, err := io.WriteString(stdin, c.Sudo+"\n"); err != nil {
+			return nil, err
+		}
+		// Read the command output
+		output := make([]byte, 0)
+		stdoutData, _ := io.ReadAll(stdout)
+		stderrData, _ := io.ReadAll(stderr)
+		output = append(output, stdoutData...)
+		output = append(output, stderrData...)
+		// Wait for the command to complete
+		if err := session.Wait(); err != nil {
+			return nil, err
+		}
+
+		return output, nil
 	}
 
-	return session.CombinedOutput(fmt.Sprintf("%s -c \"%q\"", shell, cmd))
+	return session.CombinedOutput(fmt.Sprintf("%s -c \"%q\"", c.shell, cmd))
 }
 
 // HostInfo for GatherFacts
diff --git a/pkg/const/common.go b/pkg/const/common.go
index 011272383..3a2eb2bb7 100644
--- a/pkg/const/common.go
+++ b/pkg/const/common.go
@@ -30,8 +30,8 @@ const ( // === From inventory ===
 	VariableConnector = "connector"
 	// VariableConnectorType is connected type for VariableConnector.
 	VariableConnectorType = "type"
-	// VariableConnectorSudo is connected address for VariableConnector.
-	VariableConnectorSudo = "sudo"
+	// VariableConnectorSudoPassword is connected address for VariableConnector.
+	VariableConnectorSudoPassword = "sudo_password"
 	// VariableConnectorHost is connected address for VariableConnector.
 	VariableConnectorHost = "host"
 	// VariableConnectorPort is connected address for VariableConnector.
diff --git a/pkg/executor/pipeline_executor.go b/pkg/executor/pipeline_executor.go
index 7338c8c20..dd7fdb98f 100644
--- a/pkg/executor/pipeline_executor.go
+++ b/pkg/executor/pipeline_executor.go
@@ -168,7 +168,7 @@ func (e pipelineExecutor) execBatchHosts(ctx context.Context, play kkprojectv1.P
 			option:       e.option,
 			hosts:        serials,
 			ignoreErrors: play.IgnoreErrors,
-			blocks:       play.Tasks,
+			blocks:       play.PostTasks,
 			tags:         play.Taggable,
 		}.Exec(ctx)); err != nil {
 			return fmt.Errorf("execute post-tasks error: %w", err)